Skip to content

Commit

Permalink
Restored formatting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
stefannica committed Jan 11, 2025
1 parent 3efcfda commit 603a38d
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 57 deletions.
8 changes: 3 additions & 5 deletions examples/quickstart/pipelines/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,9 @@ def english_translation_pipeline(
tokenized_dataset, tokenizer = tokenize_data(
dataset=full_dataset, model_type=model_type
)
(
tokenized_train_dataset,
tokenized_eval_dataset,
tokenized_test_dataset,
) = split_dataset(tokenized_dataset)
tokenized_train_dataset, tokenized_eval_dataset, tokenized_test_dataset = (
split_dataset(tokenized_dataset)
)
model = train_model(
tokenized_dataset=tokenized_train_dataset,
model_type=model_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,11 +632,9 @@ def _compute_orchestrator_url(
the URL to the dashboard view in SageMaker.
"""
try:
(
region_name,
pipeline_name,
execution_id,
) = dissect_pipeline_execution_arn(pipeline_execution.arn)
region_name, pipeline_name, execution_id = (
dissect_pipeline_execution_arn(pipeline_execution.arn)
)

# Get the Sagemaker session
session = pipeline_execution.sagemaker_session
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
cast,
)

from pydantic import BaseModel, field_validator
from pydantic import field_validator, BaseModel

from zenml.config.base_settings import BaseSettings
from zenml.experiment_trackers.base_experiment_tracker import (
Expand Down Expand Up @@ -69,8 +69,8 @@ def _convert_settings(cls, value: Any) -> Any:
import wandb

if isinstance(value, wandb.Settings):
# Depending on the wandb version, either `model_dump`,
# `make_static` or `to_dict` is available to convert the settings
# Depending on the wandb version, either `model_dump`,
# `make_static` or `to_dict` is available to convert the settings
# to a dictionary
if isinstance(value, BaseModel):
return value.model_dump()
Expand Down
7 changes: 3 additions & 4 deletions src/zenml/zen_server/template_execution/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,9 @@ def run_template(
)

def _task() -> None:
(
pypi_requirements,
apt_packages,
) = requirements_utils.get_requirements_for_stack(stack=stack)
pypi_requirements, apt_packages = (
requirements_utils.get_requirements_for_stack(stack=stack)
)

if build.python_version:
version_info = version.parse(build.python_version)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,25 @@ def upgrade() -> None:
batch_op.add_column(sa.Column("save_type", sa.TEXT(), nullable=True))

# Step 2: Move data from step_run_output_artifact.type to artifact_version.save_type
op.execute(
"""
op.execute("""
UPDATE artifact_version
SET save_type = (
SELECT max(step_run_output_artifact.type)
FROM step_run_output_artifact
WHERE step_run_output_artifact.artifact_id = artifact_version.id
GROUP BY artifact_id
)
"""
)
op.execute(
"""
""")
op.execute("""
UPDATE artifact_version
SET save_type = 'step_output'
WHERE artifact_version.save_type = 'default'
"""
)
op.execute(
"""
""")
op.execute("""
UPDATE artifact_version
SET save_type = 'external'
WHERE save_type is NULL
"""
)
""")

# # Step 3: Set save_type to non-nullable
with op.batch_alter_table("artifact_version", schema=None) as batch_op:
Expand Down Expand Up @@ -75,24 +69,20 @@ def downgrade() -> None:
)

# Move data back from artifact_version.save_type to step_run_output_artifact.type
op.execute(
"""
op.execute("""
UPDATE step_run_output_artifact
SET type = (
SELECT max(artifact_version.save_type)
FROM artifact_version
WHERE step_run_output_artifact.artifact_id = artifact_version.id
GROUP BY artifact_id
)
"""
)
op.execute(
"""
""")
op.execute("""
UPDATE step_run_output_artifact
SET type = 'default'
WHERE step_run_output_artifact.type = 'step_output'
"""
)
""")

# Set type to non-nullable
with op.batch_alter_table(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,17 @@

def upgrade() -> None:
"""Upgrade database schema and/or data, creating a new revision."""
op.execute(
"""
op.execute("""
UPDATE step_run_input_artifact
SET type = 'step_output'
WHERE type = 'default'
"""
)
""")


def downgrade() -> None:
"""Downgrade database schema and/or data back to the previous revision."""
op.execute(
"""
op.execute("""
UPDATE step_run_input_artifact
SET type = 'default'
WHERE type = 'step_output'
"""
)
""")
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@ def upgrade() -> None:
connection = op.get_bind()

run_metadata_data = connection.execute(
sa.text(
"""
sa.text("""
SELECT id, resource_id, resource_type
FROM run_metadata
"""
)
""")
).fetchall()

# Prepare data with new UUIDs for bulk insert
Expand Down Expand Up @@ -109,24 +107,20 @@ def downgrade() -> None:

# Fetch data from `run_metadata_resource`
run_metadata_resource_data = connection.execute(
sa.text(
"""
sa.text("""
SELECT resource_id, resource_type, run_metadata_id
FROM run_metadata_resource
"""
)
""")
).fetchall()

# Update `run_metadata` with the data from `run_metadata_resource`
for row in run_metadata_resource_data:
connection.execute(
sa.text(
"""
sa.text("""
UPDATE run_metadata
SET resource_id = :resource_id, resource_type = :resource_type
WHERE id = :run_metadata_id
"""
),
"""),
{
"resource_id": row.resource_id,
"resource_type": row.resource_type,
Expand Down

0 comments on commit 603a38d

Please sign in to comment.