diff --git a/taipy/core/_repository/_filesystem_repository.py b/taipy/core/_repository/_filesystem_repository.py index 721c569b36..13be55aaa6 100644 --- a/taipy/core/_repository/_filesystem_repository.py +++ b/taipy/core/_repository/_filesystem_repository.py @@ -191,10 +191,14 @@ def __filter_files_by_config_and_owner_id( return None def __match_file_and_get_entity(self, filepath, config_and_owner_ids, filters): - if match := [(c, p) for c, p in config_and_owner_ids if c.id in filepath.name]: + if match := [(c, p) for c, p in config_and_owner_ids if (c if isinstance(c, str) else c.id) in filepath.name]: for config, owner_id in match: for fil in filters: - fil.update({"config_id": config.id, "owner_id": owner_id}) + if isinstance(config, str): + config_id = config + else: + config_id = config.id + fil.update({"config_id": config_id, "owner_id": owner_id}) if data := self.__filter_by(filepath, filters): return config, owner_id, self.__file_content_to_entity(data) diff --git a/taipy/core/data/_data_manager.py b/taipy/core/data/_data_manager.py index e65d0ec779..a170d70067 100644 --- a/taipy/core/data/_data_manager.py +++ b/taipy/core/data/_data_manager.py @@ -22,7 +22,7 @@ from ..cycle.cycle_id import CycleId from ..exceptions.exceptions import InvalidDataNodeType from ..notification import Event, EventEntityType, EventOperation, Notifier, _make_event -from ..reason import NotGlobalScope, ReasonCollection, WrongConfigType +from ..reason import EntityDoesNotExist, NotGlobalScope, ReasonCollection, WrongConfigType from ..scenario.scenario_id import ScenarioId from ..sequence.sequence_id import SequenceId from ._data_fs_repository import _DataFSRepository @@ -37,6 +37,17 @@ class _DataManager(_Manager[DataNode], _VersionMixin): _EVENT_ENTITY_TYPE = EventEntityType.DATA_NODE _repository: _DataFSRepository + @classmethod + def _get_owner_id( + cls, scope, cycle_id, scenario_id + ) -> Union[Optional[SequenceId], Optional[ScenarioId], Optional[CycleId]]: + if scope == Scope.SCENARIO: + return scenario_id + elif scope == Scope.CYCLE: + return cycle_id + else: + return None + @classmethod def _bulk_get_or_create( cls, @@ -48,13 +59,7 @@ def _bulk_get_or_create( dn_configs_and_owner_id = [] for dn_config in data_node_configs: scope = dn_config.scope - owner_id: Union[Optional[SequenceId], Optional[ScenarioId], Optional[CycleId]] - if scope == Scope.SCENARIO: - owner_id = scenario_id - elif scope == Scope.CYCLE: - owner_id = cycle_id - else: - owner_id = None + owner_id = cls._get_owner_id(scope, cycle_id, scenario_id) dn_configs_and_owner_id.append((dn_config, owner_id)) data_nodes = cls._repository._get_by_configs_and_owner_ids( @@ -174,3 +179,39 @@ def _get_by_config_id(cls, config_id: str, version_number: Optional[str] = None) for fil in filters: fil.update({"config_id": config_id}) return cls._repository._load_all(filters) + + @classmethod + def _duplicate( + cls, dn: DataNode, cycle_id: Optional[CycleId] = None, scenario_id: Optional[ScenarioId] = None + ) -> DataNode: + data_nodes = cls._repository._get_by_configs_and_owner_ids( + [(dn.config_id, cls._get_owner_id(dn.scope, cycle_id, scenario_id))], cls._build_filters_with_version(None) + ) + + if existing_dn := data_nodes.get((dn.config_id, dn.owner_id)): + return existing_dn + else: + duplicated_dn = cls._get(dn) + + duplicated_dn.id = duplicated_dn._new_id(duplicated_dn._config_id) + duplicated_dn._owner_id = cls._get_owner_id(duplicated_dn._scope, cycle_id, scenario_id) + duplicated_dn._parent_ids = set() + + duplicated_dn._duplicate_data() + + cls._set(duplicated_dn) + return duplicated_dn + + @classmethod + def _can_duplicate(cls, dn: DataNode) -> ReasonCollection: + reason_collector = ReasonCollection() + + if isinstance(dn, DataNode): + dn_id = dn.id + else: + dn_id = dn + + if not cls._repository._exists(dn_id): + reason_collector._add_reason(dn_id, EntityDoesNotExist(dn_id)) + + return reason_collector diff --git a/taipy/core/data/_file_datanode_mixin.py b/taipy/core/data/_file_datanode_mixin.py index ff87146756..8744a3325c 100644 --- a/taipy/core/data/_file_datanode_mixin.py +++ b/taipy/core/data/_file_datanode_mixin.py @@ -42,6 +42,7 @@ class _FileDataNodeMixin: _PATH_KEY = "path" _DEFAULT_PATH_KEY = "default_path" _IS_GENERATED_KEY = "is_generated" + __TAIPY_DUPLICATED_PREFIX = "TAIPY_DUPLICATED" __logger = _TaipyLogger._get_logger() @@ -109,12 +110,14 @@ def _get_downloadable_path(self) -> str: return "" - def _upload(self, - path: str, - upload_checker: Optional[Callable[[str, Any], bool]] = None, - editor_id: Optional[str] = None, - comment: Optional[str] = None, - **kwargs: Any) -> ReasonCollection: + def _upload( + self, + path: str, + upload_checker: Optional[Callable[[str, Any], bool]] = None, + editor_id: Optional[str] = None, + comment: Optional[str] = None, + **kwargs: Any, + ) -> ReasonCollection: """Upload a file data to the data node. Arguments: @@ -136,11 +139,15 @@ def _upload(self, from ._data_manager_factory import _DataManagerFactory reasons = ReasonCollection() - if (editor_id - and self.edit_in_progress # type: ignore[attr-defined] - and self.editor_id != editor_id # type: ignore[attr-defined] - and (not self.editor_expiration_date # type: ignore[attr-defined] - or self.editor_expiration_date > datetime.now())): # type: ignore[attr-defined] + if ( + editor_id + and self.edit_in_progress # type: ignore[attr-defined] + and self.editor_id != editor_id # type: ignore[attr-defined] + and ( + not self.editor_expiration_date # type: ignore[attr-defined] + or self.editor_expiration_date > datetime.now() + ) + ): # type: ignore[attr-defined] reasons._add_reason(self.id, DataNodeEditInProgress(self.id)) # type: ignore[attr-defined] return reasons @@ -148,8 +155,7 @@ def _upload(self, try: upload_data = self._read_from_path(str(up_path)) except Exception as err: - self.__logger.error(f"Error uploading `{up_path.name}` to data " - f"node `{self.id}`:") # type: ignore[attr-defined] + self.__logger.error(f"Error uploading `{up_path.name}` to data " f"node `{self.id}`:") # type: ignore[attr-defined] self.__logger.error(f"Error: {err}") reasons._add_reason(self.id, UploadFileCanNotBeRead(up_path.name, self.id)) # type: ignore[attr-defined] return reasons @@ -161,7 +167,8 @@ def _upload(self, self.__logger.error( f"Error with the upload checker `{upload_checker.__name__}` " f"while checking `{up_path.name}` file for upload to the data " - f"node `{self.id}`:") # type: ignore[attr-defined] + f"node `{self.id}`:" + ) # type: ignore[attr-defined] self.__logger.error(f"Error: {err}") can_upload = False @@ -171,9 +178,12 @@ def _upload(self, shutil.copy(up_path, self.path) - self.track_edit(timestamp=datetime.now(), # type: ignore[attr-defined] - editor_id=editor_id, - comment=comment, **kwargs) + self.track_edit( + timestamp=datetime.now(), # type: ignore[attr-defined] + editor_id=editor_id, + comment=comment, + **kwargs, + ) self.unlock_edit() # type: ignore[attr-defined] _DataManagerFactory._build_manager()._set(self) # type: ignore[arg-type] @@ -212,3 +222,23 @@ def _migrate_path(self, storage_type, old_path) -> str: if os.path.exists(old_path): shutil.move(old_path, new_path) return new_path + + def _duplicate_data_file(self, id: str) -> Optional[str]: + if os.path.exists(self.path): + folder_path, base_name = os.path.split(self.path) + + if base_name.startswith(self.__TAIPY_DUPLICATED_PREFIX): + base_name = "".join(base_name.split("_")[5:]) + new_base_path = os.path.join(folder_path, f"{self.__TAIPY_DUPLICATED_PREFIX}_{id}_{base_name}") + + if os.path.isdir(self.path): + shutil.copytree(self.path, new_base_path) + else: + shutil.copy(self.path, new_base_path) + + if hasattr(self._properties, "_entity_owner"): # type: ignore[attr-defined] + del self._properties._entity_owner # type: ignore[attr-defined] + self._properties[self._PATH_KEY] = new_base_path # type: ignore[attr-defined] + + return new_base_path + return "" diff --git a/taipy/core/data/csv.py b/taipy/core/data/csv.py index 083215bc4e..6c18a83366 100644 --- a/taipy/core/data/csv.py +++ b/taipy/core/data/csv.py @@ -192,3 +192,6 @@ def _write(self, data: Any, columns: Optional[List[str]] = None): encoding=properties[self.__ENCODING_KEY], header=properties[self._HAS_HEADER_PROPERTY], ) + + def _duplicate_data(self): + return self._duplicate_data_file(self.id) diff --git a/taipy/core/data/data_node.py b/taipy/core/data/data_node.py index 08e8b2e1da..97a93be2d3 100644 --- a/taipy/core/data/data_node.py +++ b/taipy/core/data/data_node.py @@ -433,22 +433,27 @@ def append(self, data, editor_id: Optional[str] = None, comment: Optional[str] = corresponding to this write. """ from ._data_manager_factory import _DataManagerFactory - if (editor_id + + if ( + editor_id and self.edit_in_progress and self.editor_id != editor_id - and (not self.editor_expiration_date or self.editor_expiration_date > datetime.now())): + and (not self.editor_expiration_date or self.editor_expiration_date > datetime.now()) + ): raise DataNodeIsBeingEdited(self.id, self.editor_id) self._append(data) self.track_edit(editor_id=editor_id, comment=comment, **kwargs) self.unlock_edit() _DataManagerFactory._build_manager()._set(self) - def write(self, - data, - job_id: Optional[JobId] = None, - editor_id: Optional[str] = None, - comment: Optional[str] = None, - **kwargs: Any): + def write( + self, + data, + job_id: Optional[JobId] = None, + editor_id: Optional[str] = None, + comment: Optional[str] = None, + **kwargs: Any, + ): """Write some data to this data node. once the data is written, the data node is unlocked and the edit is tracked. @@ -461,10 +466,12 @@ def write(self, **kwargs (Any): Extra information to attach to the edit document corresponding to this write. """ - if (editor_id + if ( + editor_id and self.edit_in_progress and self.editor_id != editor_id - and (not self.editor_expiration_date or self.editor_expiration_date > datetime.now())): + and (not self.editor_expiration_date or self.editor_expiration_date > datetime.now()) + ): raise DataNodeIsBeingEdited(self.id, self.editor_id) self._write(data) self.track_edit(job_id=job_id, editor_id=editor_id, comment=comment, **kwargs) @@ -473,12 +480,14 @@ def write(self, _DataManagerFactory._build_manager()._set(self) - def track_edit(self, - job_id: Optional[str] = None, - editor_id: Optional[str] = None, - timestamp: Optional[datetime] = None, - comment: Optional[str] = None, - **options: Any): + def track_edit( + self, + job_id: Optional[str] = None, + editor_id: Optional[str] = None, + timestamp: Optional[datetime] = None, + comment: Optional[str] = None, + **options: Any, + ): """Creates and adds a new entry in the edits attribute without writing the data. Arguments: @@ -627,15 +636,15 @@ def _get_rank(self, scenario_config_id: str) -> int: If the data node config is not part of the scenario config, 0xfffc is returned as an infinite rank. """ if not scenario_config_id: - return 0xfffb + return 0xFFFB dn_config = Config.data_nodes.get(self._config_id, None) if not dn_config: self._logger.error(f"Data node config `{self.config_id}` for data node `{self.id}` is not found.") - return 0xfffd + return 0xFFFD if not dn_config._ranks: self._logger.error(f"Data node config `{self.config_id}` for data node `{self.id}` has no rank.") - return 0xfffe - return dn_config._ranks.get(scenario_config_id, 0xfffc) + return 0xFFFE + return dn_config._ranks.get(scenario_config_id, 0xFFFC) @abstractmethod def _read(self): @@ -676,6 +685,9 @@ def _get_last_modified_datetime(cls, path: Optional[str] = None) -> Optional[dat return last_modified_datetime + def _duplicate_data(self): + raise NotImplementedError + @staticmethod def _class_map(): def all_subclasses(cls): diff --git a/taipy/core/data/excel.py b/taipy/core/data/excel.py index 3e39c1160f..baec87a64b 100644 --- a/taipy/core/data/excel.py +++ b/taipy/core/data/excel.py @@ -339,3 +339,6 @@ def _write(self, data: Any): self._write_excel_with_single_sheet( data.to_excel, self._path, index=False, header=properties[self._HAS_HEADER_PROPERTY] or None ) + + def _duplicate_data(self): + return self._duplicate_data_file(self.id) diff --git a/taipy/core/data/json.py b/taipy/core/data/json.py index c18ab8d7b1..2e14fbdda8 100644 --- a/taipy/core/data/json.py +++ b/taipy/core/data/json.py @@ -158,6 +158,9 @@ def _write(self, data: Any): with open(self._path, "w", encoding=self.properties[self.__ENCODING_KEY]) as f: # type: ignore json.dump(data, f, indent=4, cls=self._encoder) + def _duplicate_data(self): + return self._duplicate_data_file(self.id) + class _DefaultJSONEncoder(json.JSONEncoder): def default(self, o): diff --git a/taipy/core/data/parquet.py b/taipy/core/data/parquet.py index 7c526b35d8..dcba52bf2a 100644 --- a/taipy/core/data/parquet.py +++ b/taipy/core/data/parquet.py @@ -249,3 +249,6 @@ def _append(self, data: Any): def _write(self, data: Any): self._write_with_kwargs(data) + + def _duplicate_data(self): + return self._duplicate_data_file(self.id) diff --git a/taipy/core/data/pickle.py b/taipy/core/data/pickle.py index b86e82d6c7..4b370a6221 100644 --- a/taipy/core/data/pickle.py +++ b/taipy/core/data/pickle.py @@ -108,3 +108,6 @@ def _read_from_path(self, path: Optional[str] = None, **read_kwargs) -> Any: def _write(self, data): with open(self._path, "wb") as pf: pickle.dump(data, pf) + + def _duplicate_data(self): + return self._duplicate_data_file(self.id) diff --git a/taipy/core/scenario/_scenario_manager.py b/taipy/core/scenario/_scenario_manager.py index b8d71c25d9..e1603a2fdf 100644 --- a/taipy/core/scenario/_scenario_manager.py +++ b/taipy/core/scenario/_scenario_manager.py @@ -521,3 +521,87 @@ def _get_by_config_id(cls, config_id: str, version_number: Optional[str] = None) for fil in filters: fil.update({"config_id": config_id}) return cls._repository._load_all(filters) + + @classmethod + def _duplicate( + cls, scenario: Scenario, creation_date: Optional[datetime] = None, name: Optional[str] = None + ) -> Scenario: + """ + Duplicate a scenario. + + Arguments: + scenario (Scenario): The scenario to duplicate. + + Returns: + Scenario: The duplicated scenario. + """ + creation_date = creation_date or datetime.now() + duplicated_scenario = cls._get(scenario) + duplicated_scenario.id = duplicated_scenario._new_id(duplicated_scenario.config_id) + + frequency = cls.__get_config(scenario).frequency + cycle = _CycleManagerFactory._build_manager()._get_or_create(frequency, creation_date) if frequency else None + cycle_id = cycle.id if cycle else None + + # Duplicate tasks and data nodes and sequences + _task_manager = _TaskManagerFactory._build_manager() + _data_manager = _DataManagerFactory._build_manager() + + duplicated_tasks = set() + task_ids_to_duplicated_tasks_dict = {} + for task in duplicated_scenario.tasks.values(): + duplicated_task = _task_manager._duplicate(task, cycle_id, duplicated_scenario.id) + duplicated_tasks.add(duplicated_task) + task_ids_to_duplicated_tasks_dict[task.id] = duplicated_task + duplicated_scenario._tasks = duplicated_tasks + + duplicated_sequences = {} + for sequence_name, sequence_values in duplicated_scenario._sequences.items(): + sequence_values[duplicated_scenario._SEQUENCE_TASKS_KEY] = { + task_ids_to_duplicated_tasks_dict[sequence_task_id] + for sequence_task_id in sequence_values[duplicated_scenario._SEQUENCE_TASKS_KEY] + } + duplicated_sequences[sequence_name] = sequence_values + duplicated_scenario._sequences = duplicated_sequences + + # Duplicate additional data nodes + duplicated_additional_data_nodes = set() + for data_node in duplicated_scenario.additional_data_nodes.values(): + duplicated_additional_data_nodes.add(_data_manager._duplicate(data_node, None, duplicated_scenario.id)) + duplicated_scenario._additional_data_nodes = duplicated_additional_data_nodes + + for task in duplicated_tasks: + if duplicated_scenario.id not in task._parent_ids: + task._parent_ids.update([duplicated_scenario.id]) + _task_manager._set(task) + + for dn in duplicated_additional_data_nodes: + if duplicated_scenario.id not in dn._parent_ids: + dn._parent_ids.update([duplicated_scenario.id]) + _data_manager._set(dn) + + if name: + if hasattr(duplicated_scenario._properties, "_entity_owner"): + del duplicated_scenario._properties._entity_owner + duplicated_scenario._properties["name"] = name + duplicated_scenario._cycle = cycle + duplicated_scenario._creation_date = creation_date + duplicated_scenario._primary_scenario = len(cls._get_all_by_cycle(cycle)) == 0 if cycle else False + + cls._set(duplicated_scenario) + + return duplicated_scenario + + @classmethod + def _can_duplicate(cls, scenario: Optional[Scenario]) -> ReasonCollection: + reason_collector = ReasonCollection() + + if isinstance(scenario, Scenario): + scenario_id = scenario.id + else: + scenario_id = str(scenario) # type: ignore + + if not cls._repository._exists(scenario_id): + reason_collector._add_reason(scenario_id, EntityDoesNotExist(scenario_id)) + + return reason_collector diff --git a/taipy/core/taipy.py b/taipy/core/taipy.py index 784cb8f314..e76aa6a098 100644 --- a/taipy/core/taipy.py +++ b/taipy/core/taipy.py @@ -1070,3 +1070,35 @@ def get_entities_by_config_id( if entities := _DataManagerFactory._build_manager()._get_by_config_id(config_id): return entities return entities + + +def can_duplicate(entity: Optional[Scenario] = None) -> ReasonCollection: + """Indicate if a scenario can be duplicated. + + Returns: + True if the given scenario can be created. False otherwise. + """ + return _ScenarioManagerFactory._build_manager()._can_duplicate(entity) + + +def duplicate_scenario( + scenario: Scenario, creation_date: Optional[datetime] = None, name: Optional[str] = None +) -> Scenario: + """Duplicate an existing scenario and return a new scenario. + + This function duplicates the provided scenario, optionally setting a new creation date and name. + + If the scenario belongs to a cycle, the cycle (corresponding to the creation_date and the configuration + frequency attribute) is created if it does not exist yet. + + Arguments: + scenario (Scenario): The scenario to duplicate. + creation_date (Optional[datetime.datetime]): The creation date of the new scenario. + If None, the current date and time is used. + name (Optional[str]): The displayable name of the new scenario. + + Returns: + Scenario: The newly duplicated scenario. + """ + + return _ScenarioManagerFactory._build_manager()._duplicate(scenario, creation_date, name) diff --git a/taipy/core/task/_task_manager.py b/taipy/core/task/_task_manager.py index 4336b8069f..e79634a649 100644 --- a/taipy/core/task/_task_manager.py +++ b/taipy/core/task/_task_manager.py @@ -57,6 +57,17 @@ def _set(cls, task: Task) -> None: cls.__save_data_nodes(task.output.values()) super()._set(task) + @classmethod + def _get_owner_id( + cls, scope, cycle_id, scenario_id + ) -> Union[Optional[SequenceId], Optional[ScenarioId], Optional[CycleId]]: + if scope == Scope.SCENARIO: + return scenario_id + elif scope == Scope.CYCLE: + return cycle_id + else: + return None + @classmethod def _bulk_get_or_create( cls, @@ -79,13 +90,7 @@ def _bulk_get_or_create( ] task_config_data_nodes = [data_nodes[dn_config] for dn_config in task_dn_configs] scope = min(dn.scope for dn in task_config_data_nodes) if len(task_config_data_nodes) != 0 else Scope.GLOBAL - owner_id: Union[Optional[SequenceId], Optional[ScenarioId], Optional[CycleId]] - if scope == Scope.SCENARIO: - owner_id = scenario_id - elif scope == Scope.CYCLE: - owner_id = cycle_id - else: - owner_id = None + owner_id = cls._get_owner_id(scope, cycle_id, scenario_id) tasks_configs_and_owner_id.append((task_config, owner_id)) @@ -226,3 +231,52 @@ def _get_by_config_id(cls, config_id: str, version_number: Optional[str] = None) for fil in filters: fil.update({"config_id": config_id}) return cls._repository._load_all(filters) + + @classmethod + def _duplicate( + cls, task: Task, cycle_id: Optional[CycleId] = None, scenario_id: Optional[ScenarioId] = None + ) -> Task: + data_manager = _DataManagerFactory._build_manager() + + duplicated_task = cls._get(task) + + inputs = [data_manager._duplicate(i, cycle_id, scenario_id) for i in duplicated_task.input.values()] + outputs = [data_manager._duplicate(o, cycle_id, scenario_id) for o in duplicated_task.output.values()] + + scope = min(dn.scope for dn in (inputs + outputs)) if (len(inputs) + len(outputs)) != 0 else Scope.GLOBAL + owner_id = cls._get_owner_id(scope, cycle_id, scenario_id) + + tasks_by_config = cls._repository._get_by_configs_and_owner_ids( # type: ignore + [(task.config_id, owner_id)], cls._build_filters_with_version(None) + ) + + if existing_task := tasks_by_config.get((task.config_id, owner_id)): + return existing_task + + duplicated_task.id = duplicated_task._new_id(duplicated_task.config_id) + duplicated_task._parent_ids = set() + duplicated_task._owner_id = owner_id + + duplicated_task._input = {i.config_id: i for i in inputs} + duplicated_task._output = {o.config_id: o for o in outputs} + + for dn in set(inputs + outputs): + dn._parent_ids.update([duplicated_task.id]) + data_manager._set(dn) + + cls._set(duplicated_task) + return duplicated_task + + @classmethod + def _can_duplicate(cls, task: Task) -> ReasonCollection: + reason_collector = ReasonCollection() + + if isinstance(task, Task): + task_id = task.id + else: + task_id = task + + if not cls._repository._exists(task_id): + reason_collector._add_reason(task_id, EntityDoesNotExist(task_id)) + + return reason_collector diff --git a/taipy/core/task/task.py b/taipy/core/task/task.py index ecedf8ae4b..0bb04742f9 100644 --- a/taipy/core/task/task.py +++ b/taipy/core/task/task.py @@ -116,7 +116,7 @@ def __init__( skippable: bool = False, ) -> None: self._config_id = _validate_id(config_id) - self.id = id or TaskId(self.__ID_SEPARATOR.join([self._ID_PREFIX, self.config_id, str(uuid.uuid4())])) + self.id = id or self._new_id(config_id) self._owner_id = owner_id self._parent_ids = parent_ids or set() self._input = {dn.config_id: dn for dn in input or []} @@ -127,6 +127,11 @@ def __init__( self._properties = _Properties(self, **properties) self._init_done = True + @staticmethod + def _new_id(config_id: str) -> TaskId: + """Generate a unique task identifier.""" + return TaskId(Task.__ID_SEPARATOR.join([Task._ID_PREFIX, config_id, str(uuid.uuid4())])) + def __hash__(self) -> int: return hash(self.id) diff --git a/tests/core/data/test_csv_data_node.py b/tests/core/data/test_csv_data_node.py index dcd5f56cc1..03b89c3280 100644 --- a/tests/core/data/test_csv_data_node.py +++ b/tests/core/data/test_csv_data_node.py @@ -10,6 +10,7 @@ # specific language governing permissions and limitations under the License. import dataclasses +import filecmp import os import pathlib import re @@ -429,3 +430,29 @@ def check_data_is_positive(upload_path, upload_data): # The upload should succeed when check_data_is_positive() return True assert dn._upload(new_csv_path, upload_checker=check_data_is_positive) + + def test_duplicate_data_file(self): + path = os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/example.csv") + dn = CSVDataNode("foo", Scope.SCENARIO, properties={"path": path, "exposed_type": "pandas"}) + _DataManager._set(dn) + + read_data = dn.read() + assert read_data is not None + + old_path = dn.path + new_file_path = str(dn._duplicate_data()) + assert filecmp.cmp(path, new_file_path) + + old_dn_id = dn.id + old_dn = _DataManager._get(old_dn_id) + assert old_dn.path == old_path + + dn.id = dn._new_id("foo") + dn.path = new_file_path + new_file_path_2 = str(dn._duplicate_data()) + assert len(new_file_path_2.split("TAIPY_DUPLICATED")) == 2 + os.unlink(new_file_path) + os.unlink(new_file_path_2) + + old_dn = _DataManager._get(old_dn_id) + assert old_dn.path == old_path diff --git a/tests/core/data/test_data_manager.py b/tests/core/data/test_data_manager.py index 7316f6d498..ef081760a0 100644 --- a/tests/core/data/test_data_manager.py +++ b/tests/core/data/test_data_manager.py @@ -24,7 +24,7 @@ from taipy.core.data.in_memory import InMemoryDataNode from taipy.core.data.pickle import PickleDataNode from taipy.core.exceptions.exceptions import InvalidDataNodeType, ModelNotFound -from taipy.core.reason import NotGlobalScope, WrongConfigType +from taipy.core.reason import EntityDoesNotExist, NotGlobalScope, WrongConfigType from tests.core.utils.named_temporary_file import NamedTemporaryFile @@ -731,3 +731,48 @@ def test_get_data_nodes_by_config_id_in_multiple_versions_environment(self): assert len(_DataManager._get_by_config_id(dn_config_1.id)) == 3 assert len(_DataManager._get_by_config_id(dn_config_2.id)) == 2 + + def test_duplicate_data_node_with_differnt_owner_id(self): + csv_path_inp = os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/example.csv") + dn_config = Config.configure_csv_data_node("dn_csv_in_1", default_path=csv_path_inp) + dn = _DataManager._create_and_set(dn_config, None, None) + + assert len(_DataManager._get_all()) == 1 + + new_dn = _DataManager._duplicate(dn, scenario_id="new_scenario_owner_id") + + assert dn.id != new_dn.id + assert len(_DataManager._get_all()) == 2 + assert dn.properties["path"] != new_dn.properties["path"] + assert os.path.exists(str(new_dn.properties["path"])) + os.remove(str(new_dn.properties["path"])) + + def test_duplicate_data_node_with_same_owner_id(self): + csv_path_inp = os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/example.csv") + dn_config = Config.configure_csv_data_node("dn_csv_in_1", default_path=csv_path_inp) + dn = _DataManager._create_and_set(dn_config, None, None) + + old_dn_id = dn.id + + assert len(_DataManager._get_all()) == 1 + + new_dn = _DataManager._duplicate(dn) + old_dn = _DataManager._get(old_dn_id) + + assert old_dn.id == new_dn.id + assert len(_DataManager._get_all()) == 1 + + def test_duplicate_data_node(self): + dn_config = Config.configure_pickle_data_node("dn", scope=Scope.SCENARIO) + data = _DataManager._create_and_set(dn_config, None, None) + + reasons = _DataManager._can_duplicate(data) + assert bool(reasons) + assert reasons._reasons == {} + + reasons = _DataManager._can_duplicate("1") + assert not bool(reasons) + assert reasons._reasons["1"] == {EntityDoesNotExist(1)} + assert str(list(reasons._reasons["1"])[0]) == "Entity 1 does not exist in the repository" + with pytest.raises(AttributeError): + _DataManager._duplicate("1") diff --git a/tests/core/data/test_excel_data_node.py b/tests/core/data/test_excel_data_node.py index 0a262a8e90..fe2a486773 100644 --- a/tests/core/data/test_excel_data_node.py +++ b/tests/core/data/test_excel_data_node.py @@ -9,6 +9,7 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +import filecmp import os import pathlib import re @@ -652,3 +653,29 @@ def check_data_is_positive(upload_path, upload_data): # The upload should succeed when check_data_is_positive() return True assert dn._upload(new_excel_path, upload_checker=check_data_is_positive) + + def test_duplicate_data_file(self): + path = os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/example.xlsx") + dn = ExcelDataNode("foo", Scope.SCENARIO, properties={"default_path": path}) + _DataManager._set(dn) + + read_data = dn.read() + assert read_data is not None + old_path = dn.path + + new_file_path = str(dn._duplicate_data()) + assert filecmp.cmp(path, new_file_path) + + old_dn_id = dn.id + old_dn = _DataManager._get(old_dn_id) + assert old_dn.path == old_path + + dn.id = dn._new_id("foo") + dn.path = new_file_path + new_file_path_2 = str(dn._duplicate_data()) + assert len(new_file_path_2.split("TAIPY_DUPLICATED")) == 2 + os.unlink(new_file_path) + os.unlink(new_file_path_2) + + old_dn = _DataManager._get(old_dn_id) + assert old_dn.path == old_path diff --git a/tests/core/data/test_json_data_node.py b/tests/core/data/test_json_data_node.py index 05b2b76b02..cfad9b896b 100644 --- a/tests/core/data/test_json_data_node.py +++ b/tests/core/data/test_json_data_node.py @@ -10,6 +10,7 @@ # specific language governing permissions and limitations under the License. import datetime +import filecmp import json import os import pathlib @@ -492,3 +493,29 @@ def check_data_keys(upload_path, upload_data): # The upload should succeed when check_data_keys() return True assert dn._upload(json_file, upload_checker=check_data_keys) + + def test_duplicate_data_file(self): + path = os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/json/example_dict.json") + dn = JSONDataNode("foo", Scope.SCENARIO, properties={"path": path}) + _DataManager._set(dn) + + read_data = dn.read() + assert read_data is not None + old_path = dn.path + + new_file_path = str(dn._duplicate_data()) + assert filecmp.cmp(path, new_file_path) + + old_dn_id = dn.id + old_dn = _DataManager._get(old_dn_id) + assert old_dn.path == old_path + + dn.id = dn._new_id("foo") + dn.path = new_file_path + new_file_path_2 = str(dn._duplicate_data()) + assert len(new_file_path_2.split("TAIPY_DUPLICATED")) == 2 + os.unlink(new_file_path) + os.unlink(new_file_path_2) + + old_dn = _DataManager._get(old_dn_id) + assert old_dn.path == old_path diff --git a/tests/core/data/test_parquet_data_node.py b/tests/core/data/test_parquet_data_node.py index 1fc224dfa1..b6f76208f0 100644 --- a/tests/core/data/test_parquet_data_node.py +++ b/tests/core/data/test_parquet_data_node.py @@ -9,9 +9,11 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +import filecmp import os import pathlib import re +import shutil import uuid from datetime import datetime, timedelta from importlib import util @@ -402,3 +404,29 @@ def check_data_is_positive(upload_path, upload_data): # The upload should succeed when check_data_is_positive() return True assert dn._upload(new_parquet_path, upload_checker=check_data_is_positive) + + def test_duplicate_data_file(self): + path = os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/parquet_example") + dn = ParquetDataNode("foo", Scope.SCENARIO, properties={"path": path}) + _DataManager._set(dn) + + read_data = dn.read() + assert read_data is not None + + old_path = dn.path + new_file_path = str(dn._duplicate_data()) + assert filecmp.dircmp(path, new_file_path) + + old_dn_id = dn.id + old_dn = _DataManager._get(old_dn_id) + assert old_dn.path == old_path + + dn.id = dn._new_id("foo") + dn.path = new_file_path + new_file_path_2 = str(dn._duplicate_data()) + assert len(new_file_path_2.split("TAIPY_DUPLICATED")) == 2 + shutil.rmtree(new_file_path) + shutil.rmtree(new_file_path_2) + + old_dn = _DataManager._get(old_dn_id) + assert old_dn.path == old_path diff --git a/tests/core/data/test_pickle_data_node.py b/tests/core/data/test_pickle_data_node.py index 05deccf0cf..e592681789 100644 --- a/tests/core/data/test_pickle_data_node.py +++ b/tests/core/data/test_pickle_data_node.py @@ -9,6 +9,7 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +import filecmp import os import pathlib import pickle @@ -305,3 +306,29 @@ def check_data_column(upload_path, upload_data): # The upload should succeed when check_data_column() return True assert dn._upload(pickle_file_path, upload_checker=check_data_column) + + def test_duplicate_data_file(self): + path = os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/example.p") + dn = PickleDataNode("foo", Scope.SCENARIO, properties={"default_path": path}) + _DataManager._set(dn) + + read_data = dn.read() + assert read_data is not None + + old_path = dn.path + new_file_path = str(dn._duplicate_data()) + assert filecmp.cmp(path, new_file_path) + + old_dn_id = dn.id + old_dn = _DataManager._get(old_dn_id) + assert old_dn.path == old_path + + dn.id = dn._new_id("foo") + dn.path = new_file_path + new_file_path_2 = str(dn._duplicate_data()) + assert len(new_file_path_2.split("TAIPY_DUPLICATED")) == 2 + os.unlink(new_file_path) + os.unlink(new_file_path_2) + + old_dn = _DataManager._get(old_dn_id) + assert old_dn.path == old_path diff --git a/tests/core/scenario/test_scenario_manager.py b/tests/core/scenario/test_scenario_manager.py index 6441a32e31..1bffd21880 100644 --- a/tests/core/scenario/test_scenario_manager.py +++ b/tests/core/scenario/test_scenario_manager.py @@ -41,7 +41,7 @@ UnauthorizedTagError, ) from taipy.core.job._job_manager import _JobManager -from taipy.core.reason import WrongConfigType +from taipy.core.reason import EntityDoesNotExist, WrongConfigType from taipy.core.scenario._scenario_manager import _ScenarioManager from taipy.core.scenario._scenario_manager_factory import _ScenarioManagerFactory from taipy.core.scenario.scenario import Scenario @@ -1553,3 +1553,279 @@ def test_filter_scenarios_by_creation_datetime(): ) assert len(filtered_scenarios) == 1 assert [s_1_1] == filtered_scenarios + + +def test_can_duplicate_scenario(): + dn_config = Config.configure_pickle_data_node("dn", scope=Scope.SCENARIO) + task_config = Config.configure_task("task_1", print, [dn_config]) + scenario_config = Config.configure_scenario("scenario_1", [task_config]) + scenario = _ScenarioManager._create(scenario_config) + + reasons = _ScenarioManager._can_duplicate(scenario) + assert bool(reasons) + assert reasons._reasons == {} + + reasons = _ScenarioManager._can_duplicate("1") + assert not bool(reasons) + assert reasons._reasons["1"] == {EntityDoesNotExist(1)} + assert str(list(reasons._reasons["1"])[0]) == "Entity 1 does not exist in the repository" + with pytest.raises(AttributeError): + _ScenarioManager._duplicate("1") + + +def test_duplicate_scenario(): + dn_config_1 = Config.configure_pickle_data_node("dn_1", scope=Scope.SCENARIO) + dn_config_2 = Config.configure_pickle_data_node("dn_2", scope=Scope.SCENARIO) + additional_dn_config_1 = Config.configure_data_node("additional_dn_1", scope=Scope.SCENARIO) + task_config_1 = Config.configure_task("task_1", print, [dn_config_1], [dn_config_2]) + scenario_config_1 = Config.configure_scenario("scenario_1", [task_config_1], [additional_dn_config_1]) + scenario = _ScenarioManager._create(scenario_config_1) + + assert len(_ScenarioManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 3 + assert len(_TaskManager._get_all()) == 1 + + new_scenario = _ScenarioManager._duplicate(scenario, name="New Scenario") + + assert scenario.id != new_scenario.id + assert new_scenario.name == "New Scenario" + assert len(_ScenarioManager._get_all()) == 2 + assert len(_DataManager._get_all()) == 6 + assert len(_TaskManager._get_all()) == 2 + + assert all(scenario.id in t.parent_ids for t in scenario.tasks.values()) + assert all(scenario.id == t.owner_id for t in scenario.tasks.values()) + assert all(scenario.id in dn.parent_ids for dn in scenario.additional_data_nodes.values()) + assert all(scenario.id == dn.owner_id for dn in scenario.data_nodes.values()) + + assert all(new_scenario.id in t.parent_ids for t in new_scenario.tasks.values()) + assert all(new_scenario.id == t.owner_id for t in new_scenario.tasks.values()) + assert all(new_scenario.id in dn.parent_ids for dn in new_scenario.additional_data_nodes.values()) + assert all(new_scenario.id == dn.owner_id for dn in new_scenario.data_nodes.values()) + + +def test_duplicate_scenario_with_single_GLOBAL_dn_scope(): + dn_config_1 = Config.configure_pickle_data_node("dn_1", scope=Scope.SCENARIO) + dn_config_2 = Config.configure_pickle_data_node("dn_2", scope=Scope.GLOBAL) + additional_dn_config_1 = Config.configure_data_node("additional_dn_1", scope=Scope.SCENARIO) + task_config_1 = Config.configure_task("task_1", print, [dn_config_1], [dn_config_2]) + scenario_config_1 = Config.configure_scenario("scenario_1", [task_config_1], [additional_dn_config_1]) + scenario = _ScenarioManager._create(scenario_config_1) + + assert len(_ScenarioManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 3 + assert len(_TaskManager._get_all()) == 1 + + new_scenario = _ScenarioManager._duplicate(scenario) + + assert scenario.id != new_scenario.id + assert len(_ScenarioManager._get_all()) == 2 + assert len(_DataManager._get_all()) == 5 + assert len(_TaskManager._get_all()) == 2 + + assert all(scenario.id in t.parent_ids for t in scenario.tasks.values()) + assert all(scenario.id == t.owner_id for t in scenario.tasks.values()) + assert all(scenario.id in dn.parent_ids for dn in scenario.additional_data_nodes.values()) + assert all((scenario.id == dn.owner_id or dn.owner_id is None) for dn in scenario.data_nodes.values()) + + assert all(new_scenario.id in t.parent_ids for t in new_scenario.tasks.values()) + assert all(new_scenario.id == t.owner_id for t in new_scenario.tasks.values()) + assert all(new_scenario.id in dn.parent_ids for dn in new_scenario.additional_data_nodes.values()) + assert all((new_scenario.id == dn.owner_id or dn.owner_id is None) for dn in new_scenario.data_nodes.values()) + + +def test_duplicate_scenario_with_all_GLOBAL_dn_scope(): + dn_config_1 = Config.configure_pickle_data_node("dn_1", scope=Scope.GLOBAL) + dn_config_2 = Config.configure_pickle_data_node("dn_2", scope=Scope.GLOBAL) + additional_dn_config_1 = Config.configure_data_node("additional_dn_1", scope=Scope.SCENARIO) + task_config_1 = Config.configure_task("task_1", print, [dn_config_1], [dn_config_2]) + scenario_config_1 = Config.configure_scenario("scenario_1", [task_config_1], [additional_dn_config_1]) + scenario = _ScenarioManager._create(scenario_config_1) + + assert len(_ScenarioManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 3 + assert len(_TaskManager._get_all()) == 1 + + new_scenario = _ScenarioManager._duplicate(scenario) + + assert scenario.id != new_scenario.id + assert len(_ScenarioManager._get_all()) == 2 + assert len(_DataManager._get_all()) == 4 + assert len(_TaskManager._get_all()) == 1 + + assert all(scenario.id in t.parent_ids for t in scenario.tasks.values()) + assert all(t.owner_id is None for t in scenario.tasks.values()) + assert all(scenario.id in dn.parent_ids for dn in scenario.additional_data_nodes.values()) + assert all((scenario.id == dn.owner_id or dn.owner_id is None) for dn in scenario.data_nodes.values()) + + assert all(new_scenario.id in t.parent_ids for t in new_scenario.tasks.values()) + assert all(t.owner_id is None for t in new_scenario.tasks.values()) + assert all(new_scenario.id in dn.parent_ids for dn in new_scenario.additional_data_nodes.values()) + assert all((new_scenario.id == dn.owner_id or dn.owner_id is None) for dn in new_scenario.data_nodes.values()) + + +def test_duplicate_scenario_with_single_CYCLE_dn_scope(): + dn_config_1 = Config.configure_pickle_data_node("dn_1", scope=Scope.SCENARIO) + dn_config_2 = Config.configure_pickle_data_node("dn_2", scope=Scope.CYCLE) + additional_dn_config_1 = Config.configure_data_node("additional_dn_1", scope=Scope.SCENARIO) + task_config_1 = Config.configure_task("task_1", print, [dn_config_1], [dn_config_2]) + scenario_config_1 = Config.configure_scenario("scenario_1", [task_config_1], [additional_dn_config_1]) + scenario = _ScenarioManager._create(scenario_config_1) + + assert len(_ScenarioManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 3 + assert len(_TaskManager._get_all()) == 1 + + new_scenario = _ScenarioManager._duplicate(scenario) + + assert scenario.id != new_scenario.id + assert len(_ScenarioManager._get_all()) == 2 + assert len(_DataManager._get_all()) == 5 + assert len(_TaskManager._get_all()) == 2 + + assert all(scenario.id in t.parent_ids for t in scenario.tasks.values()) + assert all(scenario.id == t.owner_id for t in scenario.tasks.values()) + assert all(scenario.id in dn.parent_ids for dn in scenario.additional_data_nodes.values()) + assert all((scenario.id == dn.owner_id or dn.owner_id is None) for dn in scenario.data_nodes.values()) + + assert all(new_scenario.id in t.parent_ids for t in new_scenario.tasks.values()) + assert all(new_scenario.id == t.owner_id for t in new_scenario.tasks.values()) + assert all(new_scenario.id in dn.parent_ids for dn in new_scenario.additional_data_nodes.values()) + assert all((new_scenario.id == dn.owner_id or dn.owner_id is None) for dn in new_scenario.data_nodes.values()) + + +def test_duplicate_scenario_with_all_CYCLE_dn_scope(): + dn_config_1 = Config.configure_pickle_data_node("dn_1", scope=Scope.CYCLE) + dn_config_2 = Config.configure_pickle_data_node("dn_2", scope=Scope.CYCLE) + additional_dn_config_1 = Config.configure_data_node("additional_dn_1", scope=Scope.SCENARIO) + task_config_1 = Config.configure_task("task_1", print, [dn_config_1], [dn_config_2]) + scenario_config_1 = Config.configure_scenario("scenario_1", [task_config_1], [additional_dn_config_1]) + scenario = _ScenarioManager._create(scenario_config_1) + + assert len(_ScenarioManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 3 + assert len(_TaskManager._get_all()) == 1 + + new_scenario = _ScenarioManager._duplicate(scenario) + + assert scenario.id != new_scenario.id + assert len(_ScenarioManager._get_all()) == 2 + assert len(_DataManager._get_all()) == 4 + assert len(_TaskManager._get_all()) == 1 + + assert all(scenario.id in t.parent_ids for t in scenario.tasks.values()) + assert all(t.owner_id is None for t in scenario.tasks.values()) + assert all(scenario.id in dn.parent_ids for dn in scenario.additional_data_nodes.values()) + assert all((scenario.id == dn.owner_id or dn.owner_id is None) for dn in scenario.data_nodes.values()) + + assert all(new_scenario.id in t.parent_ids for t in new_scenario.tasks.values()) + assert all(t.owner_id is None for t in new_scenario.tasks.values()) + assert all(new_scenario.id in dn.parent_ids for dn in new_scenario.additional_data_nodes.values()) + assert all((new_scenario.id == dn.owner_id or dn.owner_id is None) for dn in new_scenario.data_nodes.values()) + + +def test_duplicate_scenario_with_same_cycle(): + dn_config_1 = Config.configure_pickle_data_node("dn_1", scope=Scope.SCENARIO) + dn_config_2 = Config.configure_pickle_data_node("dn_2", scope=Scope.SCENARIO) + additional_dn_config_1 = Config.configure_data_node("additional_dn_1", scope=Scope.SCENARIO) + task_config_1 = Config.configure_task("task_1", print, [dn_config_1], [dn_config_2]) + scenario_config_1 = Config.configure_scenario( + "scenario_1", [task_config_1], [additional_dn_config_1], frequency=Frequency.YEARLY + ) + scenario = _ScenarioManager._create(scenario_config_1) + + assert len(_CycleManager._get_all()) == 1 + assert len(_ScenarioManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 3 + assert len(_TaskManager._get_all()) == 1 + + new_scenario = _ScenarioManager._duplicate(scenario) + + assert scenario.id != new_scenario.id + assert len(_CycleManager._get_all()) == 1 + assert len(_ScenarioManager._get_all()) == 2 + assert len(_DataManager._get_all()) == 6 + assert len(_TaskManager._get_all()) == 2 + + assert all(scenario.id in t.parent_ids for t in scenario.tasks.values()) + assert all(scenario.id == t.owner_id for t in scenario.tasks.values()) + assert all(scenario.id in dn.parent_ids for dn in scenario.additional_data_nodes.values()) + assert all(scenario.id == dn.owner_id for dn in scenario.data_nodes.values()) + + assert all(new_scenario.id in t.parent_ids for t in new_scenario.tasks.values()) + assert all(new_scenario.id == t.owner_id for t in new_scenario.tasks.values()) + assert all(new_scenario.id in dn.parent_ids for dn in new_scenario.additional_data_nodes.values()) + assert all(new_scenario.id == dn.owner_id for dn in new_scenario.data_nodes.values()) + + assert new_scenario.cycle == scenario.cycle + + +def test_duplicate_scenario_with_separate_cycle(): + dn_config_1 = Config.configure_pickle_data_node("dn_1", scope=Scope.SCENARIO) + dn_config_2 = Config.configure_pickle_data_node("dn_2", scope=Scope.SCENARIO) + additional_dn_config_1 = Config.configure_data_node("additional_dn_1", scope=Scope.SCENARIO) + task_config_1 = Config.configure_task("task_1", print, [dn_config_1], [dn_config_2]) + scenario_config_1 = Config.configure_scenario( + "scenario_1", [task_config_1], [additional_dn_config_1], frequency=Frequency.DAILY + ) + scenario = _ScenarioManager._create(scenario_config_1, datetime.now() - timedelta(days=1)) + + assert len(_CycleManager._get_all()) == 1 + assert len(_ScenarioManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 3 + assert len(_TaskManager._get_all()) == 1 + + new_scenario = _ScenarioManager._duplicate(scenario, datetime.now() + timedelta(days=1)) + + assert scenario.id != new_scenario.id + assert len(_CycleManager._get_all()) == 2 + assert len(_ScenarioManager._get_all()) == 2 + assert len(_DataManager._get_all()) == 6 + assert len(_TaskManager._get_all()) == 2 + + assert all(scenario.id in t.parent_ids for t in scenario.tasks.values()) + assert all(scenario.id == t.owner_id for t in scenario.tasks.values()) + assert all(scenario.id in dn.parent_ids for dn in scenario.additional_data_nodes.values()) + assert all(scenario.id == dn.owner_id for dn in scenario.data_nodes.values()) + + assert all(new_scenario.id in t.parent_ids for t in new_scenario.tasks.values()) + assert all(new_scenario.id == t.owner_id for t in new_scenario.tasks.values()) + assert all(new_scenario.id in dn.parent_ids for dn in new_scenario.additional_data_nodes.values()) + assert all(new_scenario.id == dn.owner_id for dn in new_scenario.data_nodes.values()) + + +def test_duplicate_scenario_with_sequences(): + dn_config_1 = Config.configure_pickle_data_node("dn_1", scope=Scope.SCENARIO) + dn_config_2 = Config.configure_pickle_data_node("dn_2", scope=Scope.SCENARIO) + dn_config_3 = Config.configure_pickle_data_node("dn_3", scope=Scope.SCENARIO) + task_config_1 = Config.configure_task("task_1", print, [dn_config_1], [dn_config_2]) + task_config_2 = Config.configure_task("task_2", print, [dn_config_2], [dn_config_3]) + + scenario_config_1 = Config.configure_scenario("scenario_1", [task_config_1, task_config_2]) + scenario = _ScenarioManager._create(scenario_config_1) + + tasks_dict = scenario.tasks + scenario.add_sequence("seq_1", [tasks_dict["task_1"], tasks_dict["task_2"]], {"some_properties_1": "some_values_1"}) + scenario.add_sequence("seq_2", [tasks_dict["task_2"]], {"some_properties_2": "some_values_2"}) + + assert len(_ScenarioManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 3 + assert len(_TaskManager._get_all()) == 2 + + duplicated_scenario = _ScenarioManager._duplicate(scenario, name="New Scenario") + + assert scenario.id != duplicated_scenario.id + assert duplicated_scenario.name == "New Scenario" + assert len(_ScenarioManager._get_all()) == 2 + assert len(_DataManager._get_all()) == 7 + assert len(_TaskManager._get_all()) == 4 + + duplicated_tasks = duplicated_scenario.tasks + tasks_duplicated_tasks_ids = { + task.id: duplicated_tasks[task_config].id for task_config, task in scenario.tasks.items() + } + + assert duplicated_scenario.sequences["seq_1"].id != scenario.sequences["seq_1"].id + assert {task.id for task in duplicated_scenario.sequences["seq_1"].tasks.values()} == { + tasks_duplicated_tasks_ids[task.id] for task in scenario.sequences["seq_1"].tasks.values() + } diff --git a/tests/core/task/test_task_manager.py b/tests/core/task/test_task_manager.py index 55d98bd875..4aff99ea30 100644 --- a/tests/core/task/test_task_manager.py +++ b/tests/core/task/test_task_manager.py @@ -22,6 +22,7 @@ from taipy.core.data._data_manager import _DataManager from taipy.core.data.in_memory import InMemoryDataNode from taipy.core.exceptions.exceptions import ModelNotFound, NonExistingTask +from taipy.core.reason import EntityDoesNotExist from taipy.core.task._task_manager import _TaskManager from taipy.core.task._task_manager_factory import _TaskManagerFactory from taipy.core.task.task import Task @@ -483,3 +484,68 @@ def test_get_scenarios_by_config_id_in_multiple_versions_environment(): def _create_task_from_config(task_config, *args, **kwargs): return _TaskManager._bulk_get_or_create([task_config], *args, **kwargs)[0] + + +def test_duplicate_task_wit_different_owner_id(): + dn_input_config_1 = Config.configure_pickle_data_node("my_input_1", scope=Scope.SCENARIO, default_data="testing") + dn_output_config_1 = Config.configure_pickle_data_node("my_output_1", scope=Scope.SCENARIO) + task_config_1 = Config.configure_task("task_config_1", print, dn_input_config_1, dn_output_config_1) + task = _create_task_from_config(task_config_1) + + task_id = task.id + + assert len(_TaskManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 2 + + new_task = _TaskManager._duplicate(task, scenario_id="scenario_id") + + assert task.id != new_task.id + assert len(_TaskManager._get_all()) == 2 + assert len(_DataManager._get_all()) == 4 + + assert all(task_id in dn.parent_ids for dn in task.data_nodes.values()) + assert all(dn.owner_id is None for dn in task.data_nodes.values()) + + assert all(new_task.id in dn.parent_ids for dn in new_task.data_nodes.values()) + assert all(dn.owner_id == "scenario_id" for dn in new_task.data_nodes.values()) + + +def test_duplicate_task_wit_same_owner_id(): + dn_input_config_1 = Config.configure_pickle_data_node("my_input_1", scope=Scope.SCENARIO, default_data="testing") + dn_output_config_1 = Config.configure_pickle_data_node("my_output_1", scope=Scope.SCENARIO) + task_config_1 = Config.configure_task("task_config_1", print, dn_input_config_1, dn_output_config_1) + task = _create_task_from_config(task_config_1) + + task_id = task.id + + assert len(_TaskManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 2 + + new_task = _TaskManager._duplicate(task) + + assert task.id == new_task.id + assert len(_TaskManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 2 + + assert all(task_id in dn.parent_ids for dn in task.data_nodes.values()) + assert all(dn.owner_id is None for dn in task.data_nodes.values()) + + assert all(new_task.id in dn.parent_ids for dn in new_task.data_nodes.values()) + assert all(dn.owner_id is None for dn in new_task.data_nodes.values()) + + +def test_duplicate_task(): + dn_config = Config.configure_pickle_data_node("dn", scope=Scope.SCENARIO) + task_config = Config.configure_task("task_1", print, [dn_config]) + task = _TaskManager._bulk_get_or_create([task_config])[0] + + reasons = _TaskManager._can_duplicate(task) + assert bool(reasons) + assert reasons._reasons == {} + + reasons = _TaskManager._can_duplicate("1") + assert not bool(reasons) + assert reasons._reasons["1"] == {EntityDoesNotExist(1)} + assert str(list(reasons._reasons["1"])[0]) == "Entity 1 does not exist in the repository" + with pytest.raises(AttributeError): + _TaskManager._duplicate("1") diff --git a/tests/core/test_taipy.py b/tests/core/test_taipy.py index afcea08d40..5f967d98c6 100644 --- a/tests/core/test_taipy.py +++ b/tests/core/test_taipy.py @@ -869,3 +869,29 @@ def test_get_entities_by_config_id_in_multiple_versions_environment(self): assert len(tp.get_scenarios()) == 5 assert len(tp.get_entities_by_config_id(scenario_config_1.id)) == 3 assert len(tp.get_entities_by_config_id(scenario_config_2.id)) == 2 + + def test_can_duplicate(self): + dn_config = Config.configure_in_memory_data_node("dn", 10) + task_config = Config.configure_task("task", print, [dn_config]) + scenario_config = Config.configure_scenario("sc", {task_config}, [], Frequency.DAILY) + + scenario = tp.create_scenario(scenario_config) + assert tp.can_duplicate(scenario) + assert not tp.can_duplicate("1") + + def test_duplicate_scenario(self): + dn_config = Config.configure_in_memory_data_node("dn", 10) + task_config = Config.configure_task("task", print, [dn_config]) + scenario_config = Config.configure_scenario("sc", {task_config}, [], Frequency.DAILY) + + scenario = tp.create_scenario(scenario_config) + + with mock.patch("taipy.core.scenario._scenario_manager._ScenarioManager._duplicate") as mck: + tp.duplicate_scenario(scenario) + mck.assert_called_once_with(scenario, None, None) + with mock.patch("taipy.core.scenario._scenario_manager._ScenarioManager._duplicate") as mck: + tp.duplicate_scenario(scenario, datetime.datetime(2022, 2, 5)) + mck.assert_called_once_with(scenario, datetime.datetime(2022, 2, 5), None) + with mock.patch("taipy.core.scenario._scenario_manager._ScenarioManager._duplicate") as mck: + tp.duplicate_scenario(scenario, datetime.datetime(2022, 2, 5), "displayable_name") + mck.assert_called_once_with(scenario, datetime.datetime(2022, 2, 5), "displayable_name")