Skip to content

Commit

Permalink
fix: compare versions using major.minor.patch during domain import (#…
Browse files Browse the repository at this point in the history
…1469)

* fix: compare versions using major.minor.patch during domain import

* Ignore add permission on requirement assessment model

Useless to check, and causes errors. If one has the permission to create
a compliance assessment, they can import requirement assessments.

* Fix m2m id mappings

---------

Co-authored-by: Abderrahmane Smimite <[email protected]>
  • Loading branch information
nas-tabchiche and ab-smith authored Feb 4, 2025
1 parent 1f30b48 commit 8f1e78d
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 6 deletions.
63 changes: 58 additions & 5 deletions backend/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,7 +2244,9 @@ def _process_uploaded_file(self, dump_file: str | Path) -> Any:

with zipfile.ZipFile(dump_file, mode="r") as zipf:
if "data.json" not in zipf.namelist():
logger.error("No data.json file found in uploaded file")
logger.error(
"No data.json file found in uploaded file", files=zipf.namelist()
)
raise ValidationError({"file": "noDataJsonFileFound"})
infolist = zipf.infolist()
directories = list(set([Path(f.filename).parent.name for f in infolist]))
Expand All @@ -2260,9 +2262,48 @@ def _process_uploaded_file(self, dump_file: str | Path) -> Any:
raise
if "objects" not in json_dump:
raise ValidationError("badly formatted json")
if not import_version == VERSION:

# Check backup and local version

VERSION_REGEX = r"^v[0-9]+\.[0-9]+\.[0-9]+"
match = re.match(VERSION_REGEX, import_version)
if match is None:
logger.error(
"Backup malformed: invalid version",
backup_version=import_version,
current_version=VERSION,
)
return Response(
{"error": "errorBackupInvalidVersion"},
status=status.HTTP_400_BAD_REQUEST,
)

import_version = match.group()
current_version = VERSION.split("-")[0]

if current_version.lower() == "dev":
current_version = "v0.0.0"

import_version = [int(num) for num in import_version.lstrip("v").split(".")]
current_version = [
int(num) for num in current_version.lstrip("v").split(".")
]
# All versions are composed of 3 numbers (see git tag)
for i in range(3):
if import_version[i] > current_version[i]:
logger.error(
"Backup version greater than current version",
version=import_version,
)
# Refuse to import the backup and ask to update the instance before importing the backup
return Response(
{"error": "GreaterBackupVersion"},
status=status.HTTP_400_BAD_REQUEST,
)

if not import_version == current_version:
logger.error(
f"Import version {import_version} not compatible with current version {VERSION}"
f"Import version {import_version} not compatible with current version {current_version}"
)
raise ValidationError(
{"file": "importVersionNotCompatibleWithCurrentVersion"}
Expand Down Expand Up @@ -2341,7 +2382,9 @@ def _import_objects(

# check that user has permission to create all objects to import
error_dict = {}
for model in models_map.values():
for model in filter(
lambda x: x not in [RequirementAssessment], models_map.values()
):
if not RoleAssignment.is_access_allowed(
user=user,
perm=Permission.objects.get(
Expand All @@ -2351,6 +2394,10 @@ def _import_objects(
):
error_dict[model._meta.model_name] = "permission_denied"
if error_dict:
logger.error(
"User does not have permission to import objects",
error_dict=error_dict,
)
raise PermissionDenied()

# Validation phase (outside transaction since it doesn't modify database)
Expand Down Expand Up @@ -2601,7 +2648,9 @@ def _process_model_relationships(
def get_mapped_ids(
ids: List[str], link_dump_database_ids: Dict[str, str]
) -> List[str]:
return [link_dump_database_ids.get(id, "") for id in ids]
return [
link_dump_database_ids[id] for id in ids if id in link_dump_database_ids
]

model_name = model._meta.model_name
_fields = fields.copy()
Expand Down Expand Up @@ -2653,6 +2702,7 @@ def get_mapped_ids(
_fields.pop("attachment_hash", None)

case "requirementassessment":
logger.debug("Looking for requirement", urn=_fields.get("requirement"))
_fields["requirement"] = RequirementNode.objects.get(
urn=_fields.get("requirement")
)
Expand Down Expand Up @@ -2812,6 +2862,9 @@ def _set_many_to_many_relations(self, model, obj, many_to_many_map_ids):
match model_name:
case "asset":
if parent_ids := many_to_many_map_ids.get("parent_ids"):
logger.debug(
"Setting parent assets", asset=obj, parent_ids=parent_ids
)
obj.parent_assets.set(Asset.objects.filter(id__in=parent_ids))

case "appliedcontrol":
Expand Down
2 changes: 1 addition & 1 deletion backend/serdes/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def post(self, request, *args, **kwargs):
if backup_version is None:
logger.error("Backup malformed: no version found")
return Response(
{"erroe": "errorBackupNoVersion"},
{"error": "errorBackupNoVersion"},
status=status.HTTP_400_BAD_REQUEST,
)

Expand Down

0 comments on commit 8f1e78d

Please sign in to comment.