diff --git a/home/import_assessments.py b/home/import_assessments.py index 4c986e45..aa97ad6c 100644 --- a/home/import_assessments.py +++ b/home/import_assessments.py @@ -12,7 +12,14 @@ from wagtail.coreutils import get_content_languages # type: ignore from wagtail.models import Locale, Page # type: ignore -from home.import_helpers import ImportException, parse_file, validate_using_form +from home.import_helpers import ( + ImportException, + convert_headers_to_snake_case, + validate_using_form, +) +from home.import_helpers import ( + parse_file as helper_parse_file, +) from home.models import Assessment, ContentPage, HomePage # type: ignore AssessmentId = tuple[str, Locale] @@ -106,21 +113,13 @@ def parse_file(self) -> list["AssessmentRow"]: c. Validates that the snake_case headers contain all mandatory headers. d. Transforms each row to use snake_case headers. """ - - row_iterator = parse_file(self.file_content, self.file_type) + row_iterator = helper_parse_file(self.file_content, self.file_type) rows = [row for _, row in row_iterator] - if not rows: - raise ImportAssessmentException( - "The import file is empty or contains no valid rows.", row_num=1 - ) - original_headers = rows[0].keys() - headers_mapping = { - header: self.to_snake_case(header) for header in original_headers - } + headers_mapping = convert_headers_to_snake_case(list(original_headers)) snake_case_headers = list(headers_mapping.values()) - self.validate_headers(snake_case_headers, MANDATORY_HEADERS, row_num=1) + self.validate_headers(snake_case_headers, row_num=1) transformed_rows = [ {headers_mapping[key]: value for key, value in row.items()} for row in rows ] @@ -214,9 +213,7 @@ def create_shadow_assessment_from_row( ) assessment.questions.append(question) - def validate_headers( - self, headers: list[str], MANDATORY_HEADERS: list[str], row_num: int - ) -> None: + def validate_headers(self, headers: list[str], row_num: int) -> None: missing_headers = [ header for header in MANDATORY_HEADERS if header not in headers ] diff --git a/home/import_helpers.py b/home/import_helpers.py index f499baf9..f90b3ce7 100644 --- a/home/import_helpers.py +++ b/home/import_helpers.py @@ -1,5 +1,6 @@ # The error messages are processed and parsed into a list of messages we return to the user import csv +import re from collections.abc import Iterator from datetime import datetime from io import BytesIO, StringIO @@ -165,6 +166,30 @@ def extract_errors(data: dict[str | int, Any] | list[str]) -> dict[str, str]: return error_message +def check_empty_rows(rows: list[dict[str, Any]], row_num: int) -> None: + """ + Checks if the list of rows is empty and raises an exception if true. + """ + if not rows: + raise ImportException( + "The import file is empty or contains no valid rows.", row_num=row_num + ) + + +def convert_headers_to_snake_case(headers: list[str]) -> dict[str, str]: + """ + Converts a list of headers to snake_case and returns a mapping. + """ + return {header: to_snake_case(header) for header in headers} + + +def to_snake_case(s: str) -> str: + """ + Converts string to snake_case. + """ + return re.sub(r"[\W_]+", "_", s).lower().strip("_") + + def fix_rows(rows: Iterator[dict[str | Any, Any]]) -> Iterator[dict[str, str | None]]: """ Fix keys for all rows by lowercasing keys and removing whitespace from keys and values @@ -210,7 +235,11 @@ def parse_file( file_content: bytes, file_type: str ) -> Iterator[tuple[int, dict[str, Any]]]: read_rows = read_xlsx if file_type == "XLSX" else read_csv - return enumerate(fix_rows(read_rows(file_content)), start=2) + rows = list(fix_rows(read_rows(file_content))) + + check_empty_rows(rows, row_num=1) + + return enumerate(rows, start=2) def read_csv(file_content: bytes) -> Iterator[dict[str, Any]]: diff --git a/home/tests/test_assessment_import_export.py b/home/tests/test_assessment_import_export.py index ddc441d5..4c7568e4 100644 --- a/home/tests/test_assessment_import_export.py +++ b/home/tests/test_assessment_import_export.py @@ -706,9 +706,11 @@ def test_empty_rows(self, csv_impexp: ImportExport) -> None: Importing an empty CSV should return an error that the import file has no rows. """ - with pytest.raises(ImportAssessmentException) as e: + with pytest.raises(ImportException) as e: csv_impexp.import_file("empty.csv") - assert e.value.message == "The import file is empty or contains no valid rows." + assert e.value.message == [ + "The import file is empty or contains no valid rows." + ] assert e.value.row_num == 1