Skip to content

Commit

Permalink
create new session if current session not ready (#2363)
Browse files Browse the repository at this point in the history
Signed-off-by: Peng Huo <[email protected]>
  • Loading branch information
penghuo authored Oct 25, 2023
1 parent 886c2fc commit a5512f5
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ
Map<String, String> tags = getDefaultTagsForJobSubmission(dispatchQueryRequest);

if (sessionManager.isEnabled()) {
Session session;
Session session = null;

if (dispatchQueryRequest.getSessionId() != null) {
// get session from request
SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId());
Expand All @@ -222,8 +223,9 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ
throw new IllegalArgumentException("no session found. " + sessionId);
}
session = createdSession.get();
} else {
// create session if not exist
}
if (session == null || !session.isReady()) {
// create session if not exist or session dead/fail
tags.put(JOB_TYPE_TAG_KEY, JobType.INTERACTIVE.getText());
session =
sessionManager.createSession(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package org.opensearch.sql.spark.execution.session;

import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession;
import static org.opensearch.sql.spark.execution.session.SessionState.DEAD;
import static org.opensearch.sql.spark.execution.session.SessionState.END_STATE;
import static org.opensearch.sql.spark.execution.session.SessionState.FAIL;
import static org.opensearch.sql.spark.execution.statement.StatementId.newStatementId;
import static org.opensearch.sql.spark.execution.statestore.StateStore.createSession;
import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession;
Expand Down Expand Up @@ -130,4 +132,9 @@ public Optional<Statement> get(StatementId stID) {
.statementModel(model)
.build());
}

@Override
public boolean isReady() {
return sessionModel.getSessionState() != DEAD && sessionModel.getSessionState() != FAIL;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,7 @@ public interface Session {
SessionModel getSessionModel();

SessionId getSessionId();

/** return true if session is ready to use. */
boolean isReady();
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.opensearch.sql.spark.client.StartJobRequest;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
import org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher;
import org.opensearch.sql.spark.execution.session.SessionId;
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.execution.session.SessionModel;
import org.opensearch.sql.spark.execution.session.SessionState;
Expand Down Expand Up @@ -390,6 +391,7 @@ public void withSessionCreateAsyncQueryFailed() {
assertEquals("mock error", asyncQueryResults.getError());
}

// https://github.com/opensearch-project/sql/issues/2344
@Test
public void createSessionMoreThanLimitFailed() {
LocalEMRSClient emrsClient = new LocalEMRSClient();
Expand Down Expand Up @@ -419,6 +421,65 @@ public void createSessionMoreThanLimitFailed() {
"The maximum number of active sessions can be supported is 1", exception.getMessage());
}

// https://github.com/opensearch-project/sql/issues/2360
@Test
public void recreateSessionIfNotReady() {
LocalEMRSClient emrsClient = new LocalEMRSClient();
AsyncQueryExecutorService asyncQueryExecutorService =
createAsyncQueryExecutorService(emrsClient);

// enable session
enableSession(true);

// 1. create async query.
CreateAsyncQueryResponse first =
asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null));
assertNotNull(first.getSessionId());

// set sessionState to FAIL
setSessionState(first.getSessionId(), SessionState.FAIL);

// 2. reuse session id
CreateAsyncQueryResponse second =
asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest(
"select 1", DATASOURCE, LangType.SQL, first.getSessionId()));

assertNotEquals(first.getSessionId(), second.getSessionId());

// set sessionState to FAIL
setSessionState(second.getSessionId(), SessionState.DEAD);

// 3. reuse session id
CreateAsyncQueryResponse third =
asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest(
"select 1", DATASOURCE, LangType.SQL, second.getSessionId()));
assertNotEquals(second.getSessionId(), third.getSessionId());
}

@Test
public void submitQueryInInvalidSessionThrowException() {
LocalEMRSClient emrsClient = new LocalEMRSClient();
AsyncQueryExecutorService asyncQueryExecutorService =
createAsyncQueryExecutorService(emrsClient);

// enable session
enableSession(true);

// 1. create async query.
SessionId sessionId = SessionId.newSessionId(DATASOURCE);
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest(
"select 1", DATASOURCE, LangType.SQL, sessionId.getSessionId())));
assertEquals("no session found. " + sessionId, exception.getMessage());
}

private DataSourceServiceImpl createDataSourceService() {
String masterKey = "a57d991d9b573f75b9bba1df";
DataSourceMetadataStorage dataSourceMetadataStorage =
Expand Down Expand Up @@ -536,6 +597,6 @@ void setSessionState(String sessionId, SessionState sessionState) {
Optional<SessionModel> model = getSession(stateStore, DATASOURCE).apply(sessionId);
SessionModel updated =
updateSessionState(stateStore, DATASOURCE).apply(model.get(), sessionState);
assertEquals(SessionState.RUNNING, updated.getSessionState());
assertEquals(sessionState, updated.getSessionState());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ void testDispatchSelectQueryReuseSession() {
doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId();
doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any());
when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID);
when(session.isReady()).thenReturn(true);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.getRawDataSourceMetadata("my_glue")).thenReturn(dataSourceMetadata);
doNothing().when(dataSourceUserAuthorizationHelper).authorizeDataSource(dataSourceMetadata);
Expand Down

0 comments on commit a5512f5

Please sign in to comment.