Skip to content

Commit

Permalink
Format with pyink
Browse files Browse the repository at this point in the history
  • Loading branch information
Deependra Patel committed Feb 5, 2025
1 parent b876f00 commit fcea5c0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
23 changes: 15 additions & 8 deletions google/cloud/spark_connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def _remove_stoped_session_from_file(self):
logger.error(
f"Exception while removing active session in file {file_path} , {e}"
)

"""
Install PyPi packages (with their dependencies) in the active Spark
session on the driver and executors.
Expand All @@ -505,24 +506,30 @@ def _remove_stoped_session_from_file(self):
This is an API available only in Google Spark Session as of today.
If there are conflicts/package doesn't exist, it throws an exception.
"""

def addArtifact(self, package: str) -> None:
if package in self._installed_pypi_libs:
logger.info("Ignoring as artifact has already been added earlier")
return

if package.startswith("pypi://") is False:
raise ValueError("Only PyPi packages are supported in format `pypi://spacy`")
raise ValueError(
"Only PyPi packages are supported in format `pypi://spacy`"
)

dependencies = {
"Version": "1.0",
"packages": [package]
}
dependencies = {"Version": "1.0", "packages": [package]}

# Can't use the same file as Spark throws exception that file already exists
file_path = tempfile.tempdir + \
"/.deps-" + self._active_s8s_session_uuid + "-" + package.removeprefix("pypi://") + ".json"
file_path = (
tempfile.tempdir
+ "/.deps-"
+ self._active_s8s_session_uuid
+ "-"
+ package.removeprefix("pypi://")
+ ".json"
)

with open(file_path, 'w') as json_file:
with open(file_path, "w") as json_file:
json.dump(dependencies, json_file, indent=4)
self.addArtifacts(file_path, file=True)
self._installed_pypi_libs.add(package)
Expand Down
10 changes: 8 additions & 2 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def test_create_spark_session_with_default_notebook_behavior(
mock_session_controller_client_instance.get_session.assert_called_once_with(
get_session_request
)

@mock.patch("google.cloud.dataproc_v1.SessionControllerClient")
def test_custom_add_artifact(
self,
Expand All @@ -168,13 +169,18 @@ def test_custom_add_artifact(
self.assertTrue(isinstance(session, GoogleSparkSession))

# Invalid input format throws Error
with self.assertRaises(ValueError, msg="Only PyPi packages are supported in format `pypi://spacy`"):
with self.assertRaises(
ValueError,
msg="Only PyPi packages are supported in format `pypi://spacy`",
):
session.addArtifact("spacy")

# Happy case, also validating content of the file
session.addArtifacts = mock.MagicMock()
session.addArtifact("pypi://spacy")
file_name = tempfile.tempdir + '/.deps-' + session_response.uuid + '-spacy.json'
file_name = (
tempfile.tempdir + "/.deps-" + session_response.uuid + "-spacy.json"
)
session.addArtifacts.assert_called_once_with(file_name, file=True)
expected_file_content = {"Version": "1.0", "packages": ["pypi://spacy"]}
self.assertEqual(json.load(open(file_name)), expected_file_content)
Expand Down

0 comments on commit fcea5c0

Please sign in to comment.