diff --git a/google/cloud/spark_connect/session.py b/google/cloud/spark_connect/session.py index eadddbf..3971918 100644 --- a/google/cloud/spark_connect/session.py +++ b/google/cloud/spark_connect/session.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import atexit import json import logging import os @@ -213,6 +214,16 @@ def __create(self) -> "SparkSession": logger.info( "Creating Spark session. It may take few minutes." ) + atexit.register( + atexit.register( + lambda: ServerlessSessionHelper.terminate_s8s_session( + self._project_id, + self._region, + session_id, + self._client_options, + ) + ) + ) operation = SessionControllerClient( client_options=self._client_options ).create_session(session_request) @@ -473,42 +484,12 @@ def _remove_stoped_session_from_file(self): def stop(self) -> None: with GoogleSparkSession._lock: if GoogleSparkSession._active_s8s_session_id is not None: - from google.cloud.dataproc_v1 import SessionControllerClient - - logger.debug( - f"Terminating serverless session: {GoogleSparkSession._active_s8s_session_id}" + ServerlessSessionHelper.terminate_s8s_session( + GoogleSparkSession._project_id, + GoogleSparkSession._region, + GoogleSparkSession._active_s8s_session_id, + self._client_options, ) - terminate_session_request = TerminateSessionRequest() - session_name = f"projects/{GoogleSparkSession._project_id}/locations/{GoogleSparkSession._region}/sessions/{GoogleSparkSession._active_s8s_session_id}" - terminate_session_request.name = session_name - state = None - try: - SessionControllerClient( - client_options=self._client_options - ).terminate_session(terminate_session_request) - get_session_request = GetSessionRequest() - get_session_request.name = session_name - state = Session.State.ACTIVE - while ( - state != Session.State.TERMINATING - and state != Session.State.TERMINATED - and state != Session.State.FAILED - ): - session = SessionControllerClient( - client_options=self._client_options - ).get_session(get_session_request) - state = session.state - sleep(1) - except NotFound: - logger.debug( - f"Session {GoogleSparkSession._active_s8s_session_id} already deleted" - ) - except FailedPrecondition: - logger.debug( - f"Session {GoogleSparkSession._active_s8s_session_id} already terminated manually or terminated automatically through session ttl limits" - ) - if state is not None and state == Session.State.FAILED: - raise RuntimeError("Serverless session termination failed") self._remove_stoped_session_from_file() GoogleSparkSession._active_s8s_session_uuid = None @@ -524,3 +505,43 @@ def stop(self) -> None: GoogleSparkSession._active_session, "session", None ): GoogleSparkSession._active_session.session = None + + +class ServerlessSessionHelper: + + @staticmethod + def terminate_s8s_session( + project_id, region, active_s8s_session_id, client_options=None + ): + from google.cloud.dataproc_v1 import SessionControllerClient + + logger.debug(f"Terminating serverless session: {active_s8s_session_id}") + terminate_session_request = TerminateSessionRequest() + session_name = f"projects/{project_id}/locations/{region}/sessions/{active_s8s_session_id}" + terminate_session_request.name = session_name + state = None + try: + SessionControllerClient( + client_options=client_options + ).terminate_session(terminate_session_request) + get_session_request = GetSessionRequest() + get_session_request.name = session_name + state = Session.State.ACTIVE + while ( + state != Session.State.TERMINATING + and state != Session.State.TERMINATED + and state != Session.State.FAILED + ): + session = SessionControllerClient( + client_options=client_options + ).get_session(get_session_request) + state = session.state + sleep(1) + except NotFound: + logger.debug(f"Session {active_s8s_session_id} already deleted") + except FailedPrecondition: + logger.debug( + f"Session {active_s8s_session_id} already terminated manually or terminated automatically through session ttl limits" + ) + if state is not None and state == Session.State.FAILED: + raise RuntimeError("Serverless session termination failed")