diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index 5aa82432bb..18cf8b3dfb 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -81,7 +81,8 @@ public DispatchQueryResponse submit( if (dispatchQueryRequest.getSessionId() != null) { // get session from request SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId()); - Optional createdSession = sessionManager.getSession(sessionId); + Optional createdSession = + sessionManager.getSession(sessionId, dispatchQueryRequest.getDatasource()); if (createdSession.isPresent()) { session = createdSession.get(); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index 0f0a4ce373..e0989e30c8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -35,9 +35,27 @@ public Session createSession(CreateSessionRequest request) { return session; } - public Optional getSession(SessionId sid) { + /** + * Retrieves the session associated with the given session ID. + * + *

This method is particularly used in scenarios where the data source encoded in the session + * ID is deemed untrustworthy. It allows for the safe retrieval of session details based on a + * known and validated session ID, rather than relying on potentially outdated data source + * information. + * + *

For more context on the use case and implementation, refer to the documentation here: + * https://tinyurl.com/bdh6s834 + * + * @param sid The unique identifier of the session. It is used to fetch the corresponding session + * details. + * @param dataSourceName The name of the data source. This parameter is utilized in the session + * retrieval process. + * @return An Optional containing the session associated with the provided session ID. Returns an + * empty Optional if no matching session is found. + */ + public Optional getSession(SessionId sid, String dataSourceName) { Optional model = - StateStore.getSession(stateStore, sid.getDataSourceName()).apply(sid.getSessionId()); + StateStore.getSession(stateStore, dataSourceName).apply(sid.getSessionId()); if (model.isPresent()) { InteractiveSession session = InteractiveSession.builder() @@ -51,6 +69,22 @@ public Optional getSession(SessionId sid) { return Optional.empty(); } + /** + * Retrieves the session associated with the provided session ID. + * + *

This method is utilized specifically in scenarios where the data source information encoded + * in the session ID is considered trustworthy. It ensures the retrieval of session details based + * on the session ID, relying on the integrity of the data source information contained within it. + * + * @param sid The session ID used to identify and retrieve the corresponding session. It is + * expected to contain valid and trusted data source information. + * @return An Optional containing the session associated with the provided session ID. If no + * session is found that matches the session ID, an empty Optional is returned. + */ + public Optional getSession(SessionId sid) { + return getSession(sid, sid.getDataSourceName()); + } + // todo, keep it only for testing, will remove it later. public boolean isEnabled() { return true; diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 56ee56ea5e..0207480048 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -383,6 +383,44 @@ public void recreateSessionIfNotReady() { assertNotEquals(second.getSessionId(), third.getSessionId()); } + @Test + public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse first = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "SHOW SCHEMAS IN " + DATASOURCE, DATASOURCE, LangType.SQL, null)); + assertNotNull(first.getSessionId()); + + // set sessionState to RUNNING + setSessionState(first.getSessionId(), SessionState.RUNNING); + + // 2. reuse session id + CreateAsyncQueryResponse second = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "SHOW SCHEMAS IN " + DATASOURCE, DATASOURCE, LangType.SQL, first.getSessionId())); + + assertEquals(first.getSessionId(), second.getSessionId()); + + // set sessionState to RUNNING + setSessionState(second.getSessionId(), SessionState.RUNNING); + + // 3. given different source, create a new session id + CreateAsyncQueryResponse third = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "SHOW SCHEMAS IN " + DSOTHER, DSOTHER, LangType.SQL, second.getSessionId())); + assertNotEquals(second.getSessionId(), third.getSessionId()); + } + @Test public void submitQueryInInvalidSessionWillCreateNewSession() { LocalEMRSClient emrsClient = new LocalEMRSClient(); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 2a76eabe6a..d8ec782e9b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -314,7 +314,7 @@ void testDispatchSelectQueryReuseSession() { doReturn(true).when(sessionManager).isEnabled(); doReturn(Optional.of(session)) .when(sessionManager) - .getSession(eq(new SessionId(MOCK_SESSION_ID))); + .getSession(eq(new SessionId(MOCK_SESSION_ID)), eq("my_glue")); doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID);