diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index 5aa82432bb..8b40c2e2f8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -81,7 +81,7 @@ public DispatchQueryResponse submit( if (dispatchQueryRequest.getSessionId() != null) { // get session from request SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId()); - Optional createdSession = sessionManager.getSession(sessionId); + Optional createdSession = sessionManager.getSession(sessionId, dispatchQueryRequest.getDatasource()); if (createdSession.isPresent()) { session = createdSession.get(); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index 0f0a4ce373..5bf71b8482 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -35,9 +35,21 @@ public Session createSession(CreateSessionRequest request) { return session; } - public Optional getSession(SessionId sid) { + /** + * Retrieves the session associated with the given session ID. + * + * This method is particularly used in scenarios where the data source encoded in the session ID is deemed untrustworthy. + * It allows for the safe retrieval of session details based on a known and validated session ID, rather than relying on potentially outdated data source information. + * + * For more context on the use case and implementation, refer to the documentation here: https://tinyurl.com/bdh6s834 + * + * @param sid The unique identifier of the session. It is used to fetch the corresponding session details. + * @param dataSourceName The name of the data source. This parameter is utilized in the session retrieval process. + * @return An Optional containing the session associated with the provided session ID. Returns an empty Optional if no matching session is found. + */ + public Optional getSession(SessionId sid, String dataSourceName) { Optional model = - StateStore.getSession(stateStore, sid.getDataSourceName()).apply(sid.getSessionId()); + StateStore.getSession(stateStore, dataSourceName).apply(sid.getSessionId()); if (model.isPresent()) { InteractiveSession session = InteractiveSession.builder() @@ -51,6 +63,18 @@ public Optional getSession(SessionId sid) { return Optional.empty(); } + /** + * Retrieves the session associated with the provided session ID. + * + * This method is utilized specifically in scenarios where the data source information encoded in the session ID is considered trustworthy. It ensures the retrieval of session details based on the session ID, relying on the integrity of the data source information contained within it. + * + * @param sid The session ID used to identify and retrieve the corresponding session. It is expected to contain valid and trusted data source information. + * @return An Optional containing the session associated with the provided session ID. If no session is found that matches the session ID, an empty Optional is returned. + */ + public Optional getSession(SessionId sid) { + return getSession(sid, sid.getDataSourceName()); + } + // todo, keep it only for testing, will remove it later. public boolean isEnabled() { return true; diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 56ee56ea5e..bd851b2e56 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -383,6 +383,43 @@ public void recreateSessionIfNotReady() { assertNotEquals(second.getSessionId(), third.getSessionId()); } + @Test + public void submitQueryWithDifferentDataSourceSessionWillCreateNewSession() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // enable session + enableSession(true); + + // 1. create async query. + CreateAsyncQueryResponse first = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("SHOW SCHEMAS IN " + DATASOURCE, DATASOURCE, LangType.SQL, null)); + assertNotNull(first.getSessionId()); + + // set sessionState to RUNNING + setSessionState(first.getSessionId(), SessionState.RUNNING); + + // 2. reuse session id + CreateAsyncQueryResponse second = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "SHOW SCHEMAS IN " + DATASOURCE, DATASOURCE, LangType.SQL, first.getSessionId())); + + assertEquals(first.getSessionId(), second.getSessionId()); + + // set sessionState to RUNNING + setSessionState(second.getSessionId(), SessionState.RUNNING); + + // 3. given different source, create a new session id + CreateAsyncQueryResponse third = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest( + "SHOW SCHEMAS IN " + DSOTHER, DSOTHER, LangType.SQL, second.getSessionId())); + assertNotEquals(second.getSessionId(), third.getSessionId()); + } + @Test public void submitQueryInInvalidSessionWillCreateNewSession() { LocalEMRSClient emrsClient = new LocalEMRSClient(); diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 2a76eabe6a..d8ec782e9b 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -314,7 +314,7 @@ void testDispatchSelectQueryReuseSession() { doReturn(true).when(sessionManager).isEnabled(); doReturn(Optional.of(session)) .when(sessionManager) - .getSession(eq(new SessionId(MOCK_SESSION_ID))); + .getSession(eq(new SessionId(MOCK_SESSION_ID)), eq("my_glue")); doReturn(new SessionId(MOCK_SESSION_ID)).when(session).getSessionId(); doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any()); when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID);