diff --git a/sdk/python/kfp/compiler/compiler_test.py b/sdk/python/kfp/compiler/compiler_test.py index 2433f09bc6d1..6ac26568c892 100644 --- a/sdk/python/kfp/compiler/compiler_test.py +++ b/sdk/python/kfp/compiler/compiler_test.py @@ -3847,6 +3847,42 @@ def outer(): foo_platform_set_bar_feature(task, 12) +class TestPipelineSemaphoreMutex(unittest.TestCase): + + def test_pipeline_with_semaphore_and_mutex(self): + from kfp import compiler + from kfp import dsl + from kfp.dsl.pipeline_config import PipelineConfig + + config = PipelineConfig() + config.set_semaphore_key('semaphore') + config.set_mutex_name('mutex') + + @dsl.pipeline(pipeline_config=config) + def my_pipeline(): + task = comp() + + with tempfile.TemporaryDirectory() as tempdir: + output_yaml = os.path.join(tempdir, 'pipeline.yaml') + compiler.Compiler().compile( + pipeline_func=my_pipeline, package_path=output_yaml) + + with open(output_yaml, 'r') as f: + pipeline_docs = list(yaml.safe_load_all(f)) + + pipeline_spec = None + for doc in pipeline_docs: + if 'platforms' in doc: + pipeline_spec = doc + break + + if pipeline_spec: + kubernetes_spec = pipeline_spec['platforms']['kubernetes'][ + 'pipelineConfig'] + assert kubernetes_spec['semaphoreKey'] == 'semaphore' + assert kubernetes_spec['mutexName'] == 'mutex' + + class ExtractInputOutputDescription(unittest.TestCase): def test_no_descriptions(self): diff --git a/sdk/python/kfp/dsl/pipeline_config.py b/sdk/python/kfp/dsl/pipeline_config.py index 8a730548d8b8..b1d2f86a15f3 100644 --- a/sdk/python/kfp/dsl/pipeline_config.py +++ b/sdk/python/kfp/dsl/pipeline_config.py @@ -24,8 +24,16 @@ def __init__(self): def set_semaphore_key(self, semaphore_key: str): """Set the name of the semaphore to control pipeline concurrency. + The semaphore is configured via a ConfigMap. By default, the ConfigMap is + named "semaphore-config", but this name can be specified through the APIServer + deployment manifests using an environment variable named SEMAPHORE_CONFIGMAP_NAME. + If the environment variable is not specified, the default name "semaphore-config" + is used. The semaphore key is provided through the pipeline configuration. + If a pipeline has a semaphore, the backend maps the semaphore to the ConfigMap + using the key provided by the user. + Args: - semaphore_key (str): Name of the semaphore. + semaphore_key (str): The key used to map to the ConfigMap. """ self.semaphore_key = semaphore_key.strip()