diff --git a/docs/book/component-guide/orchestrators/azureml.md b/docs/book/component-guide/orchestrators/azureml.md index e47b4d8e9f2..0cce7d75b0a 100644 --- a/docs/book/component-guide/orchestrators/azureml.md +++ b/docs/book/component-guide/orchestrators/azureml.md @@ -195,10 +195,10 @@ def example_step() -> int: @pipeline(settings={"orchestrator": azureml_settings}) -def pipeline(): +def my_pipeline(): example_step() -pipeline() +my_pipeline() ``` {% hint style="info" %} @@ -213,10 +213,18 @@ its [JobSchedules](https://learn.microsoft.com/en-us/azure/machine-learning/how- Both cron expression and intervals are supported. ```python +from zenml import pipeline from zenml.config.schedule import Schedule +@pipeline +def my_pipeline(): + ... + # Run a pipeline every 5th minute -pipeline.run(schedule=Schedule(cron_expression="*/5 * * * *")) +my_pipeline = my_pipeline.with_options( + schedule=Schedule(cron_expression="*/5 * * * *") +) +my_pipeline() ``` Once you run the pipeline with a schedule, you can find the schedule and diff --git a/docs/book/component-guide/orchestrators/sagemaker.md b/docs/book/component-guide/orchestrators/sagemaker.md index 64643339347..f7178ca0e8b 100644 --- a/docs/book/component-guide/orchestrators/sagemaker.md +++ b/docs/book/component-guide/orchestrators/sagemaker.md @@ -22,7 +22,7 @@ You should use the Sagemaker orchestrator if: ## How it works -The ZenML Sagemaker orchestrator works with [Sagemaker Pipelines](https://aws.amazon.com/sagemaker/pipelines), which can be used to construct machine learning pipelines. Under the hood, for each ZenML pipeline step, it creates a SageMaker `PipelineStep`, which contains a Sagemaker Processing job. Currently, other step types are not supported. +The ZenML Sagemaker orchestrator works with [Sagemaker Pipelines](https://aws.amazon.com/sagemaker/pipelines), which can be used to construct machine learning pipelines. Under the hood, for each ZenML pipeline step, it creates a SageMaker `PipelineStep`, which contains a Sagemaker Processing or Training job. ## How to deploy it @@ -54,12 +54,13 @@ zenml integration install aws s3 * A [remote container registry](../container-registries/container-registries.md) as part of your stack. * An IAM role or user with [an `AmazonSageMakerFullAccess` managed policy](https://docs.aws.amazon.com/sagemaker/latest/dg/security-iam-awsmanpol.html) applied to it as well as `sagemaker.amazonaws.com` added as a Principal Service. Full details on these permissions can be found [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) or use the ZenML recipe (when available) which will set up the necessary permissions for you. * The local client (whoever is running the pipeline) will also have to have the necessary permissions or roles to be able to launch Sagemaker jobs. (This would be covered by the `AmazonSageMakerFullAccess` policy suggested above.) +* If you want to use schedules, you also need to set up the correct roles, permissions and policies covered [here](#required-iam-permissions-for-schedules). There are three ways you can authenticate your orchestrator and link it to the IAM role you have created: {% tabs %} {% tab title="Authentication via Service Connector" %} -The recommended way to authenticate your SageMaker orchestrator is by registering an [AWS Service Connector](../../how-to/infrastructure-deployment/auth-management/aws-service-connector.md) and connecting it to your SageMaker orchestrator: +The recommended way to authenticate your SageMaker orchestrator is by registering an [AWS Service Connector](../../how-to/infrastructure-deployment/auth-management/aws-service-connector.md) and connecting it to your SageMaker orchestrator. If you plan to use scheduled pipelines, ensure the credentials used by the service connector have the necessary EventBridge and IAM permissions listed in the [Required IAM Permissions](#required-iam-permissions) section: ```shell zenml service-connector register --type aws -i @@ -72,7 +73,7 @@ zenml stack register -o ... --set {% endtab %} {% tab title="Explicit Authentication" %} -Instead of creating a service connector, you can also configure your AWS authentication credentials directly in the orchestrator: +Instead of creating a service connector, you can also configure your AWS authentication credentials directly in the orchestrator. If you plan to use scheduled pipelines, ensure these credentials have the necessary EventBridge and IAM permissions listed in the [Required IAM Permissions](#required-iam-permissions) section: ```shell zenml orchestrator register \ @@ -88,7 +89,7 @@ See the [`SagemakerOrchestratorConfig` SDK Docs](https://sdkdocs.zenml.io/latest {% endtab %} {% tab title="Implicit Authentication" %} -If you neither connect your orchestrator to a service connector nor configure credentials explicitly, ZenML will try to implicitly authenticate to AWS via the `default` profile in your local [AWS configuration file](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html). +If you neither connect your orchestrator to a service connector nor configure credentials explicitly, ZenML will try to implicitly authenticate to AWS via the `default` profile in your local [AWS configuration file](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html). If you plan to use scheduled pipelines, ensure this profile has the necessary EventBridge and IAM permissions listed in the [Required IAM Permissions](#required-iam-permissions) section: ```shell zenml orchestrator register \ @@ -114,7 +115,9 @@ If all went well, you should now see the following output: ``` Steps can take 5-15 minutes to start running when using the Sagemaker Orchestrator. -Your orchestrator 'sagemaker' is running remotely. Note that the pipeline run will only show up on the ZenML dashboard once the first step has started executing on the remote infrastructure. +Your orchestrator 'sagemaker' is running remotely. Note that the pipeline run +will only show up on the ZenML dashboard once the first step has started +executing on the remote infrastructure. ``` {% hint style="warning" %} @@ -153,10 +156,6 @@ Alternatively, for a more detailed view of log messages during SageMaker pipelin ![SageMaker CloudWatch Logs](../../.gitbook/assets/sagemaker-cloudwatch-logs.png) -### Run pipelines on a schedule - -The ZenML Sagemaker orchestrator doesn't currently support running pipelines on a schedule. We maintain a public roadmap for ZenML, which you can find [here](https://zenml.io/roadmap). We welcome community contributions (see more [here](https://github.com/zenml-io/zenml/blob/main/CONTRIBUTING.md)) so if you want to enable scheduling for Sagemaker, please [do let us know](https://zenml.io/slack)! - ### Configuration at pipeline or step level When running your ZenML pipeline with the Sagemaker orchestrator, the configuration set when configuring the orchestrator as a ZenML component will be used by default. However, it is possible to provide additional configuration at the pipeline or step level. This allows you to run whole pipelines or individual steps with alternative configurations. For example, this allows you to run the training process with a heavier, GPU-enabled instance type, while running other steps with lighter instances. @@ -170,19 +169,23 @@ Additional configuration for the Sagemaker orchestrator can be passed via `Sagem * `base_job_name` * `env` -For example, settings can be provided in the following way: +For example, settings can be provided and applied in the following way: ```python +from zenml import step +from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import ( + SagemakerOrchestratorSettings +) + sagemaker_orchestrator_settings = SagemakerOrchestratorSettings( instance_type="ml.m5.large", volume_size_in_gb=30, ) -``` -They can then be applied to a step as follows: -```python @step(settings={"orchestrator": sagemaker_orchestrator_settings}) +def my_step() -> None: + pass ``` For example, if your ZenML component is configured to use `ml.c5.xlarge` with 400GB additional storage by default, all steps will use it except for the step above, which will use `ml.t3.medium` (for Processing Steps) or `ml.m5.xlarge` (for Training Steps) with 30GB additional storage. See the next section for details on how ZenML decides which Sagemaker Step type to use. @@ -198,6 +201,8 @@ For more information and a full list of configurable attributes of the Sagemaker To enable Warm Pools, use the [`SagemakerOrchestratorSettings`](https://sdkdocs.zenml.io/latest/integration_code_docs/integrations-aws/#zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor.SagemakerOrchestratorSettings) class: ```python +from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import SagemakerOrchestratorSettings + sagemaker_orchestrator_settings = SagemakerOrchestratorSettings( keep_alive_period_in_seconds = 300, # 5 minutes, default value ) @@ -208,6 +213,8 @@ This configuration keeps instances warm for 5 minutes after each job completes, If you prefer not to use Warm Pools, you can explicitly disable them: ```python +from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import SagemakerOrchestratorSettings + sagemaker_orchestrator_settings = SagemakerOrchestratorSettings( keep_alive_period_in_seconds = None, ) @@ -216,6 +223,8 @@ sagemaker_orchestrator_settings = SagemakerOrchestratorSettings( By default, the SageMaker orchestrator uses Training Steps where possible, which can offer performance benefits and better integration with SageMaker's training capabilities. To disable this behavior: ```python +from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import SagemakerOrchestratorSettings + sagemaker_orchestrator_settings = SagemakerOrchestratorSettings( use_training_step = False ) @@ -236,6 +245,10 @@ Note that data import and export can be used jointly with `processor_args` for m A simple example of importing data from S3 to the Sagemaker job is as follows: ```python +from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import ( + SagemakerOrchestratorSettings +) + sagemaker_orchestrator_settings = SagemakerOrchestratorSettings( input_data_s3_mode="File", input_data_s3_uri="s3://some-bucket-name/folder" @@ -247,6 +260,10 @@ In this case, data will be available at `/opt/ml/processing/input/data` within t It is also possible to split your input over channels. This can be useful if the dataset is already split in S3, or maybe even located in different buckets. ```python +from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import ( + SagemakerOrchestratorSettings +) + sagemaker_orchestrator_settings = SagemakerOrchestratorSettings( input_data_s3_mode="File", input_data_s3_uri={ @@ -268,6 +285,10 @@ Data from within the job (e.g. produced by the training process, or when preproc In the simple case, data in `/opt/ml/processing/output/data` will be copied to S3 at the end of a job: ```python +from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import ( + SagemakerOrchestratorSettings +) + sagemaker_orchestrator_settings = SagemakerOrchestratorSettings( output_data_s3_mode="EndOfJob", output_data_s3_uri="s3://some-results-bucket-name/results" @@ -277,6 +298,10 @@ sagemaker_orchestrator_settings = SagemakerOrchestratorSettings( In a more complex case, data in `/opt/ml/processing/output/data/metadata` and `/opt/ml/processing/output/data/checkpoints` will be written away continuously: ```python +from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import ( + SagemakerOrchestratorSettings +) + sagemaker_orchestrator_settings = SagemakerOrchestratorSettings( output_data_s3_mode="Continuous", output_data_s3_uri={ @@ -296,7 +321,9 @@ The SageMaker orchestrator allows you to add tags to your pipeline executions an ```python from zenml import pipeline, step -from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import SagemakerOrchestratorSettings +from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import ( + SagemakerOrchestratorSettings +) # Define settings for the pipeline pipeline_settings = SagemakerOrchestratorSettings( @@ -339,4 +366,179 @@ This approach allows for more granular tagging, giving you flexibility in how yo Note that if you wish to use this orchestrator to run steps on a GPU, you will need to follow [the instructions on this page](../../how-to/pipeline-development/training-with-gpus/README.md) to ensure that it works. It requires adding some extra settings customization and is essential to enable CUDA for the GPU to give its full acceleration. -
ZenML Scarf
+### Scheduling Pipelines + +The SageMaker orchestrator supports running pipelines on a schedule using +SageMaker's native scheduling capabilities. You can configure schedules in +three ways: + +* Using a cron expression +* Using a fixed interval +* Running once at a specific time + +```python +from datetime import datetime, timedelta + +from zenml import pipeline +from zenml.config.schedule import Schedule + +# Using a cron expression (runs daily at 2 AM UTC) +@pipeline +def my_scheduled_pipeline(): + # Your pipeline steps here + pass + +my_scheduled_pipeline.with_options( + schedule=Schedule(cron_expression="0 2 * * *") +)() + +# Using an interval (runs every 2 hours) +@pipeline +def my_interval_pipeline(): + # Your pipeline steps here + pass + +my_interval_pipeline.with_options( + schedule=Schedule( + start_time=datetime.now(), + interval_second=timedelta(hours=2) + ) +)() + +# Running once at a specific time +@pipeline +def my_one_time_pipeline(): + # Your pipeline steps here + pass + +my_one_time_pipeline.with_options( + schedule=Schedule(run_once_start_time=datetime(2024, 12, 31, 23, 59)) +)() +``` + +When you deploy a scheduled pipeline, ZenML will: + +1. Create a SageMaker Pipeline Schedule with the specified configuration +2. Configure the pipeline as the target for the schedule +3. Enable automatic execution based on the schedule + +{% hint style="info" %} +If you run the same pipeline with a schedule multiple times, the existing +schedule will **not** be updated with the new settings. Rather, ZenML will +create a new SageMaker pipeline and attach a new schedule to it. The user +must manually delete the old pipeline and their attached schedule using the +AWS CLI or API (`aws scheduler delete-schedule `). See details +here: [SageMaker Pipeline Schedules](https://docs.aws.amazon.com/sagemaker/latest/dg/pipeline-eventbridge.html) +{% endhint %} + +#### Required IAM Permissions for schedules + +When using scheduled pipelines, you need to ensure your IAM role has the +correct permissions and trust relationships. You can set this up by either +defining an explicit `scheduler_role` in your orchestrator configuration or +you can adjust the role that you are already using on the client side to manage +Sagemaker pipelines. + +```bash +# When registering the orchestrator +zenml orchestrator register sagemaker-orchestrator \ + --flavor=sagemaker \ + --scheduler_role=arn:aws:iam::123456789012:role/my-scheduler-role + +# Or updating an existing orchestrator +zenml orchestrator update sagemaker-orchestrator \ + --scheduler_role=arn:aws:iam::123456789012:role/my-scheduler-role +``` + +{% hint style="info" %} +The IAM role that you are using on the client side can come from multiple +sources depending on how you configured your orchestrator, such as explicit +credentials, a service connector or an implicit authentication. + +If you are using a service connector, keep in mind, this only works with +authentication methods that involve IAM roles (IAM role, Implicit +authentication). LINK +{% endhint %} + +This is particularly useful when: + +* You want to use different roles for creating pipelines and scheduling them +* Your organization's security policies require separate roles for different operations +* You need to grant specific permissions only to the scheduling operations + +1. **Trust Relationships** + Your `scheduler_role` (or your client role if you did not configure + a `scheduler_role`) needs to be assumed by the EventBridge Scheduler + service: + ```json + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": { + "AWS": "", + "Service": [ + "scheduler.amazonaws.com" + ] + }, + "Action": "sts:AssumeRole" + } + ] + } + ``` + +2. **Required IAM Permissions for the client role** + + In addition to permissions needed to manage pipelines, the role on the +client side also needs the following permissions to create schedules on +EventBridge: + + ```json + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "scheduler:ListSchedules", + "scheduler:GetSchedule", + "scheduler:CreateSchedule", + "scheduler:UpdateSchedule", + "scheduler:DeleteSchedule" + ], + "Resource": "*" + }, + { + "Effect": "Allow", + "Action": "iam:PassRole", + "Resource": "arn:aws:iam::*:role/*", + "Condition": { + "StringLike": { + "iam:PassedToService": "scheduler.amazonaws.com" + } + } + } + ] + } + ``` + + Or you can use the `AmazonEventBridgeSchedulerFullAccess` managed policy. + + These permissions enable: + + * Creation and management of Pipeline Schedules + * Setting up trust relationships between services + * Managing IAM policies required for the scheduled execution + * Cleanup of resources when schedules are removed + + Without these permissions, the scheduling functionality will fail. Make +sure to configure them before attempting to use scheduled pipelines. + +3. **Required IAM Permissions for the `scheduler_role`** + + The `scheduler_role` requires the same permissions as the client role (that +would run the pipeline in a non-scheduled case) to launch and manage Sagemaker +jobs. This would be covered by the `AmazonSageMakerFullAccess` permission. + +
ZenML Scarf
\ No newline at end of file diff --git a/docs/book/component-guide/orchestrators/vertex.md b/docs/book/component-guide/orchestrators/vertex.md index 210d34f931c..8045306a0bb 100644 --- a/docs/book/component-guide/orchestrators/vertex.md +++ b/docs/book/component-guide/orchestrators/vertex.md @@ -184,7 +184,7 @@ For any runs executed on Vertex, you can get the URL to the Vertex UI in Python from zenml.client import Client pipeline_run = Client().get_pipeline_run("") -orchestrator_url = pipeline_run.run_metadata["orchestrator_url"].value +orchestrator_url = pipeline_run.run_metadata["orchestrator_url"] ``` ### Run pipelines on a schedule @@ -194,24 +194,37 @@ The Vertex Pipelines orchestrator supports running pipelines on a schedule using **How to schedule a pipeline** ```python +from datetime import datetime, timedelta + +from zenml import pipeline from zenml.config.schedule import Schedule +@pipeline +def first_pipeline(): + ... + # Run a pipeline every 5th minute -pipeline_instance.run( +first_pipeline = first_pipeline.with_options( schedule=Schedule( cron_expression="*/5 * * * *" ) ) +first_pipeline() + +@pipeline +def second_pipeline(): + ... # Run a pipeline every hour # starting in one day from now and ending in three days from now -pipeline_instance.run( +second_pipeline = second_pipeline.with_options( schedule=Schedule( - cron_expression="0 * * * *" - start_time=datetime.datetime.now() + datetime.timedelta(days=1), - end_time=datetime.datetime.now() + datetime.timedelta(days=3), + cron_expression="0 * * * *", + start_time=datetime.now() + timedelta(days=1), + end_time=datetime.now() + timedelta(days=3), ) ) +second_pipeline() ``` {% hint style="warning" %} @@ -233,23 +246,32 @@ In order to cancel a scheduled Vertex pipeline, you need to manually delete the For additional configuration of the Vertex orchestrator, you can pass `VertexOrchestratorSettings` which allows you to configure labels for your Vertex Pipeline jobs or specify which GPU to use. ```python -from zenml.integrations.gcp.flavors.vertex_orchestrator_flavor import VertexOrchestratorSettings -from kubernetes.client.models import V1Toleration - -vertex_settings = VertexOrchestratorSettings( - labels={"key": "value"} +from zenml.integrations.gcp.flavors.vertex_orchestrator_flavor import ( + VertexOrchestratorSettings ) + +vertex_settings = VertexOrchestratorSettings(labels={"key": "value"}) ``` If your pipelines steps have certain hardware requirements, you can specify them as `ResourceSettings`: ```python +from zenml.config import ResourceSettings + resource_settings = ResourceSettings(cpu_count=8, memory="16GB") ``` -To run your pipeline (or some steps of it) on a GPU, you will need to set both a node selector -and the gpu count as follows: +To run your pipeline (or some steps of it) on a GPU, you will need to set both +a node selector and the GPU count as follows: + ```python +from zenml import step, pipeline + +from zenml.config import ResourceSettings +from zenml.integrations.gcp.flavors.vertex_orchestrator_flavor import ( + VertexOrchestratorSettings +) + vertex_settings = VertexOrchestratorSettings( pod_settings={ "node_selectors": { @@ -258,33 +280,30 @@ vertex_settings = VertexOrchestratorSettings( } ) resource_settings = ResourceSettings(gpu_count=1) -``` -You can find available accelerator types [here](https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus). - -These settings can then be specified on either pipeline-level or step-level: -```python -# Either specify on pipeline-level -@pipeline( +# Either specify settings on step-level +@step( settings={ "orchestrator": vertex_settings, "resources": resource_settings, } ) -def my_pipeline(): +def my_step(): ... -# OR specify settings on step-level -@step( +# OR specify on pipeline-level +@pipeline( settings={ "orchestrator": vertex_settings, "resources": resource_settings, } ) -def my_step(): +def my_pipeline(): ... ``` +You can find available accelerator types [here](https://cloud.google.com/vertex-ai/docs/training/configure-compute#specifying_gpus). + Check out the [SDK docs](https://sdkdocs.zenml.io/latest/integration\_code\_docs/integrations-gcp/#zenml.integrations.gcp.flavors.vertex\_orchestrator\_flavor.VertexOrchestratorSettings) for a full list of available attributes and [this docs page](../../how-to/pipeline-development/use-configuration-files/runtime-configuration.md) for more information on how to specify settings. For more information and a full list of configurable attributes of the Vertex orchestrator, check out the [SDK Docs](https://sdkdocs.zenml.io/latest/integration\_code\_docs/integrations-gcp/#zenml.integrations.gcp.orchestrators.vertex\_orchestrator.VertexOrchestrator) . diff --git a/docs/book/how-to/pipeline-development/build-pipelines/schedule-a-pipeline.md b/docs/book/how-to/pipeline-development/build-pipelines/schedule-a-pipeline.md index be725e386fc..f922339393e 100644 --- a/docs/book/how-to/pipeline-development/build-pipelines/schedule-a-pipeline.md +++ b/docs/book/how-to/pipeline-development/build-pipelines/schedule-a-pipeline.md @@ -18,7 +18,7 @@ Schedules don't work for all orchestrators. Here is a list of all supported orch | [KubernetesOrchestrator](../../../component-guide/orchestrators/kubernetes.md) | ✅ | | [LocalOrchestrator](../../../component-guide/orchestrators/local.md) | ⛔️ | | [LocalDockerOrchestrator](../../../component-guide/orchestrators/local-docker.md) | ⛔️ | -| [SagemakerOrchestrator](../../../component-guide/orchestrators/sagemaker.md) | ⛔️ | +| [SagemakerOrchestrator](../../../component-guide/orchestrators/sagemaker.md) | ✅ | | [SkypilotAWSOrchestrator](../../../component-guide/orchestrators/skypilot-vm.md) | ⛔️ | | [SkypilotAzureOrchestrator](../../../component-guide/orchestrators/skypilot-vm.md) | ⛔️ | | [SkypilotGCPOrchestrator](../../../component-guide/orchestrators/skypilot-vm.md) | ⛔️ | diff --git a/docs/mocked_libs.json b/docs/mocked_libs.json index aa72411b38a..44a610aff08 100644 --- a/docs/mocked_libs.json +++ b/docs/mocked_libs.json @@ -191,6 +191,7 @@ "sagemaker.workflow.execution_variables", "sagemaker.workflow.pipeline", "sagemaker.workflow.steps", + "sagemaker.workflow.triggers", "scipy", "scipy.sparse", "sklearn", diff --git a/pyproject.toml b/pyproject.toml index b02f19c49f9..6602d78313b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,7 +116,7 @@ azure-mgmt-resource = { version = ">=21.0.0", optional = true } s3fs = { version = ">=2022.11.0", optional = true } # Optional dependencies for the Sagemaker orchestrator -sagemaker = { version = ">=2.117.0", optional = true } +sagemaker = { version = ">=2.199.0", optional = true } # Optional dependencies for the GCS artifact store gcsfs = { version = ">=2022.11.0", optional = true } diff --git a/scripts/summarize_docs.py b/scripts/summarize_docs.py index cf759821f1d..19602602050 100644 --- a/scripts/summarize_docs.py +++ b/scripts/summarize_docs.py @@ -13,29 +13,32 @@ # permissions and limitations under the License. import os import re -import json -from openai import OpenAI from pathlib import Path -from typing import List, Dict + +from openai import OpenAI # Initialize OpenAI client -client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) +client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + def extract_content_blocks(md_content: str) -> str: """Extracts content blocks while preserving order and marking code blocks.""" - parts = re.split(r'(```[\s\S]*?```)', md_content) - + parts = re.split(r"(```[\s\S]*?```)", md_content) + processed_content = "" for part in parts: - if part.startswith('```'): - processed_content += "\n[CODE_BLOCK_START]\n" + part + "\n[CODE_BLOCK_END]\n" + if part.startswith("```"): + processed_content += ( + "\n[CODE_BLOCK_START]\n" + part + "\n[CODE_BLOCK_END]\n" + ) else: - cleaned_text = re.sub(r'\s+', ' ', part).strip() + cleaned_text = re.sub(r"\s+", " ", part).strip() if cleaned_text: processed_content += "\n" + cleaned_text + "\n" - + return processed_content + def summarize_content(content: str, file_path: str) -> str: """Summarizes content using OpenAI API.""" try: @@ -44,7 +47,7 @@ def summarize_content(content: str, file_path: str) -> str: messages=[ { "role": "system", - "content": "You are a technical documentation summarizer." + "content": "You are a technical documentation summarizer.", }, { "role": "user", @@ -53,21 +56,22 @@ def summarize_content(content: str, file_path: str) -> str: Make it concise but ensure NO critical information is lost and some details that you think are important are kept. Make the code shorter where possible keeping only the most important parts while preserving syntax and accuracy: - {content}""" - } + {content}""", + }, ], temperature=0.3, - max_tokens=2000 + max_tokens=2000, ) return response.choices[0].message.content except Exception as e: print(f"Error summarizing {file_path}: {e}") return "" + def main(): docs_dir = "docs/book" output_file = "summarized_docs.txt" - + # Get markdown files exclude_files = ["toc.md"] md_files = list(Path(docs_dir).rglob("*.md")) @@ -77,21 +81,22 @@ def main(): with open(output_file, "w", encoding="utf-8") as out_f: for file_path in md_files: try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: content = f.read() - + processed_content = extract_content_blocks(content) summary = summarize_content(processed_content, str(file_path)) - + if summary: out_f.write(f"=== File: {file_path} ===\n\n") out_f.write(summary) - out_f.write("\n\n" + "="*50 + "\n\n") - + out_f.write("\n\n" + "=" * 50 + "\n\n") + print(f"Processed: {file_path}") - + except Exception as e: print(f"Error processing {file_path}: {e}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/upload_to_huggingface.py b/scripts/upload_to_huggingface.py index eb605a309a2..7a6cd154f75 100644 --- a/scripts/upload_to_huggingface.py +++ b/scripts/upload_to_huggingface.py @@ -1,25 +1,28 @@ -from huggingface_hub import HfApi import os +from huggingface_hub import HfApi + + def upload_to_huggingface(): api = HfApi(token=os.environ["HF_TOKEN"]) - + # Upload OpenAI summary api.upload_file( path_or_fileobj="summarized_docs.txt", path_in_repo="how-to-guides.txt", repo_id="zenml/llms.txt", - repo_type="dataset" + repo_type="dataset", ) - + # Upload repomix outputs for filename in ["component-guide.txt", "basics.txt", "llms-full.txt"]: api.upload_file( path_or_fileobj=f"repomix-outputs/{filename}", path_in_repo=filename, repo_id="zenml/llms.txt", - repo_type="dataset" + repo_type="dataset", ) + if __name__ == "__main__": - upload_to_huggingface() \ No newline at end of file + upload_to_huggingface() diff --git a/src/zenml/enums.py b/src/zenml/enums.py index 0469048f3d9..e8d15001a28 100644 --- a/src/zenml/enums.py +++ b/src/zenml/enums.py @@ -376,6 +376,7 @@ class MetadataResourceTypes(StrEnum): STEP_RUN = "step_run" ARTIFACT_VERSION = "artifact_version" MODEL_VERSION = "model_version" + SCHEDULE = "schedule" class DatabaseBackupStrategy(StrEnum): diff --git a/src/zenml/integrations/aws/__init__.py b/src/zenml/integrations/aws/__init__.py index c18c90f4deb..f7c5abcc2d9 100644 --- a/src/zenml/integrations/aws/__init__.py +++ b/src/zenml/integrations/aws/__init__.py @@ -35,12 +35,13 @@ S3_RESOURCE_TYPE = "s3-bucket" AWS_IMAGE_BUILDER_FLAVOR = "aws" + class AWSIntegration(Integration): """Definition of AWS integration for ZenML.""" NAME = AWS REQUIREMENTS = [ - "sagemaker>=2.117.0", + "sagemaker>=2.199.0", "kubernetes", "aws-profile-manager", ] diff --git a/src/zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py b/src/zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py index 2898f24cc67..a9efa79a42e 100644 --- a/src/zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +++ b/src/zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py @@ -85,6 +85,7 @@ class SagemakerOrchestratorSettings(BaseSettings): to the container is configured with input_data_s3_mode. Two possible input types: - str: S3 location where training data is saved. + - Dict[str, str]: (ChannelName, S3Location) which represent - Dict[str, str]: (ChannelName, S3Location) which represent channels (e.g. training, validation, testing) where specific parts of the data are saved in S3. @@ -184,6 +185,10 @@ class SagemakerOrchestratorConfig( Attributes: execution_role: The IAM role ARN to use for the pipeline. + scheduler_role: The ARN of the IAM role that will be assumed by + the EventBridge service to launch Sagemaker pipelines + (For more details regarding the required permissions, please check: + https://docs.zenml.io/stack-components/orchestrators/sagemaker#required-iam-permissions-for-schedules) aws_access_key_id: The AWS access key ID to use to authenticate to AWS. If not provided, the value from the default AWS config will be used. aws_secret_access_key: The AWS secret access key to use to authenticate @@ -203,6 +208,7 @@ class SagemakerOrchestratorConfig( """ execution_role: str + scheduler_role: Optional[str] = None aws_access_key_id: Optional[str] = SecretField(default=None) aws_secret_access_key: Optional[str] = SecretField(default=None) aws_profile: Optional[str] = None @@ -232,6 +238,15 @@ def is_synchronous(self) -> bool: """ return self.synchronous + @property + def is_schedulable(self) -> bool: + """Whether the orchestrator is schedulable or not. + + Returns: + Whether the orchestrator is schedulable or not. + """ + return True + class SagemakerOrchestratorFlavor(BaseOrchestratorFlavor): """Flavor for the Sagemaker orchestrator.""" diff --git a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py index f832647a97e..52ba430b082 100644 --- a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +++ b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py @@ -15,6 +15,7 @@ import os import re +from datetime import datetime, timezone from typing import ( TYPE_CHECKING, Any, @@ -35,14 +36,20 @@ from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.pipeline import Pipeline from sagemaker.workflow.steps import ProcessingStep, TrainingStep +from sagemaker.workflow.triggers import PipelineSchedule +from zenml.client import Client from zenml.config.base_settings import BaseSettings from zenml.constants import ( METADATA_ORCHESTRATOR_LOGS_URL, METADATA_ORCHESTRATOR_RUN_ID, METADATA_ORCHESTRATOR_URL, ) -from zenml.enums import ExecutionStatus, StackComponentType +from zenml.enums import ( + ExecutionStatus, + MetadataResourceTypes, + StackComponentType, +) from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import ( SagemakerOrchestratorConfig, SagemakerOrchestratorSettings, @@ -69,6 +76,36 @@ logger = get_logger(__name__) +def dissect_schedule_arn( + schedule_arn: str, +) -> Tuple[Optional[str], Optional[str]]: + """Extracts the region and the name from an EventBridge schedule ARN. + + Args: + schedule_arn: The ARN of the EventBridge schedule. + + Returns: + Region Name, Schedule Name (including the group name) + + Raises: + ValueError: If the input is not a properly formatted ARN. + """ + # Split the ARN into parts + arn_parts = schedule_arn.split(":") + + # Validate ARN structure + if len(arn_parts) < 6 or not arn_parts[5].startswith("schedule/"): + raise ValueError("Invalid EventBridge schedule ARN format.") + + # Extract the region + region = arn_parts[3] + + # Extract the group name and schedule name + name = arn_parts[5].split("schedule/")[1] + + return region, name + + def dissect_pipeline_execution_arn( pipeline_execution_arn: str, ) -> Tuple[Optional[str], Optional[str], Optional[str]]: @@ -237,21 +274,15 @@ def prepare_or_run_pipeline( environment. Raises: - RuntimeError: If a connector is used that does not return a - `boto3.Session` object. + RuntimeError: If there is an error creating or scheduling the + pipeline. TypeError: If the network_config passed is not compatible with the AWS SageMaker NetworkConfig class. + ValueError: If the schedule is not valid. Yields: A dictionary of metadata related to the pipeline run. """ - if deployment.schedule: - logger.warning( - "The Sagemaker Orchestrator currently does not support the " - "use of schedules. The `schedule` will be ignored " - "and the pipeline will be run immediately." - ) - # sagemaker requires pipelineName to use alphanum and hyphens only unsanitized_orchestrator_run_name = get_orchestrator_run_name( pipeline_name=deployment.pipeline_configuration.name @@ -459,7 +490,7 @@ def prepare_or_run_pipeline( sagemaker_steps.append(sagemaker_step) - # construct the pipeline from the sagemaker_steps + # Create the pipeline pipeline = Pipeline( name=orchestrator_run_name, steps=sagemaker_steps, @@ -479,39 +510,207 @@ def prepare_or_run_pipeline( if settings.pipeline_tags else None, ) - execution = pipeline.start() - logger.warning( - "Steps can take 5-15 minutes to start running " - "when using the Sagemaker Orchestrator." - ) - # Yield metadata based on the generated execution object - yield from self.compute_metadata( - execution=execution, settings=settings - ) + # Handle scheduling if specified + if deployment.schedule: + if settings.synchronous: + logger.warning( + "The 'synchronous' setting is ignored for scheduled " + "pipelines since they run independently of the " + "deployment process." + ) - # mainly for testing purposes, we wait for the pipeline to finish - if settings.synchronous: - logger.info( - "Executing synchronously. Waiting for pipeline to finish... \n" - "At this point you can `Ctrl-C` out without cancelling the " - "execution." + schedule_name = orchestrator_run_name + next_execution = None + + # Create PipelineSchedule based on schedule type + if deployment.schedule.cron_expression: + cron_exp = self._validate_cron_expression( + deployment.schedule.cron_expression + ) + schedule = PipelineSchedule( + name=schedule_name, + cron=cron_exp, + start_date=deployment.schedule.start_time, + enabled=True, + ) + elif deployment.schedule.interval_second: + # This is necessary because SageMaker's PipelineSchedule rate + # expressions require minutes as the minimum time unit. + # Even if a user specifies an interval of less than 60 seconds, + # it will be rounded up to 1 minute. + minutes = max( + 1, + int( + deployment.schedule.interval_second.total_seconds() + / 60 + ), + ) + schedule = PipelineSchedule( + name=schedule_name, + rate=(minutes, "minutes"), + start_date=deployment.schedule.start_time, + enabled=True, + ) + next_execution = ( + deployment.schedule.start_time + or datetime.now(timezone.utc) + ) + deployment.schedule.interval_second + else: + # One-time schedule + execution_time = ( + deployment.schedule.run_once_start_time + or deployment.schedule.start_time + ) + if not execution_time: + raise ValueError( + "A start time must be specified for one-time " + "schedule execution" + ) + schedule = PipelineSchedule( + name=schedule_name, + at=execution_time.astimezone(timezone.utc), + enabled=True, + ) + next_execution = execution_time + + # Get the current role ARN if not explicitly configured + if self.config.scheduler_role is None: + logger.info( + "No scheduler_role configured. Trying to extract it from " + "the client side authentication." + ) + sts = session.boto_session.client("sts") + try: + scheduler_role_arn = sts.get_caller_identity()["Arn"] + # If this is a user ARN, try to get the role ARN + if ":user/" in scheduler_role_arn: + logger.warning( + f"Using IAM user credentials " + f"({scheduler_role_arn}). For production " + "environments, it's recommended to use IAM roles " + "instead." + ) + # If this is an assumed role, extract the role ARN + elif ":assumed-role/" in scheduler_role_arn: + # Convert assumed-role ARN format to role ARN format + # From: arn:aws:sts::123456789012:assumed-role/role-name/session-name + # To: arn:aws:iam::123456789012:role/role-name + scheduler_role_arn = re.sub( + r"arn:aws:sts::(\d+):assumed-role/([^/]+)/.*", + r"arn:aws:iam::\1:role/\2", + scheduler_role_arn, + ) + elif ":role/" not in scheduler_role_arn: + raise RuntimeError( + f"Unexpected credential type " + f"({scheduler_role_arn}). Please use IAM " + f"roles for SageMaker pipeline scheduling." + ) + else: + raise RuntimeError( + "The ARN of the caller identity " + f"`{scheduler_role_arn}` does not " + "include a user or a proper role." + ) + except Exception: + raise RuntimeError( + "Failed to get current role ARN. This means the " + "your client side credentials that you are " + "is not configured correctly to schedule sagemaker " + "pipelines. For more information, please check:" + "https://docs.zenml.io/stack-components/orchestrators/sagemaker#required-iam-permissions-for-schedules" + ) + else: + scheduler_role_arn = self.config.scheduler_role + + # Attach schedule to pipeline + triggers = pipeline.put_triggers( + triggers=[schedule], + role_arn=scheduler_role_arn, ) + logger.info(f"The schedule ARN is: {triggers[0]}") + try: - execution.wait( - delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS + from zenml.models import RunMetadataResource + + schedule_metadata = self.generate_schedule_metadata( + schedule_arn=triggers[0] ) - logger.info("Pipeline completed successfully.") - except WaiterError: - raise RuntimeError( - "Timed out while waiting for pipeline execution to " - "finish. For long-running pipelines we recommend " - "configuring your orchestrator for asynchronous execution. " - "The following command does this for you: \n" - f"`zenml orchestrator update {self.name} " - f"--synchronous=False`" + + Client().create_run_metadata( + metadata=schedule_metadata, # type: ignore[arg-type] + resources=[ + RunMetadataResource( + id=deployment.schedule.id, + type=MetadataResourceTypes.SCHEDULE, + ) + ], + ) + except Exception as e: + logger.debug( + "There was an error attaching metadata to the " + f"schedule: {e}" ) + logger.info( + f"Successfully scheduled pipeline with name: {schedule_name}\n" + + ( + f"First execution will occur at: " + f"{next_execution.strftime('%Y-%m-%d %H:%M:%S UTC')}" + if next_execution + else f"Using cron expression: " + f"{deployment.schedule.cron_expression}" + ) + + ( + f" (and every {minutes} minutes after)" + if deployment.schedule.interval_second + else "" + ) + ) + logger.info( + "\n\nIn order to cancel the schedule, you can use execute " + "the following command:\n" + ) + logger.info( + f"`aws scheduler delete-schedule --name {schedule_name}`" + ) + else: + # Execute the pipeline immediately if no schedule is specified + execution = pipeline.start() + logger.warning( + "Steps can take 5-15 minutes to start running " + "when using the Sagemaker Orchestrator." + ) + + # Yield metadata based on the generated execution object + yield from self.compute_metadata( + execution_arn=execution.arn, settings=settings + ) + + # mainly for testing purposes, we wait for the pipeline to finish + if settings.synchronous: + logger.info( + "Executing synchronously. Waiting for pipeline to " + "finish... \n" + "At this point you can `Ctrl-C` out without cancelling the " + "execution." + ) + try: + execution.wait( + delay=POLLING_DELAY, max_attempts=MAX_POLLING_ATTEMPTS + ) + logger.info("Pipeline completed successfully.") + except WaiterError: + raise RuntimeError( + "Timed out while waiting for pipeline execution to " + "finish. For long-running pipelines we recommend " + "configuring your orchestrator for asynchronous " + "execution. The following command does this for you: \n" + f"`zenml orchestrator update {self.name} " + f"--synchronous=False`" + ) + def get_pipeline_run_metadata( self, run_id: UUID ) -> Dict[str, "MetadataType"]: @@ -523,10 +722,22 @@ def get_pipeline_run_metadata( Returns: A dictionary of metadata. """ - pipeline_execution_arn = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID] - run_metadata: Dict[str, "MetadataType"] = { - "pipeline_execution_arn": pipeline_execution_arn, - } + from zenml import get_step_context + + execution_arn = os.environ[ENV_ZENML_SAGEMAKER_RUN_ID] + + run_metadata: Dict[str, "MetadataType"] = {} + + settings = cast( + SagemakerOrchestratorSettings, + self.get_settings(get_step_context().pipeline_run), + ) + + for metadata in self.compute_metadata( + execution_arn=execution_arn, + settings=settings, + ): + run_metadata.update(metadata) return run_metadata @@ -588,56 +799,57 @@ def fetch_status(self, run: "PipelineRunResponse") -> ExecutionStatus: def compute_metadata( self, - execution: Any, + execution_arn: str, settings: SagemakerOrchestratorSettings, ) -> Iterator[Dict[str, MetadataType]]: """Generate run metadata based on the generated Sagemaker Execution. Args: - execution: The corresponding _PipelineExecution object. + execution_arn: The ARN of the pipeline execution. settings: The Sagemaker orchestrator settings. Yields: A dictionary of metadata related to the pipeline run. """ - # Metadata - metadata: Dict[str, MetadataType] = {} - # Orchestrator Run ID - if run_id := self._compute_orchestrator_run_id(execution): - metadata[METADATA_ORCHESTRATOR_RUN_ID] = run_id + metadata: Dict[str, MetadataType] = { + "pipeline_execution_arn": execution_arn, + METADATA_ORCHESTRATOR_RUN_ID: execution_arn, + } # URL to the Sagemaker's pipeline view - if orchestrator_url := self._compute_orchestrator_url(execution): + if orchestrator_url := self._compute_orchestrator_url( + execution_arn=execution_arn + ): metadata[METADATA_ORCHESTRATOR_URL] = Uri(orchestrator_url) # URL to the corresponding CloudWatch page if logs_url := self._compute_orchestrator_logs_url( - execution, settings + execution_arn=execution_arn, settings=settings ): metadata[METADATA_ORCHESTRATOR_LOGS_URL] = Uri(logs_url) yield metadata - @staticmethod def _compute_orchestrator_url( - pipeline_execution: Any, + self, + execution_arn: Any, ) -> Optional[str]: """Generate the Orchestrator Dashboard URL upon pipeline execution. Args: - pipeline_execution: The corresponding _PipelineExecution object. + execution_arn: The ARN of the pipeline execution. Returns: the URL to the dashboard view in SageMaker. """ try: region_name, pipeline_name, execution_id = ( - dissect_pipeline_execution_arn(pipeline_execution.arn) + dissect_pipeline_execution_arn(execution_arn) ) # Get the Sagemaker session - session = pipeline_execution.sagemaker_session + session = self._get_sagemaker_session() # List the Studio domains and get the Studio Domain ID domains_response = session.sagemaker_client.list_domains() @@ -657,13 +869,13 @@ def _compute_orchestrator_url( @staticmethod def _compute_orchestrator_logs_url( - pipeline_execution: Any, + execution_arn: Any, settings: SagemakerOrchestratorSettings, ) -> Optional[str]: """Generate the CloudWatch URL upon pipeline execution. Args: - pipeline_execution: The corresponding _PipelineExecution object. + execution_arn: The ARN of the pipeline execution. settings: The Sagemaker orchestrator settings. Returns: @@ -671,7 +883,7 @@ def _compute_orchestrator_logs_url( """ try: region_name, _, execution_id = dissect_pipeline_execution_arn( - pipeline_execution.arn + execution_arn ) use_training_jobs = True @@ -693,22 +905,48 @@ def _compute_orchestrator_logs_url( return None @staticmethod - def _compute_orchestrator_run_id( - pipeline_execution: Any, - ) -> Optional[str]: - """Fetch the Orchestrator Run ID upon pipeline execution. + def generate_schedule_metadata(schedule_arn: str) -> Dict[str, str]: + """Attaches metadata to the ZenML Schedules. Args: - pipeline_execution: The corresponding _PipelineExecution object. + schedule_arn: The trigger ARNs that is generated on the AWS side. Returns: - the Execution ID of the run in SageMaker. + a dictionary containing metadata related to the schedule. """ - try: - return str(pipeline_execution.arn) + region, name = dissect_schedule_arn(schedule_arn=schedule_arn) - except Exception as e: - logger.warning( - f"There was an issue while extracting the pipeline run ID: {e}" + return { + "trigger_url": ( + f"https://{region}.console.aws.amazon.com/scheduler/home" + f"?region={region}#schedules/{name}" + ), + } + + @staticmethod + def _validate_cron_expression(cron_expression: str) -> str: + """Validates and formats a cron expression for SageMaker schedules. + + Args: + cron_expression: The cron expression to validate + + Returns: + The formatted cron expression + + Raises: + ValueError: If the cron expression is invalid + """ + # Strip any "cron(" prefix if it exists + cron_exp = cron_expression.replace("cron(", "").replace(")", "") + + # Split into components + parts = cron_exp.split() + if len(parts) not in [6, 7]: # AWS cron requires 6 or 7 fields + raise ValueError( + f"Invalid cron expression: {cron_expression}. AWS cron " + "expressions must have 6 or 7 fields: minute hour day-of-month " + "month day-of-week year(optional). Example: '15 10 ? * 6L " + "2022-2023'" ) - return None + + return cron_exp diff --git a/src/zenml/integrations/huggingface/__init__.py b/src/zenml/integrations/huggingface/__init__.py index f1e1721bbb5..e11241afdd6 100644 --- a/src/zenml/integrations/huggingface/__init__.py +++ b/src/zenml/integrations/huggingface/__init__.py @@ -47,16 +47,11 @@ def get_requirements(cls, target_os: Optional[str] = None) -> List[str]: A list of requirements. """ requirements = [ - "datasets", + "datasets>=2.16.0", "huggingface_hub>0.19.0", "accelerate", "bitsandbytes>=0.41.3", "peft", - # temporary fix for CI issue similar to: - # - https://github.com/huggingface/datasets/issues/6737 - # - https://github.com/huggingface/datasets/issues/6697 - # TODO try relaxing it back going forward - "fsspec<=2023.12.0", "transformers", ] diff --git a/src/zenml/models/v2/core/schedule.py b/src/zenml/models/v2/core/schedule.py index 0e7dc01c421..cc77ab2cbd1 100644 --- a/src/zenml/models/v2/core/schedule.py +++ b/src/zenml/models/v2/core/schedule.py @@ -14,13 +14,14 @@ """Models representing schedules.""" import datetime -from typing import Optional, Union +from typing import Dict, Optional, Union from uuid import UUID from pydantic import Field, model_validator from zenml.constants import STR_FIELD_MAX_LENGTH from zenml.logger import get_logger +from zenml.metadata.metadata_types import MetadataType from zenml.models.v2.base.base import BaseUpdate from zenml.models.v2.base.scoped import ( WorkspaceScopedFilter, @@ -136,6 +137,11 @@ class ScheduleResponseMetadata(WorkspaceScopedResponseMetadata): orchestrator_id: Optional[UUID] pipeline_id: Optional[UUID] + run_metadata: Dict[str, MetadataType] = Field( + title="Metadata associated with this schedule.", + default={}, + ) + class ScheduleResponseResources(WorkspaceScopedResponseResources): """Class for all resource models associated with the schedule entity.""" @@ -272,6 +278,15 @@ def pipeline_id(self) -> Optional[UUID]: """ return self.get_metadata().pipeline_id + @property + def run_metadata(self) -> Dict[str, MetadataType]: + """The `run_metadata` property. + + Returns: + the value of the property. + """ + return self.get_metadata().run_metadata + # ------------------ Filter Model ------------------ diff --git a/src/zenml/stack/stack_component.py b/src/zenml/stack/stack_component.py index fc073b6310e..84f24d7af84 100644 --- a/src/zenml/stack/stack_component.py +++ b/src/zenml/stack/stack_component.py @@ -29,7 +29,11 @@ from zenml.enums import StackComponentType from zenml.exceptions import AuthorizationException from zenml.logger import get_logger -from zenml.models import ServiceConnectorRequirements, StepRunResponse +from zenml.models import ( + PipelineRunResponse, + ServiceConnectorRequirements, + StepRunResponse, +) from zenml.utils import ( pydantic_utils, secret_utils, @@ -496,6 +500,7 @@ def get_settings( "StepRunInfo", "PipelineDeploymentBase", "PipelineDeploymentResponse", + "PipelineRunResponse", ], ) -> "BaseSettings": """Gets settings for this stack component. @@ -527,7 +532,10 @@ def get_settings( all_settings = ( container.config.settings - if isinstance(container, (Step, StepRunResponse, StepRunInfo)) + if isinstance( + container, + (Step, StepRunResponse, StepRunInfo, PipelineRunResponse), + ) else container.pipeline_configuration.settings ) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 32af09d2a46..41b33787e03 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1009,6 +1009,8 @@ def create_run_metadata( verify_models.append(zen_store().get_artifact_version(resource.id)) elif resource.type == MetadataResourceTypes.MODEL_VERSION: verify_models.append(zen_store().get_model_version(resource.id)) + elif resource.type == MetadataResourceTypes.SCHEDULE: + verify_models.append(zen_store().get_schedule(resource.id)) else: raise RuntimeError(f"Unknown resource type: {resource.type}") diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index b35e4e88982..4283b4f2c8d 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -269,6 +269,13 @@ def fetch_metadata_collection(self) -> Dict[str, List[RunMetadataEntry]]: for k, v in step_metadata.items(): metadata_collection[f"{s.name}::{k}"] = v + # Fetch the metadata related to the schedule of this run + if self.deployment is not None: + if schedule := self.deployment.schedule: + schedule_metadata = schedule.fetch_metadata_collection() + for k, v in schedule_metadata.items(): + metadata_collection[f"schedule:{k}"] = v + return metadata_collection def to_model( diff --git a/src/zenml/zen_stores/schemas/schedule_schema.py b/src/zenml/zen_stores/schemas/schedule_schema.py index 5a56765bec4..577e62f34be 100644 --- a/src/zenml/zen_stores/schemas/schedule_schema.py +++ b/src/zenml/zen_stores/schemas/schedule_schema.py @@ -14,11 +14,12 @@ """SQL Model Implementations for Pipeline Schedules.""" from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, List, Optional from uuid import UUID from sqlmodel import Field, Relationship +from zenml.enums import MetadataResourceTypes from zenml.models import ( ScheduleRequest, ScheduleResponse, @@ -31,15 +32,19 @@ from zenml.zen_stores.schemas.pipeline_schemas import PipelineSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.user_schemas import UserSchema +from zenml.zen_stores.schemas.utils import RunMetadataInterface from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema if TYPE_CHECKING: from zenml.zen_stores.schemas.pipeline_deployment_schemas import ( PipelineDeploymentSchema, ) + from zenml.zen_stores.schemas.run_metadata_schemas import ( + RunMetadataSchema, + ) -class ScheduleSchema(NamedSchema, table=True): +class ScheduleSchema(NamedSchema, RunMetadataInterface, table=True): """SQL Model for schedules.""" __tablename__ = "schedule" @@ -89,6 +94,15 @@ class ScheduleSchema(NamedSchema, table=True): back_populates="schedules" ) + run_metadata: List["RunMetadataSchema"] = Relationship( + sa_relationship_kwargs=dict( + secondary="run_metadata_resource", + primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.SCHEDULE.value}', foreign(RunMetadataResourceSchema.resource_id)==ScheduleSchema.id)", + secondaryjoin="RunMetadataSchema.id==foreign(RunMetadataResourceSchema.run_metadata_id)", + overlaps="run_metadata", + ), + ) + active: bool cron_expression: Optional[str] = Field(nullable=True) start_time: Optional[datetime] = Field(nullable=True) @@ -196,6 +210,7 @@ def to_model( workspace=self.workspace.to_model(), pipeline_id=self.pipeline_id, orchestrator_id=self.orchestrator_id, + run_metadata=self.fetch_metadata(), ) return ScheduleResponse( diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 3c87f39e190..3e2ce22d981 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -100,9 +100,11 @@ ModelVersionPipelineRunRequest, ModelVersionRequest, ModelVersionUpdate, + PipelineRequest, PipelineRunFilter, PipelineRunResponse, RunMetadataResource, + ScheduleRequest, ServiceAccountFilter, ServiceAccountRequest, ServiceAccountUpdate, @@ -5444,6 +5446,55 @@ def test_metadata_full_cycle_with_cascade_deletion( pr if type_ == MetadataResourceTypes.PIPELINE_RUN else sr ) + elif type_ == MetadataResourceTypes.SCHEDULE: + step_name = sample_name("foo") + new_pipeline = client.zen_store.create_pipeline( + pipeline=PipelineRequest( + name="foo", + user=client.active_user.id, + workspace=client.active_workspace.id, + ) + ) + resource = client.zen_store.create_schedule( + ScheduleRequest( + name="foo", + cron_expression="*/5 * * * *", + user=client.active_user.id, + workspace=client.active_workspace.id, + orchestrator_id=client.active_stack.orchestrator.id, + active=False, + pipeline_id=new_pipeline.id, + ) + ) + deployment = client.zen_store.create_deployment( + PipelineDeploymentRequest( + user=client.active_user.id, + workspace=client.active_workspace.id, + run_name_template=sample_name("foo"), + pipeline_configuration=PipelineConfiguration( + name=sample_name("foo") + ), + stack=client.active_stack.id, + client_version="0.1.0", + server_version="0.1.0", + step_configurations={ + step_name: Step( + spec=StepSpec( + source=Source( + module="acme.foo", + type=SourceType.INTERNAL, + ), + upstream_steps=[], + ), + config=StepConfiguration(name=step_name), + ) + }, + schedule=resource.id, + ) + ) + else: + raise ValueError("Unknown/untested MetadataResourceType.") + client.zen_store.create_run_metadata( RunMetadataRequest( user=client.active_user.id, @@ -5465,6 +5516,10 @@ def test_metadata_full_cycle_with_cascade_deletion( rm = client.zen_store.get_run_step(resource.id, True).run_metadata assert rm["foo"] == "bar" + elif type_ == MetadataResourceTypes.SCHEDULE: + rm = client.zen_store.get_schedule(resource.id, True).run_metadata + assert rm["foo"] == "bar" + if type_ == MetadataResourceTypes.ARTIFACT_VERSION: client.zen_store.delete_artifact_version(resource.id) client.zen_store.delete_artifact(artifact.id) @@ -5476,6 +5531,10 @@ def test_metadata_full_cycle_with_cascade_deletion( ): client.zen_store.delete_run(pr.id) client.zen_store.delete_deployment(deployment.id) + elif type_ == MetadataResourceTypes.SCHEDULE: + client.zen_store.delete_deployment(deployment.id) + client.zen_store.delete_schedule(resource.id) + client.zen_store.delete_pipeline(new_pipeline.id) client.zen_store.delete_stack_component(sc.id) diff --git a/tests/integration/integrations/aws/orchestrators/test_sagemaker_orchestrator.py b/tests/integration/integrations/aws/orchestrators/test_sagemaker_orchestrator.py index 788dcb05063..523e7059801 100644 --- a/tests/integration/integrations/aws/orchestrators/test_sagemaker_orchestrator.py +++ b/tests/integration/integrations/aws/orchestrators/test_sagemaker_orchestrator.py @@ -12,7 +12,6 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. - from zenml.enums import StackComponentType from zenml.integrations.aws.flavors import SagemakerOrchestratorFlavor