Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add session and statement state for all query types #2413

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,95 @@
import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD;
import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD;
import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD;
import static org.opensearch.sql.spark.execution.session.SessionId.newSessionId;
import static org.opensearch.sql.spark.execution.session.SessionModel.initSession;
import static org.opensearch.sql.spark.execution.statement.StatementId.newStatementId;
import static org.opensearch.sql.spark.execution.statement.StatementModel.submitStatement;
import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatementModelByQueryId;

import com.amazonaws.services.emrserverless.model.JobRunState;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.json.JSONObject;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
import org.opensearch.sql.spark.execution.session.SessionId;
import org.opensearch.sql.spark.execution.session.SessionType;
import org.opensearch.sql.spark.execution.statement.Statement;
import org.opensearch.sql.spark.execution.statement.StatementState;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;

/** Process async query request. */
@RequiredArgsConstructor
public abstract class AsyncQueryHandler {
private final JobExecutionResponseReader jobExecutionResponseReader;
private final StateStore stateStore;

public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) {
JSONObject result = getResponseFromResultIndex(asyncQueryJobMetadata);
if (result.has(DATA_FIELD)) {
JSONObject items = result.getJSONObject(DATA_FIELD);

// If items have STATUS_FIELD, use it; otherwise, mark failed
String status = items.optString(STATUS_FIELD, JobRunState.FAILED.toString());
result.put(STATUS_FIELD, status);

// If items have ERROR_FIELD, use it; otherwise, set empty string
String error = items.optString(ERROR_FIELD, "");
result.put(ERROR_FIELD, error);
return result;
} else {
return getResponseFromExecutor(asyncQueryJobMetadata);
}
}
Statement statement = getStatementByQueryId(asyncQueryJobMetadata.getQueryId());
StatementState statementState = statement.getStatementState();
if (statementState == StatementState.SUCCESS || statementState == StatementState.FAILED) {
String queryId = asyncQueryJobMetadata.getQueryId().getId();
JSONObject result =
jobExecutionResponseReader.getResultWithQueryId(
queryId, asyncQueryJobMetadata.getResultIndex());
if (result.has(DATA_FIELD)) {
JSONObject items = result.getJSONObject(DATA_FIELD);

protected abstract JSONObject getResponseFromResultIndex(
AsyncQueryJobMetadata asyncQueryJobMetadata);
// If items have STATUS_FIELD, use it; otherwise, mark failed
String status = items.optString(STATUS_FIELD, JobRunState.FAILED.toString());
result.put(STATUS_FIELD, status);

protected abstract JSONObject getResponseFromExecutor(
AsyncQueryJobMetadata asyncQueryJobMetadata);
// If items have ERROR_FIELD, use it; otherwise, set empty string
String error = items.optString(ERROR_FIELD, "");
result.put(ERROR_FIELD, error);
return result;
}
}
JSONObject result = new JSONObject();
result.put(STATUS_FIELD, statementState.getState());
result.put(ERROR_FIELD, Optional.of(statement.getStatementModel().getError()).orElse(""));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is error written to result index or statement model in request index?
Are these different cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both, Flint spark job,

  • write result, state, error, to result index.
  • write state and error to request index.

@kaituo please help confirm also.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

return result;
}

public abstract String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata);
public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
getStatementByQueryId(asyncQueryJobMetadata.getQueryId()).cancel();
return asyncQueryJobMetadata.getQueryId().getId();
}

public abstract DispatchQueryResponse submit(
DispatchQueryRequest request, DispatchQueryContext context);

protected Statement getStatementByQueryId(AsyncQueryId queryId) {
return new Statement(stateStore, getStatementModelByQueryId(stateStore, queryId));
}

// todo, refactor code after extract StartJobRequest logic in configuration file.
public void createSessionAndStatement(
DispatchQueryRequest dispatchQueryRequest,
String appId,
String jobId,
SessionType sessionType,
String datasourceName,
AsyncQueryId queryId) {
String qid = queryId.getId();
SessionId sessionId = newSessionId(datasourceName);
StateStore.createSession(stateStore, datasourceName)
.apply(initSession(appId, jobId, sessionId, sessionType, datasourceName));
StateStore.createStatement(stateStore, datasourceName)
.apply(
submitStatement(
sessionId,
appId,
jobId,
newStatementId(qid),
dispatchQueryRequest.getLangType(),
datasourceName,
dispatchQueryRequest.getQuery(),
queryId.getId()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,61 +5,41 @@

package org.opensearch.sql.spark.dispatcher;

import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD;
import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD;
import static org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher.JOB_TYPE_TAG_KEY;

import com.amazonaws.services.emrserverless.model.GetJobRunResult;
import java.util.Map;
import lombok.RequiredArgsConstructor;
import org.json.JSONObject;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.utils.MetricUtils;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters;
import org.opensearch.sql.spark.client.EMRServerlessClient;
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
import org.opensearch.sql.spark.dispatcher.model.JobType;
import org.opensearch.sql.spark.execution.session.SessionType;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.leasemanager.LeaseManager;
import org.opensearch.sql.spark.leasemanager.model.LeaseRequest;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;

@RequiredArgsConstructor
public class BatchQueryHandler extends AsyncQueryHandler {
private final EMRServerlessClient emrServerlessClient;
private final JobExecutionResponseReader jobExecutionResponseReader;
private final StateStore stateStore;
protected final LeaseManager leaseManager;

@Override
protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) {
// either empty json when the result is not available or data with status
// Fetch from Result Index
return jobExecutionResponseReader.getResultFromOpensearchIndex(
asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex());
}

@Override
protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) {
JSONObject result = new JSONObject();
// make call to EMR Serverless when related result index documents are not available
GetJobRunResult getJobRunResult =
emrServerlessClient.getJobRunResult(
asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId());
String jobState = getJobRunResult.getJobRun().getState();
result.put(STATUS_FIELD, jobState);
result.put(ERROR_FIELD, "");
return result;
}

@Override
public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
emrServerlessClient.cancelJobRun(
asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId());
return asyncQueryJobMetadata.getQueryId().getId();
public BatchQueryHandler(
EMRServerlessClient emrServerlessClient,
StateStore stateStore,
JobExecutionResponseReader jobExecutionResponseReader,
LeaseManager leaseManager) {
super(jobExecutionResponseReader, stateStore);
this.emrServerlessClient = emrServerlessClient;
this.jobExecutionResponseReader = jobExecutionResponseReader;
this.stateStore = stateStore;
this.leaseManager = leaseManager;
}

@Override
Expand Down Expand Up @@ -87,6 +67,13 @@ public DispatchQueryResponse submit(
false,
dataSourceMetadata.getResultIndex());
String jobId = emrServerlessClient.startJobRun(startJobRequest);
createSessionAndStatement(
dispatchQueryRequest,
dispatchQueryRequest.getApplicationId(),
jobId,
SessionType.BATCH,
dataSourceMetadata.getName(),
context.getQueryId());
MetricUtils.incrementNumericalMetric(MetricName.EMR_BATCH_QUERY_JOBS_CREATION_COUNT);
return new DispatchQueryResponse(
context.getQueryId(), jobId, dataSourceMetadata.getResultIndex(), null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
package org.opensearch.sql.spark.dispatcher;

import static org.opensearch.sql.spark.execution.statestore.StateStore.createIndexDMLResult;
import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatementModelByQueryId;

import com.amazonaws.services.emrserverless.model.JobRunState;
import lombok.RequiredArgsConstructor;
import java.util.Locale;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.json.JSONObject;
import org.opensearch.client.Client;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
Expand All @@ -24,6 +24,8 @@
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse;
import org.opensearch.sql.spark.dispatcher.model.IndexDMLResult;
import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails;
import org.opensearch.sql.spark.execution.session.SessionType;
import org.opensearch.sql.spark.execution.statement.StatementState;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.flint.FlintIndexMetadata;
import org.opensearch.sql.spark.flint.FlintIndexMetadataReader;
Expand All @@ -33,7 +35,6 @@
import org.opensearch.sql.spark.response.JobExecutionResponseReader;

/** Handle Index DML query. includes * DROP * ALT? */
@RequiredArgsConstructor
public class IndexDMLHandler extends AsyncQueryHandler {
private static final Logger LOG = LogManager.getLogger();

Expand All @@ -53,6 +54,24 @@ public class IndexDMLHandler extends AsyncQueryHandler {

private final StateStore stateStore;

public IndexDMLHandler(
EMRServerlessClient emrServerlessClient,
DataSourceService dataSourceService,
DataSourceUserAuthorizationHelperImpl dataSourceUserAuthorizationHelper,
JobExecutionResponseReader jobExecutionResponseReader,
FlintIndexMetadataReader flintIndexMetadataReader,
Client client,
StateStore stateStore) {
super(jobExecutionResponseReader, stateStore);
this.emrServerlessClient = emrServerlessClient;
this.dataSourceService = dataSourceService;
this.dataSourceUserAuthorizationHelper = dataSourceUserAuthorizationHelper;
this.jobExecutionResponseReader = jobExecutionResponseReader;
this.flintIndexMetadataReader = flintIndexMetadataReader;
this.client = client;
this.stateStore = stateStore;
}

public static boolean isIndexDMLQuery(String jobId) {
return DROP_INDEX_JOB_ID.equalsIgnoreCase(jobId);
}
Expand Down Expand Up @@ -93,22 +112,20 @@ public DispatchQueryResponse submit(
System.currentTimeMillis());
String resultIndex = dataSourceMetadata.getResultIndex();
createIndexDMLResult(stateStore, resultIndex).apply(indexDMLResult);

createSessionAndStatement(
dispatchQueryRequest,
dispatchQueryRequest.getApplicationId(),
DROP_INDEX_JOB_ID,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does that mean there is only one job id for all DML queries? Would it cause issues if there is a concurrent DML running?

SessionType.BATCH,
dataSourceMetadata.getName(),
asyncQueryId);
StateStore.updateStatementState(stateStore, asyncQueryId.getDataSourceName())
.apply(
getStatementModelByQueryId(stateStore, asyncQueryId),
StatementState.fromString(status.toLowerCase(Locale.ROOT)));
return new DispatchQueryResponse(asyncQueryId, DROP_INDEX_JOB_ID, resultIndex, null);
}

@Override
protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) {
String queryId = asyncQueryJobMetadata.getQueryId().getId();
return jobExecutionResponseReader.getResultWithQueryId(
queryId, asyncQueryJobMetadata.getResultIndex());
}

@Override
protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) {
throw new IllegalStateException("[BUG] can't fetch result of index DML query form server");
}

@Override
public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
throw new IllegalArgumentException("can't cancel index DML query");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,14 @@

package org.opensearch.sql.spark.dispatcher;

import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD;
import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME;
import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD;
import static org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher.JOB_TYPE_TAG_KEY;

import java.util.Map;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.json.JSONObject;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.utils.MetricUtils;
import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata;
import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext;
import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest;
Expand All @@ -28,42 +23,27 @@
import org.opensearch.sql.spark.execution.session.SessionId;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.statement.QueryRequest;
import org.opensearch.sql.spark.execution.statement.Statement;
import org.opensearch.sql.spark.execution.statement.StatementId;
import org.opensearch.sql.spark.execution.statement.StatementState;
import org.opensearch.sql.spark.execution.statestore.StateStore;
import org.opensearch.sql.spark.leasemanager.LeaseManager;
import org.opensearch.sql.spark.leasemanager.model.LeaseRequest;
import org.opensearch.sql.spark.response.JobExecutionResponseReader;

@RequiredArgsConstructor
public class InteractiveQueryHandler extends AsyncQueryHandler {
private final SessionManager sessionManager;
private final StateStore stateStore;
private final JobExecutionResponseReader jobExecutionResponseReader;
private final LeaseManager leaseManager;

@Override
protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) {
String queryId = asyncQueryJobMetadata.getQueryId().getId();
return jobExecutionResponseReader.getResultWithQueryId(
queryId, asyncQueryJobMetadata.getResultIndex());
}

@Override
protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJobMetadata) {
JSONObject result = new JSONObject();
String queryId = asyncQueryJobMetadata.getQueryId().getId();
Statement statement = getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId);
StatementState statementState = statement.getStatementState();
result.put(STATUS_FIELD, statementState.getState());
result.put(ERROR_FIELD, Optional.of(statement.getStatementModel().getError()).orElse(""));
return result;
}

@Override
public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) {
String queryId = asyncQueryJobMetadata.getQueryId().getId();
getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId).cancel();
return queryId;
public InteractiveQueryHandler(
SessionManager sessionManager,
StateStore stateStore,
JobExecutionResponseReader jobExecutionResponseReader,
LeaseManager leaseManager) {
super(jobExecutionResponseReader, stateStore);
this.sessionManager = sessionManager;
this.stateStore = stateStore;
this.jobExecutionResponseReader = jobExecutionResponseReader;
this.leaseManager = leaseManager;
}

@Override
Expand Down Expand Up @@ -115,21 +95,4 @@ public DispatchQueryResponse submit(
dataSourceMetadata.getResultIndex(),
session.getSessionId().getSessionId());
}

private Statement getStatementByQueryId(String sid, String qid) {
SessionId sessionId = new SessionId(sid);
Optional<Session> session = sessionManager.getSession(sessionId);
if (session.isPresent()) {
// todo, statementId == jobId if statement running in session.
StatementId statementId = new StatementId(qid);
Optional<Statement> statement = session.get().get(statementId);
if (statement.isPresent()) {
return statement.get();
} else {
throw new IllegalArgumentException("no statement found. " + statementId);
}
} else {
throw new IllegalArgumentException("no session found. " + sessionId);
}
}
}
Loading