diff --git a/src/hera/workflows/workflow_template.py b/src/hera/workflows/workflow_template.py index 651d434c7..3ed063daf 100644 --- a/src/hera/workflows/workflow_template.py +++ b/src/hera/workflows/workflow_template.py @@ -153,6 +153,7 @@ def from_file(cls, yaml_file: Union[Path, str]) -> ModelMapperMixin: def _get_as_workflow(self, generate_name: Optional[str]) -> Workflow: workflow = cast(Workflow, Workflow.from_dict(self.to_dict())) workflow.kind = "Workflow" + workflow.workflows_service = self.workflows_service # bind this workflow to the same service if generate_name is not None: workflow.generate_name = generate_name diff --git a/tests/test_unit/test_workflow_template.py b/tests/test_unit/test_workflow_template.py index 47b496de1..03b4b3a23 100644 --- a/tests/test_unit/test_workflow_template.py +++ b/tests/test_unit/test_workflow_template.py @@ -48,11 +48,11 @@ def test_workflow_template_create(global_config_fixture): def test_workflow_template_create_as_workflow(): - with patch.object(WorkflowsService, "create_workflow", return_value=MagicMock()) as create_workflow: + with patch.object(WorkflowsService, "create_workflow", return_value=MagicMock(), autospec=True) as create_workflow: # We have to patch the function at the class level because create_as_workflow copies the workflows service # from the WorkflowTemplate to the *separate* Workflow object. - ws = WorkflowsService(namespace="my-namespace") + ws = WorkflowsService(namespace="my-namespace", host="https://localhost:2746") # Note we set the name to None here, otherwise the workflow will take the name from the returned object create_workflow.return_value.metadata.name = None @@ -75,17 +75,22 @@ def test_workflow_template_create_as_workflow(): wt.create_as_workflow() # THEN - wt.workflows_service.create_workflow.assert_called_once_with( - WorkflowCreateRequest(workflow=expected_workflow.build()), - namespace="my-namespace", - ) + create_workflow.assert_called_once() + args, kwargs = create_workflow.call_args + assert kwargs == {"namespace": "my-namespace"} + wf_ws, req = args + assert req == WorkflowCreateRequest(workflow=expected_workflow.build()) + assert wf_ws == ws def test_workflow_template_get_as_workflow(): + ws = WorkflowsService(host="https://localhost:2746") + # GIVEN with WorkflowTemplate( name="my-wt", namespace="my-namespace", + workflows_service=ws, ) as wt: pass @@ -97,6 +102,7 @@ def test_workflow_template_get_as_workflow(): assert workflow.kind == "Workflow" assert workflow.name is None assert workflow.generate_name == "my-wt" + assert workflow.workflows_service == ws def test_workflow_template_get_as_workflow_truncator():