Skip to content

Commit

Permalink
fix: terminate s8s session on kernel termination
Browse files Browse the repository at this point in the history
  • Loading branch information
isha97 committed Jan 22, 2025
1 parent 61e7b47 commit 8883caa
Showing 1 changed file with 56 additions and 35 deletions.
91 changes: 56 additions & 35 deletions google/cloud/spark_connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import json
import logging
import os
Expand Down Expand Up @@ -213,6 +214,16 @@ def __create(self) -> "SparkSession":
logger.info(
"Creating Spark session. It may take few minutes."
)
atexit.register(
atexit.register(
lambda: ServerlessSessionHelper.terminate_s8s_session(
self._project_id,
self._region,
session_id,
self._client_options,
)
)
)
operation = SessionControllerClient(
client_options=self._client_options
).create_session(session_request)
Expand Down Expand Up @@ -473,42 +484,12 @@ def _remove_stoped_session_from_file(self):
def stop(self) -> None:
with GoogleSparkSession._lock:
if GoogleSparkSession._active_s8s_session_id is not None:
from google.cloud.dataproc_v1 import SessionControllerClient

logger.debug(
f"Terminating serverless session: {GoogleSparkSession._active_s8s_session_id}"
ServerlessSessionHelper.terminate_s8s_session(
GoogleSparkSession._project_id,
GoogleSparkSession._region,
GoogleSparkSession._active_s8s_session_id,
self._client_options,
)
terminate_session_request = TerminateSessionRequest()
session_name = f"projects/{GoogleSparkSession._project_id}/locations/{GoogleSparkSession._region}/sessions/{GoogleSparkSession._active_s8s_session_id}"
terminate_session_request.name = session_name
state = None
try:
SessionControllerClient(
client_options=self._client_options
).terminate_session(terminate_session_request)
get_session_request = GetSessionRequest()
get_session_request.name = session_name
state = Session.State.ACTIVE
while (
state != Session.State.TERMINATING
and state != Session.State.TERMINATED
and state != Session.State.FAILED
):
session = SessionControllerClient(
client_options=self._client_options
).get_session(get_session_request)
state = session.state
sleep(1)
except NotFound:
logger.debug(
f"Session {GoogleSparkSession._active_s8s_session_id} already deleted"
)
except FailedPrecondition:
logger.debug(
f"Session {GoogleSparkSession._active_s8s_session_id} already terminated manually or terminated automatically through session ttl limits"
)
if state is not None and state == Session.State.FAILED:
raise RuntimeError("Serverless session termination failed")

self._remove_stoped_session_from_file()
GoogleSparkSession._active_s8s_session_uuid = None
Expand All @@ -524,3 +505,43 @@ def stop(self) -> None:
GoogleSparkSession._active_session, "session", None
):
GoogleSparkSession._active_session.session = None


class ServerlessSessionHelper:

@staticmethod
def terminate_s8s_session(
project_id, region, active_s8s_session_id, client_options=None
):
from google.cloud.dataproc_v1 import SessionControllerClient

logger.debug(f"Terminating serverless session: {active_s8s_session_id}")
terminate_session_request = TerminateSessionRequest()
session_name = f"projects/{project_id}/locations/{region}/sessions/{active_s8s_session_id}"
terminate_session_request.name = session_name
state = None
try:
SessionControllerClient(
client_options=client_options
).terminate_session(terminate_session_request)
get_session_request = GetSessionRequest()
get_session_request.name = session_name
state = Session.State.ACTIVE
while (
state != Session.State.TERMINATING
and state != Session.State.TERMINATED
and state != Session.State.FAILED
):
session = SessionControllerClient(
client_options=client_options
).get_session(get_session_request)
state = session.state
sleep(1)
except NotFound:
logger.debug(f"Session {active_s8s_session_id} already deleted")
except FailedPrecondition:
logger.debug(
f"Session {active_s8s_session_id} already terminated manually or terminated automatically through session ttl limits"
)
if state is not None and state == Session.State.FAILED:
raise RuntimeError("Serverless session termination failed")

0 comments on commit 8883caa

Please sign in to comment.