diff --git a/odxtools/database.py b/odxtools/database.py index f6ec1e45..dafc0449 100644 --- a/odxtools/database.py +++ b/odxtools/database.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: MIT from itertools import chain from pathlib import Path -from typing import IO, Any, Dict, List, Optional, OrderedDict +from typing import IO, Any, Dict, List, Optional, OrderedDict, Union from xml.etree import ElementTree from zipfile import ZipFile +from os import PathLike from packaging.version import Version @@ -27,10 +28,7 @@ class Database: into a single PDX file. """ - def __init__(self, - *, - pdx_zip: Optional[ZipFile] = None, - odx_d_file_name: Optional[str] = None) -> None: + def __init__(self) -> None: self.model_version: Optional[Version] = None self.auxiliary_files: OrderedDict[str, IO[bytes]] = OrderedDict() @@ -39,9 +37,14 @@ def __init__(self, self._comparam_subsets = NamedItemList[ComparamSubset]() self._comparam_specs = NamedItemList[ComparamSpec]() - def add_pdx_file(self, pdx_file_name: str) -> None: - pdx_zip = ZipFile(pdx_file_name) - + def add_pdx_file(self, pdx_file: Union[str, "PathLike[Any]", IO[bytes], ZipFile]) -> None: + """Add PDX file to database. + Either pass the path to the file, an IO with the file content or a ZipFile object. + """ + if isinstance(pdx_file, ZipFile): + pdx_zip = pdx_file + else: + pdx_zip = ZipFile(pdx_file) for zip_member in pdx_zip.namelist(): # The name of ODX files can end with .odx, .odx-d, # .odx-c, .odx-cs, .odx-e, .odx-f, .odx-fd, .odx-m, @@ -54,16 +57,16 @@ def add_pdx_file(self, pdx_file_name: str) -> None: elif p.name.lower() != "index.xml": self.add_auxiliary_file(zip_member, pdx_zip.open(zip_member)) - def add_odx_file(self, odx_file_name: str) -> None: + def add_odx_file(self, odx_file_name: Union[str, "PathLike[Any]"]) -> None: self._process_xml_tree(ElementTree.parse(odx_file_name).getroot()) def add_auxiliary_file(self, - aux_file_name: str, + aux_file_name: Union[str, "PathLike[Any]"], aux_file_obj: Optional[IO[bytes]] = None) -> None: if aux_file_obj is None: aux_file_obj = open(aux_file_name, "rb") - self.auxiliary_files[aux_file_name] = aux_file_obj + self.auxiliary_files[str(aux_file_name)] = aux_file_obj def _process_xml_tree(self, root: ElementTree.Element) -> None: dlcs: List[DiagLayerContainer] = []