Skip to content

Commit

Permalink
Update dataclass s mandatory fields have no default value
Browse files Browse the repository at this point in the history
  • Loading branch information
DevChima committed Feb 19, 2025
1 parent 27d1c37 commit 3329359
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 26 deletions.
51 changes: 25 additions & 26 deletions home/import_assessments.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,20 +372,20 @@ class AssessmentRow:
"""

slug: str
title: str = ""
title: str
question: str
generic_error: str
locale: str
version: str = ""
tags: list[str] = field(default_factory=list)
question_type: str = ""
locale: str = ""
high_result_page: str = ""
high_inflection: str = ""
medium_result_page: str = ""
medium_inflection: str = ""
low_result_page: str = ""
skip_threshold: str = ""
skip_high_result_page: str = ""
generic_error: str = ""
question: str = ""
explainer: str = ""
error: str = ""
min: str = ""
Expand All @@ -400,18 +400,6 @@ class AssessmentRow:
def fields(cls) -> list[str]:
return [field.name for field in fields(cls)]

@classmethod
def check_missing_fields(cls, row: dict[str, str], row_num: int) -> None:
"""
Checks for missing required fields in the row and raises an exception if any is missing.
"""
missing_fields = [field for field in MANDATORY_HEADERS if field not in row]
if missing_fields:
raise ImportAssessmentException(
f"The import file is missing required fields: {', '.join(missing_fields)}",
row_num,
)

@classmethod
def from_flat(cls, row: dict[str, str], row_num: int) -> "AssessmentRow":
"""
Expand All @@ -428,21 +416,32 @@ def from_flat(cls, row: dict[str, str], row_num: int) -> "AssessmentRow":
key: value for key, value in row.items() if value and key in cls.fields()
}

cls.check_missing_fields(row, row_num)

answers = deserialise_list(row.pop("answers", ""))
answer_responses = deserialise_list(row.pop("answer_responses", ""))
if not answer_responses:
answer_responses = [""] * len(answers)

return cls(
tags=deserialise_list(row.pop("tags", "")),
answers=answers,
scores=[float(i) for i in deserialise_list(row.pop("scores", ""))],
answer_semantic_ids=deserialise_list(row.pop("answer_semantic_ids", "")),
answer_responses=answer_responses,
**row,
)
try:
return cls(
tags=deserialise_list(row.pop("tags", "")),
answers=answers,
scores=[float(i) for i in deserialise_list(row.pop("scores", ""))],
answer_semantic_ids=deserialise_list(
row.pop("answer_semantic_ids", "")
),
answer_responses=answer_responses,
**row,
)
except TypeError:
missing_fields = [
field
for field in MANDATORY_HEADERS
if field not in row or row[field] == ""
]
raise ImportAssessmentException(
f"The import file is missing required fields: {', '.join(missing_fields)}",
row_num,
)


def get_content_page_id_from_slug(slug: str, locale: Locale) -> int:
Expand Down
3 changes: 3 additions & 0 deletions home/import_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ def check_empty_rows(rows: list[dict[str, Any]], row_num: int) -> None:
)


# TODO:
# Move to shared code once we're able to work on the contentpage import.
# Contentsets uses pascal case headers
def convert_headers_to_snake_case(headers: list[str], row_num: int) -> dict[str, str]:
"""
Converts a list of headers to snake_case and returns a mapping.
Expand Down

0 comments on commit 3329359

Please sign in to comment.