diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b1d59bb..dc5121b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,16 +25,16 @@ jobs: pip install torch --index-url https://download.pytorch.org/whl/cpu pip install ./ pip list - - name: download model + - name: download models run: | python -m spacy download en_core_web_md - pip install https://huggingface.co/kormilitzin/en_core_med7_lg/resolve/main/en_core_med7_lg-any-py3-none-any.whl + pip install -r requirements.txt - name: run pytest run: pytest ./tests/* - - name: install ruff + - name: Install ruff run: pip install ruff - - name: ruff format + - name: Lint with ruff run: | - ruff format + ruff --output-format=github . ruff check --fix - continue-on-error: true + continue-on-error: true \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e218059 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +https://huggingface.co/kormilitzin/en_core_med7_lg/resolve/main/en_core_med7_lg-any-py3-none-any.whl diff --git a/src/miade/annotators.py b/src/miade/annotators.py index ac6a3d1..1b7e1e5 100644 --- a/src/miade/annotators.py +++ b/src/miade/annotators.py @@ -82,6 +82,22 @@ def load_lookup_data(filename: str, is_package_data: bool = False, as_dict: bool return pd.read_csv(lookup_data).drop_duplicates() +def load_regex_paragraph_mappings(data: pd.DataFrame) -> Dict: + regex_lookup = {} + + for paragraph, regex in data.items(): + paragraph_enum = None + try: + paragraph_enum = ParagraphType(paragraph) + except ValueError as e: + log.warning(e) + + if paragraph_enum is not None: + regex_lookup[paragraph_enum] = regex + + return regex_lookup + + def load_allergy_type_combinations(filename: str, is_package_data: bool = False) -> Dict: """ Load allergy type combinations from a CSV file and return a dictionary. @@ -197,6 +213,9 @@ def __init__(self, cat: CAT, config: AnnotatorConfig = None): if self.config.negation_detection == "negex": self._add_negex_pipeline() + self._set_lookup_data_path() + self._load_paragraph_regex() + # TODO make paragraph processing params configurable self.structured_prob_lists = { ParagraphType.prob: Relevance.PRESENT, @@ -217,6 +236,43 @@ def _add_negex_pipeline(self) -> None: self.cat.pipe.spacy_nlp.enable_pipe("sentencizer") self.cat.pipe.spacy_nlp.add_pipe("negex") + def _set_lookup_data_path(self) -> None: + """ + Sets the lookup data path based on the configuration. + + If the `lookup_data_path` is not specified in the configuration, the default path "./data/" is used + and `use_package_data` is set to True. Otherwise, the specified `lookup_data_path` is used and + `use_package_data` is set to False. + + Raises: + RuntimeError: If the specified `lookup_data_path` does not exist. + """ + if self.config.lookup_data_path is None: + self.lookup_data_path = "./data/" + self.use_package_data = True + log.info("Loading preconfigured lookup data") + else: + self.lookup_data_path = self.config.lookup_data_path + self.use_package_data = False + log.info(f"Loading lookup data from {self.lookup_data_path}") + if not os.path.isdir(self.lookup_data_path): + raise RuntimeError(f"No lookup data configured: {self.lookup_data_path} does not exist!") + + def _load_paragraph_regex(self) -> None: + """ + Loads the paragraph regex mappings from a CSV file and initializes the paragraph_regex attribute. + + This method loads the paragraph regex mappings from a CSV file located the lookup data path specified in config. + If unspecified, loads the default packaged regex lookup for paragraph headings. + + Returns: + None + """ + data = load_lookup_data( + self.lookup_data_path + "regex_para_chunk.csv", is_package_data=self.use_package_data, as_dict=True + ) + self.paragraph_regex = load_regex_paragraph_mappings(data) + @property @abstractmethod def concept_types(self): @@ -240,39 +296,12 @@ def postprocess(self): """ pass - def run_pipeline( - self, note: Note, record_concepts: List[Concept], dosage_extractor: Optional[DosageExtractor] = None - ) -> List[Concept]: + @abstractmethod + def run_pipeline(self): """ - Runs the annotation pipeline on a given note and returns the extracted concepts. - - Args: - note (Note): The input note to process. - record_concepts (List[Concept]): The list of concepts from existing EHR records. - dosage_extractor (Optional[DosageExtractor]): An optional dosage extractor to add dosages to concepts. - - Returns: - List[Concept]: The extracted concepts from the note. + Abstract method that runs the annotation pipeline on a given note and returns the extracted concepts. """ - # TODO: make this more extensible - concepts: List[Concept] = [] - - for pipe in self.pipeline: - if pipe not in self.config.disable: - if pipe == "preprocessor": - note = self.preprocess(note) - elif pipe == "medcat": - concepts = self.get_concepts(note) - elif pipe == "paragrapher": - concepts = self.process_paragraphs(note, concepts) - elif pipe == "postprocessor": - concepts = self.postprocess(concepts) - elif pipe == "deduplicator": - concepts = self.deduplicate(concepts, record_concepts) - elif pipe == "dosage_extractor" and dosage_extractor is not None: - concepts = self.add_dosages_to_concepts(dosage_extractor, concepts, note) - - return concepts + pass def get_concepts(self, note: Note) -> List[Concept]: """ @@ -294,24 +323,96 @@ def get_concepts(self, note: Note) -> List[Concept]: return concepts - @staticmethod - def preprocess(note: Note) -> Note: + def preprocess(self, note: Note, refine: bool = True) -> Note: """ Preprocesses a note by cleaning its text and splitting it into paragraphs. Args: note (Note): The input note to preprocess. + refine (bool): Whether to refine the paragraph detection algorithm and allow merging of continuous prose + paragraphs, merging to paragraphs with empty bodies with the next prose paragraphs. Default True. Returns: The preprocessed note. """ - note.clean_text() - note.get_paragraphs() + note.process(self.paragraph_regex, refine=refine) return note @staticmethod - def deduplicate(concepts: List[Concept], record_concepts: Optional[List[Concept]]) -> List[Concept]: + def filter_concepts_in_numbered_list(concepts: List[Concept], note: Note) -> List[Concept]: + """ + Filters and returns a list of concepts in a numbered list in a note using a two-pointer algorithm. + + This filters out concepts that may not be relevant given a note that has structured list headings + and numbered lists within that. i.e. only return the first line of a numbered list. e.g. + 1. CCF - + - had echo on 15/6 + - on diuretics + will only return the concept CCF as it is the first item in a numbered list + + Args: + concepts (List[Concept]): The list of concepts to filter. + note (Note): The note containing numbered lists. + + Returns: + The filtered list of concepts. + """ + # Check there is a numbered list + if len(note.numbered_lists) == 0: + return concepts + + # Get the global list ranges of all numbered lists in a note + global_list_ranges = [ + (numbered_list.list_start, numbered_list.list_end) for numbered_list in note.numbered_lists + ] + + # Flatten the list items from all numbered lists into a single list and sort them + list_items = [item for numbered_list in note.numbered_lists for item in numbered_list.items] + list_items.sort(key=lambda x: x.start) + + # Sort the concepts by their start index + concepts.sort(key=lambda x: x.start) + + filtered_concepts = [] + concept_idx, item_idx = 0, 0 + + # Iterate through concepts and list items simultaneously + while concept_idx < len(concepts) and item_idx < len(list_items): + concept = concepts[concept_idx] + item = list_items[item_idx] + + # Check if the concept is within the global range of any list + if any(start <= concept.start < end for start, end in global_list_ranges): + # Check for partial or full overlap between concept and list item + if ( + concept.start >= item.start and concept.end <= item.end + ): # or (concept.start < item.end and concept.end > item.start) + # Concept overlaps with or is within the current list item + filtered_concepts.append(concept) + concept_idx += 1 # Move to the next concept + elif concept.end <= item.start: + # If the concept ends before the item starts, move to the next concept + concept_idx += 1 + else: + # Otherwise, move to the next list item + item_idx += 1 + else: + # If concept is not within a numbered list range, skip and return it + filtered_concepts.append(concept) + concept_idx += 1 + + # After iterating, check if there are remaining concepts after the last list item that might not have been added + while concept_idx < len(concepts): + concept = concepts[concept_idx] + if concept.start >= global_list_ranges[-1][1]: + filtered_concepts.append(concept) + concept_idx += 1 + + return filtered_concepts + + @staticmethod + def deduplicate(concepts: List[Concept], record_concepts: Optional[List[Concept]] = None) -> List[Concept]: """ Removes duplicate concepts from the extracted concepts list by strict ID matching. @@ -459,9 +560,40 @@ def pipeline(self) -> List[str]: Get the list of processing steps in the annotation pipeline. Returns: - ["preprocessor", "medcat", "paragrapher", "postprocessor", "deduplicator"] + ["preprocessor", "medcat", "list_cleaner", "paragrapher", "postprocessor", "deduplicator"] """ - return ["preprocessor", "medcat", "paragrapher", "postprocessor", "deduplicator"] + return ["preprocessor", "medcat", "list_cleaner", "paragrapher", "postprocessor", "deduplicator"] + + def run_pipeline(self, note: Note, record_concepts: Optional[List[Concept]] = None) -> List[Concept]: + """ + Runs the annotation pipeline on a given note and returns the extracted problems concepts. + + Args: + note (Note): The input note to process. + record_concepts (Optional[List[Concept]]): The list of concepts from existing EHR records. + + Returns: + List[Concept]: The extracted concepts from the note. + """ + # TODO: not the best way to do this - make this more extensible!! + concepts: List[Concept] = [] + + for pipe in self.pipeline: + if pipe not in self.config.disable: + if pipe == "preprocessor": + note = self.preprocess(note=note, refine=self.config.refine_paragraphs) + elif pipe == "medcat": + concepts = self.get_concepts(note=note) + elif pipe == "list_cleaner": + concepts = self.filter_concepts_in_numbered_list(concepts=concepts, note=note) + elif pipe == "paragrapher": + concepts = self.process_paragraphs(note=note, concepts=concepts) + elif pipe == "postprocessor": + concepts = self.postprocess(concepts=concepts) + elif pipe == "deduplicator": + concepts = self.deduplicate(concepts=concepts, record_concepts=record_concepts) + + return concepts def _load_problems_lookup_data(self) -> None: """ @@ -470,26 +602,17 @@ def _load_problems_lookup_data(self) -> None: Raises: RuntimeError: If the lookup data directory does not exist. """ - if self.config.lookup_data_path is None: - data_path = "./data/" - is_package_data = True - log.info("Loading preconfigured lookup data for ProblemsAnnotator") - else: - data_path = self.config.lookup_data_path - is_package_data = False - log.info(f"Loading lookup data from {data_path} for ProblemsAnnotator") - if not os.path.isdir(data_path): - raise RuntimeError(f"No lookup data configured: {data_path} does not exist!") - - self.negated_lookup = load_lookup_data(data_path + "negated.csv", is_package_data=is_package_data, as_dict=True) + self.negated_lookup = load_lookup_data( + self.lookup_data_path + "negated.csv", is_package_data=self.use_package_data, as_dict=True + ) self.historic_lookup = load_lookup_data( - data_path + "historic.csv", is_package_data=is_package_data, as_dict=True + self.lookup_data_path + "historic.csv", is_package_data=self.use_package_data, as_dict=True ) self.suspected_lookup = load_lookup_data( - data_path + "suspected.csv", is_package_data=is_package_data, as_dict=True + self.lookup_data_path + "suspected.csv", is_package_data=self.use_package_data, as_dict=True ) self.filtering_blacklist = load_lookup_data( - data_path + "problem_blacklist.csv", is_package_data=is_package_data, no_header=True + self.lookup_data_path + "problem_blacklist.csv", is_package_data=self.use_package_data, no_header=True ) def _process_meta_annotations(self, concept: Concept) -> Optional[Concept]: @@ -629,13 +752,19 @@ def process_paragraphs(self, note: Note, concepts: List[Concept]) -> List[Concep The filtered list of concepts. """ prob_concepts_in_structured_sections: List[Concept] = [] - - for paragraph in note.paragraphs: - for concept in concepts: - if concept.start >= paragraph.start and concept.end <= paragraph.end: - # log.debug(f"({concept.name} | {concept.id}) is in {paragraph.type}") - if concept.meta: - self._process_meta_ann_by_paragraph(concept, paragraph, prob_concepts_in_structured_sections) + if note.paragraphs: + # Use a list comprehension to flatten the loop and conditionals + concepts_in_paragraphs = [ + (concept, paragraph) + for paragraph in note.paragraphs + for concept in concepts + if concept.start >= paragraph.start and concept.end <= paragraph.end and concept.meta + ] + # Process each concept and paragraph pair + for concept, paragraph in concepts_in_paragraphs: + self._process_meta_ann_by_paragraph(concept, paragraph, prob_concepts_in_structured_sections) + else: + log.warn("Unable to run paragrapher pipeline: did you add preprocessor to the pipeline?") # if more than set no. concepts in prob or imp or pmh sections, return only those and ignore all other concepts if len(prob_concepts_in_structured_sections) > self.config.structured_list_limit: @@ -713,11 +842,12 @@ def pipeline(self) -> List[str]: The annotators are executed in the order they appear in the list. Returns: - ["preprocessor", "medcat", "paragrapher", "postprocessor", "dosage_extractor", "vtm_converter", "deduplicator"] + ["preprocessor", "medcat", "list_cleaner", "paragrapher", "postprocessor", "dosage_extractor", "vtm_converter", "deduplicator"] """ return [ "preprocessor", "medcat", + "list_cleaner", "paragrapher", "postprocessor", "dosage_extractor", @@ -726,14 +856,17 @@ def pipeline(self) -> List[str]: ] def run_pipeline( - self, note: Note, record_concepts: List[Concept], dosage_extractor: Optional[DosageExtractor] + self, + note: Note, + record_concepts: Optional[List[Concept]] = None, + dosage_extractor: Optional[DosageExtractor] = None, ) -> List[Concept]: """ Runs the annotation pipeline on the given note. Args: note (Note): The input note to run the pipeline on. - record_concepts (List[Concept]): The list of previously recorded concepts. + record_concepts (Optional[List[Concept]]): The list of previously recorded concepts. dosage_extractor (Optional[DosageExtractor]): The dosage extractor function. Returns: @@ -744,19 +877,23 @@ def run_pipeline( for pipe in self.pipeline: if pipe not in self.config.disable: if pipe == "preprocessor": - note = self.preprocess(note) + note = self.preprocess(note=note) elif pipe == "medcat": - concepts = self.get_concepts(note) + concepts = self.get_concepts(note=note) + elif pipe == "list_cleaner": + concepts = self.filter_concepts_in_numbered_list(concepts=concepts, note=note) elif pipe == "paragrapher": - concepts = self.process_paragraphs(note, concepts) + concepts = self.process_paragraphs(note=note, concepts=concepts) elif pipe == "postprocessor": - concepts = self.postprocess(concepts, note) + concepts = self.postprocess(concepts=concepts, note=note) elif pipe == "deduplicator": - concepts = self.deduplicate(concepts, record_concepts) + concepts = self.deduplicate(concepts=concepts, record_concepts=record_concepts) elif pipe == "vtm_converter": - concepts = self.convert_VTM_to_VMP_or_text(concepts) + concepts = self.convert_VTM_to_VMP_or_text(concepts=concepts) elif pipe == "dosage_extractor" and dosage_extractor is not None: - concepts = self.add_dosages_to_concepts(dosage_extractor, concepts, note) + concepts = self.add_dosages_to_concepts( + dosage_extractor=dosage_extractor, concepts=concepts, note=note + ) return concepts @@ -764,32 +901,23 @@ def _load_med_allergy_lookup_data(self) -> None: """ Loads the medication and allergy lookup data. """ - if self.config.lookup_data_path is None: - data_path = "./data/" - is_package_data = True - log.info("Loading preconfigured lookup data for MedsAllergiesAnnotator") - else: - data_path = self.config.lookup_data_path - is_package_data = False - log.info(f"Loading lookup data from {data_path} for MedsAllergiesAnnotator") - if not os.path.isdir(data_path): - raise RuntimeError(f"No lookup data configured: {data_path} does not exist!") - self.valid_meds = load_lookup_data( - data_path + "valid_meds.csv", is_package_data=is_package_data, no_header=True + self.lookup_data_path + "valid_meds.csv", is_package_data=self.use_package_data, no_header=True ) self.reactions_subset_lookup = load_lookup_data( - data_path + "reactions_subset.csv", is_package_data=is_package_data, as_dict=True + self.lookup_data_path + "reactions_subset.csv", is_package_data=self.use_package_data, as_dict=True ) self.allergens_subset_lookup = load_lookup_data( - data_path + "allergens_subset.csv", is_package_data=is_package_data, as_dict=True + self.lookup_data_path + "allergens_subset.csv", is_package_data=self.use_package_data, as_dict=True ) self.allergy_type_lookup = load_allergy_type_combinations( - data_path + "allergy_type.csv", is_package_data=is_package_data + self.lookup_data_path + "allergy_type.csv", is_package_data=self.use_package_data + ) + self.vtm_to_vmp_lookup = load_lookup_data( + self.lookup_data_path + "vtm_to_vmp.csv", is_package_data=self.use_package_data ) - self.vtm_to_vmp_lookup = load_lookup_data(data_path + "vtm_to_vmp.csv", is_package_data=is_package_data) self.vtm_to_text_lookup = load_lookup_data( - data_path + "vtm_to_text.csv", is_package_data=is_package_data, as_dict=True + self.lookup_data_path + "vtm_to_text.csv", is_package_data=self.use_package_data, as_dict=True ) def _validate_meds(self, concept) -> bool: diff --git a/src/miade/note.py b/src/miade/note.py index 6ba635e..65c514a 100644 --- a/src/miade/note.py +++ b/src/miade/note.py @@ -1,67 +1,30 @@ import re -import io -import pkgutil import logging -import pandas as pd from typing import List, Optional, Dict -from .paragraph import Paragraph, ParagraphType +from .paragraph import ListItem, NumberedList, Paragraph, ParagraphType log = logging.getLogger(__name__) -def load_regex_config_mappings(filename: str) -> Dict: - """ - Load regex configuration mappings from a file. - - Args: - filename (str): The name of the file containing the regex configuration. - - Returns: - A dictionary mapping paragraph types to their corresponding regex patterns. - """ - regex_config = pkgutil.get_data(__name__, filename) - data = ( - pd.read_csv( - io.BytesIO(regex_config), - index_col=0, - ) - .squeeze("columns") - .T.to_dict() - ) - regex_lookup = {} - - for paragraph, regex in data.items(): - paragraph_enum = None - try: - paragraph_enum = ParagraphType(paragraph) - except ValueError as e: - log.warning(e) - - if paragraph_enum is not None: - regex_lookup[paragraph_enum] = regex - - return regex_lookup - - class Note(object): """ - Represents a note object. + Represents a Note object Attributes: text (str): The text content of the note. raw_text (str): The raw text content of the note. - regex_config (str): The path to the regex configuration file. - paragraphs (Optional[List[Paragraph]]): A list of paragraphs in the note. + paragraphs (Optional[List[Paragraph]]): A list of Paragraph objects representing the paragraphs in the note. + numbered_lists (Optional[List[NumberedList]]): A list of NumberedList objects representing the numbered lists in the note. """ - def __init__(self, text: str, regex_config_path: str = "./data/regex_para_chunk.csv"): + def __init__(self, text: str): self.text = text self.raw_text = text - self.regex_config = load_regex_config_mappings(regex_config_path) self.paragraphs: Optional[List[Paragraph]] = [] + self.numbered_lists: Optional[List[NumberedList]] = [] def clean_text(self) -> None: """ @@ -83,14 +46,61 @@ def clean_text(self) -> None: # Remove spaces if the entire line (between two line breaks) is just spaces self.text = re.sub(r"(?<=\n)\s+(?=\n)", "", self.text) - def get_paragraphs(self) -> None: + def get_numbered_lists(self): """ - Splits the note into paragraphs. + Finds multiple lists of sequentially ordered numbers (with more than one item) that directly follow a newline character + and captures the text following these numbers up to the next newline. - This method splits the text content of the note into paragraphs based on double line breaks. - It also assigns a paragraph type to each paragraph based on matching patterns in the heading. + Parameters: + text (str): The input text in which to search for multiple lists of sequentially ordered numbers with more than one item and their subsequent text. + + Returns: + list of lists: Each sublist contains tuples where each tuple includes the start index of the number, + the end index of the line, and the captured text for each valid sequentially ordered list found. Returns an empty list if no such patterns are found. """ + # Regular expression to find numbers followed by any characters until a newline + pattern = re.compile(r"(?<=\n)(\d+.*)") + + # Finding all matches + matches = pattern.finditer(self.text) + + all_results = [] + results = [] + last_num = 0 + for match in matches: + number_text = match.group(1) + current_num = int(re.search(r"^\d+", number_text).group()) + + # Check if current number is the next in sequence + if current_num == last_num + 1: + results.append(ListItem(content=number_text, start=match.start(1), end=match.end(1))) + else: + # If there is a break in the sequence, check if current list has more than one item + if len(results) > 1: + numbered_list = NumberedList(items=results, list_start=results[0].start, list_end=results[-1].end) + all_results.append(numbered_list) + results = [ + ListItem(content=number_text, start=match.start(1), end=match.end(1)) + ] # Start new results list with the current match + last_num = current_num # Update last number to the current + + # Add the last sequence if not empty and has more than one item + if len(results) > 1: + numbered_list = NumberedList(items=results, list_start=results[0].start, list_end=results[-1].end) + all_results.append(numbered_list) + + self.numbered_lists = all_results + + def get_paragraphs(self, paragraph_regex: Dict) -> None: + """ + Split the text into paragraphs and assign paragraph types based on regex patterns. + + Args: + paragraph_regex (Dict): A dictionary containing paragraph types as keys and regex patterns as values. + Returns: + None + """ paragraphs = re.split(r"\n\n+", self.text) start = 0 @@ -117,12 +127,126 @@ def get_paragraphs(self) -> None: if heading: heading = heading.lower() # Iterate through the dictionary items and patterns - for paragraph_type, pattern in self.regex_config.items(): + for paragraph_type, pattern in paragraph_regex.items(): if re.search(pattern, heading): paragraph.type = paragraph_type break # Exit the loop if a match is found self.paragraphs.append(paragraph) + def merge_prose_sections(self) -> None: + """ + Merges consecutive prose sections in the paragraphs list. + + Returns: + A list of merged prose sections. + """ + is_merge = False + all_prose = [] + prose_section = [] + prose_indices = [] + + for i, paragraph in enumerate(self.paragraphs): + if paragraph.type == ParagraphType.prose: + if is_merge: + prose_section.append((i, paragraph)) + else: + prose_section = [(i, paragraph)] + is_merge = True + else: + if len(prose_section) > 0: + all_prose.append(prose_section) + prose_indices.extend([idx for idx, _ in prose_section]) + is_merge = False + + if len(prose_section) > 0: + all_prose.append(prose_section) + prose_indices.extend([idx for idx, _ in prose_section]) + + new_paragraphs = self.paragraphs[:] + + for section in all_prose: + start = section[0][1].start + end = section[-1][1].end + new_prose_para = Paragraph( + heading=self.text[start:end], body="", type=ParagraphType.prose, start=start, end=end + ) + + # Replace the first paragraph in the section with the new merged paragraph + first_idx = section[0][0] + new_paragraphs[first_idx] = new_prose_para + + # Mark other paragraphs in the section for deletion + for _, paragraph in section[1:]: + index = self.paragraphs.index(paragraph) + new_paragraphs[index] = None + + # Remove the None entries from new_paragraphs + self.paragraphs = [para for para in new_paragraphs if para is not None] + + def merge_empty_non_prose_with_next_prose(self) -> None: + """ + This method checks if a Paragraph has an empty body and a type that is not prose, + and merges it with the next Paragraph if the next paragraph is type prose. + + Returns: + None + """ + merged_paragraphs = [] + skip_next = False + + for i in range(len(self.paragraphs) - 1): + if skip_next: + # Skip this iteration because the previous iteration already handled merging + skip_next = False + continue + + current_paragraph = self.paragraphs[i] + next_paragraph = self.paragraphs[i + 1] + + # Check if current paragraph has an empty body and is not of type prose + if current_paragraph.body == "" and current_paragraph.type != ParagraphType.prose: + # Check if the next paragraph is of type prose + if next_paragraph.type == ParagraphType.prose: + # Create a new Paragraph with merged content and type prose + merged_paragraph = Paragraph( + heading=current_paragraph.heading, + body=next_paragraph.heading, + type=current_paragraph.type, + start=current_paragraph.start, + end=next_paragraph.end, + ) + merged_paragraphs.append(merged_paragraph) + # Skip the next paragraph since it's already merged + skip_next = True + continue + + # If no merging is done, add the current paragraph to the list + merged_paragraphs.append(current_paragraph) + + # Handle the last paragraph if it wasn't merged + if not skip_next: + merged_paragraphs.append(self.paragraphs[-1]) + + # Update the paragraphs list with the merged paragraphs + self.paragraphs = merged_paragraphs + + def process(self, lookup_dict: Dict, refine: bool = True): + """ + Process the note by cleaning the text, extracting numbered lists, and getting paragraphs based on a lookup dictionary. + + Args: + lookup_dict (Dict): A dictionary used to lookup specific paragraphs. + refine (bool, optional): Flag indicating whether to refine the processed note - this will merge any consecutive prose + paragraphs and then merge any structured paragraphs with empty body with the next prose paragraph (handles line break + between heading and body). Defaults to True. + """ + self.clean_text() + self.get_numbered_lists() + self.get_paragraphs(lookup_dict) + if refine: + self.merge_prose_sections() + self.merge_empty_non_prose_with_next_prose() + def __str__(self): return self.text diff --git a/src/miade/paragraph.py b/src/miade/paragraph.py index 0183279..31d1662 100644 --- a/src/miade/paragraph.py +++ b/src/miade/paragraph.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import List class ParagraphType(Enum): @@ -27,14 +28,58 @@ class Paragraph(object): """ def __init__(self, heading: str, body: str, type: ParagraphType, start: int, end: int): - self.heading = heading - self.body = body - self.type = type - self.start = start - self.end = end + self.heading: str = heading + self.body: str = body + self.type: ParagraphType = type + self.start: int = start + self.end: int = end def __str__(self): return str(self.__dict__) def __eq__(self, other): return self.type == other.type and self.start == other.start and self.end == other.end + + +class ListItem(object): + """ + Represents an item in a NumberedList + + Attributes: + content (str): The content of the list item. + start (int): The starting index of the list item. + end (int): The ending index of the list item. + """ + + def __init__(self, content: str, start: int, end: int) -> None: + self.content: str = content + self.start: int = start + self.end: int = end + + def __str__(self): + return str(self.__dict__) + + def __eq__(self, other): + return self.start == other.start and self.end == other.end + + +class NumberedList(object): + """ + Represents a numbered list. + + Attributes: + items (List[ListItem]): The list of items in the numbered list. + list_start (int): The starting number of the list. + list_end (int): The ending number of the list. + """ + + def __init__(self, items: List[ListItem], list_start: int, list_end: int) -> None: + self.list_start: int = list_start + self.list_end: int = list_end + self.items: List[ListItem] = items + + def __str__(self): + return str(self.__dict__) + + def __eq__(self, other): + return self.list_start == other.list_start and self.list_end == other.list_end diff --git a/src/miade/paragraphsegmenter.py b/src/miade/paragraphsegmenter.py new file mode 100644 index 0000000..2d24ae0 --- /dev/null +++ b/src/miade/paragraphsegmenter.py @@ -0,0 +1,103 @@ +import re +import io +import pkgutil +import logging +import spacy + +import pandas as pd + +from spacy.language import Language +from spacy.tokens import Span, Doc + +from typing import Dict + +from miade.paragraph import ParagraphType + +# TODO: this spacy pipeline doesn't play well so I will probably redo all this as a regular class + + +log = logging.getLogger(__name__) + + +def load_regex_config_mappings(filename: str) -> Dict: + regex_config = pkgutil.get_data(__name__, filename) + data = ( + pd.read_csv( + io.BytesIO(regex_config), + index_col=0, + ) + .squeeze("columns") + .T.to_dict() + ) + regex_lookup = {} + + for paragraph, regex in data.items(): + paragraph_enum = None + try: + paragraph_enum = ParagraphType(paragraph) + except ValueError as e: + log.warning(e) + + if paragraph_enum is not None: + regex_lookup[paragraph_enum] = regex + + return regex_lookup + + +@spacy.registry.misc("regex_config.v1") +def create_patterns_dict(): + regex_config = load_regex_config_mappings("./data/regex_para_chunk.csv") + + return regex_config + + +@Language.factory( + "paragraph_segmenter", + default_config={"regex_config": {"@misc": "regex_config.v1"}}, +) +def create_paragraph_segmenter(nlp: Language, name: str, regex_config: Dict): + return ParagraphSegmenter(nlp, regex_config) + + +class ParagraphSegmenter: + def __init__(self, nlp: Language, regex_config: Dict): + self.regex_config = regex_config + # Set custom extensions + if not Span.has_extension("heading"): + Span.set_extension("heading", default=None) + if not Span.has_extension("body"): + Span.set_extension("body", default=None, force=True) + if not Span.has_extension("type"): + Span.set_extension("type", default=None, force=True) + + def __call__(self, doc: Doc) -> Doc: + paragraphs = re.split(r"\n\n+", doc.text) + start = 0 + new_spans = [] + + for text in paragraphs: + match = re.search(r"^(.*?)(?:\n|$)([\s\S]*)", text) + if match: + heading, body = match.group(1), match.group(2) + else: + heading, body = text, "" + + end = start + len(text) + span = doc.char_span(start, end, label="PARAGRAPH") + if span is not None: + span._.heading = heading + span._.body = body + span._.type = ParagraphType.prose # default type + + heading_lower = heading.lower() + for paragraph_type, pattern in self.regex_config.items(): + if re.search(pattern, heading_lower): + span._.type = ParagraphType[paragraph_type] + break + + new_spans.append(span) + start = end + 2 + + doc.spans["paragraphs"] = new_spans + + return doc diff --git a/src/miade/utils/annotatorconfig.py b/src/miade/utils/annotatorconfig.py index 2e1a8b1..a5765fa 100644 --- a/src/miade/utils/annotatorconfig.py +++ b/src/miade/utils/annotatorconfig.py @@ -6,5 +6,6 @@ class AnnotatorConfig(BaseModel): lookup_data_path: Optional[str] = None negation_detection: Optional[str] = "negex" structured_list_limit: Optional[int] = 100 + refine_paragraphs: Optional[bool] = True disable: List[str] = [] add_numbering: bool = False diff --git a/tests/conftest.py b/tests/conftest.py index ef6f4f5..07b859b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,9 @@ import pytest import pandas as pd -from typing import List, Dict +from typing import List, Dict, Union from pathlib import Path -from miade.annotators import Annotator +from miade.annotators import Annotator, MedsAllergiesAnnotator, ProblemsAnnotator from miade.dosage import Dosage, Dose, Route from miade.note import Note @@ -109,6 +109,9 @@ def concept_types(): def pipeline(): return [] + def run_pipeline(self): + return [] + def postprocess(self): return super().postprocess() @@ -118,8 +121,18 @@ def process_paragraphs(self): return CustomAnnotator +@pytest.fixture +def test_problems_annotator(test_problems_medcat_model): + return ProblemsAnnotator(test_problems_medcat_model) + + +@pytest.fixture +def test_meds_algy_annotator(test_meds_algy_medcat_model): + return MedsAllergiesAnnotator(test_meds_algy_medcat_model) + + @pytest.fixture(scope="function") -def test_miade_doses() -> (List[Note], pd.DataFrame): +def test_miade_doses() -> Union[List[Note], pd.DataFrame]: extracted_doses = pd.read_csv("./tests/examples/common_doses_for_miade.csv") return [Note(text=dose) for dose in extracted_doses.dosestring.to_list()], extracted_doses @@ -214,6 +227,41 @@ def test_clean_and_paragraphing_note() -> Note: ) +@pytest.fixture +def test_numbered_list_note(): + return Note( + text="""Problems: + +1. CCF - +- had echo on 15/6 +- on diuretics + +- awaiting pacemaker + +2). IHD +3. Diabetes type 2 + +HbA1c = 78mmol/L +4 Gallstones + +here is some random prose idk + +Medications: + +1. Paracetamol - +- start after this date +- refill in future + +2) Ibuprofen +3. Metformin +- for hypertension + +Allergies: +and then here is some extra stuff that doesn't fall within lists, we want to detect things here +""" + ) + + @pytest.fixture(scope="function") def test_paragraph_chunking_prob_concepts() -> List[Concept]: return [ @@ -462,7 +510,7 @@ def test_problem_list_limit_concepts() -> List[Concept]: @pytest.fixture(scope="function") -def test_duplicate_concepts_record() -> List[Concept]: +def test_record_concepts() -> List[Concept]: return [ Concept(id="1", name="test1", category=Category.PROBLEM), Concept(id="2", name="test2", category=Category.PROBLEM), diff --git a/tests/test_annotator.py b/tests/test_annotator.py index 4cd4ea2..cac9a79 100644 --- a/tests/test_annotator.py +++ b/tests/test_annotator.py @@ -2,6 +2,7 @@ from miade.annotators import MedsAllergiesAnnotator, ProblemsAnnotator, calculate_word_distance from miade.dosage import Dose, Frequency, Dosage, Route from miade.dosageextractor import DosageExtractor +from miade.paragraph import ListItem, NumberedList def test_dosage_text_splitter(test_meds_algy_medcat_model, test_med_concepts, test_med_note): @@ -63,26 +64,24 @@ def test_calculate_word_distance(): def test_deduplicate( - test_problems_medcat_model, + test_problems_annotator, test_duplicate_concepts_note, - test_duplicate_concepts_record, + test_record_concepts, test_self_duplicate_concepts_note, test_duplicate_vtm_concept_note, test_duplicate_vtm_concept_record, ): - annotator = ProblemsAnnotator(test_problems_medcat_model) - - assert annotator.deduplicate( - concepts=test_duplicate_concepts_note, record_concepts=test_duplicate_concepts_record + assert test_problems_annotator.deduplicate( + concepts=test_duplicate_concepts_note, record_concepts=test_record_concepts ) == [ Concept(id="7", name="test2", category=Category.MEDICATION), Concept(id="5", name="test2", category=Category.PROBLEM), ] - assert annotator.deduplicate(concepts=test_self_duplicate_concepts_note, record_concepts=None) == [ + assert test_problems_annotator.deduplicate(concepts=test_self_duplicate_concepts_note, record_concepts=None) == [ Concept(id="1", name="test1", category=Category.PROBLEM), Concept(id="2", name="test2", category=Category.MEDICATION), ] - assert annotator.deduplicate(concepts=test_duplicate_concepts_note, record_concepts=None) == [ + assert test_problems_annotator.deduplicate(concepts=test_duplicate_concepts_note, record_concepts=None) == [ Concept(id="1", name="test1", category=Category.PROBLEM), Concept(id="2", name="test2", category=Category.PROBLEM), Concept(id="3", name="test2", category=Category.PROBLEM), @@ -91,7 +90,7 @@ def test_deduplicate( Concept(id="5", name="test2", category=Category.PROBLEM), Concept(id="6", name="test2", category=Category.MEDICATION), ] - assert annotator.deduplicate(concepts=test_duplicate_concepts_note, record_concepts=[]) == [ + assert test_problems_annotator.deduplicate(concepts=test_duplicate_concepts_note, record_concepts=[]) == [ Concept(id="1", name="test1", category=Category.PROBLEM), Concept(id="2", name="test2", category=Category.PROBLEM), Concept(id="3", name="test2", category=Category.PROBLEM), @@ -101,19 +100,17 @@ def test_deduplicate( Concept(id="6", name="test2", category=Category.MEDICATION), ] # test vtm deduplication (string match) - assert annotator.deduplicate( + assert test_problems_annotator.deduplicate( concepts=test_duplicate_vtm_concept_note, record_concepts=test_duplicate_vtm_concept_record ) == [ Concept(id=None, name="vtm1", category=Category.MEDICATION), Concept(id=None, name="vtm3", category=Category.MEDICATION), ] - assert annotator.deduplicate(concepts=[], record_concepts=test_duplicate_concepts_record) == [] - + assert test_problems_annotator.deduplicate(concepts=[], record_concepts=test_record_concepts) == [] -def test_meta_annotations(test_problems_medcat_model, test_meta_annotations_concepts): - annotator = ProblemsAnnotator(test_problems_medcat_model) - assert annotator.postprocess(test_meta_annotations_concepts) == [ +def test_meta_annotations(test_problems_annotator, test_meta_annotations_concepts): + assert test_problems_annotator.postprocess(test_meta_annotations_concepts) == [ Concept(id="274826007", name="Nystagmus (negated)", category=Category.PROBLEM), # negex true, meta ignored Concept(id="302064001", name="Lymphangitis (negated)", category=Category.PROBLEM), # negex true, meta ignored Concept(id="431956005", name="Arthritis (suspected)", category=Category.PROBLEM), # negex false, meta processed @@ -137,7 +134,7 @@ def test_meta_annotations(test_problems_medcat_model, test_meta_annotations_conc test_meta_annotations_concepts[3].negex = True test_meta_annotations_concepts[6].negex = True - assert annotator.postprocess(test_meta_annotations_concepts) == [ + assert test_problems_annotator.postprocess(test_meta_annotations_concepts) == [ Concept(id="274826007", name="Nystagmus (negated)", category=Category.PROBLEM), # negex true, meta empty Concept( id="1415005", name="Lymphangitis", category=Category.PROBLEM @@ -147,16 +144,14 @@ def test_meta_annotations(test_problems_medcat_model, test_meta_annotations_conc ] -def test_problems_filtering_list(test_problems_medcat_model, test_filtering_list_concepts): - annotator = ProblemsAnnotator(test_problems_medcat_model) - assert annotator.postprocess(test_filtering_list_concepts) == [ +def test_problems_filtering_list(test_problems_annotator, test_filtering_list_concepts): + assert test_problems_annotator.postprocess(test_filtering_list_concepts) == [ Concept(id="123", name="real concept", category=Category.PROBLEM), ] -def test_allergy_annotator(test_meds_algy_medcat_model, test_substance_concepts_with_meta_anns, test_meds_allergy_note): - annotator = MedsAllergiesAnnotator(test_meds_algy_medcat_model) - concepts = annotator.postprocess(test_substance_concepts_with_meta_anns, test_meds_allergy_note) +def test_allergy_annotator(test_meds_algy_annotator, test_substance_concepts_with_meta_anns, test_meds_allergy_note): + concepts = test_meds_algy_annotator.postprocess(test_substance_concepts_with_meta_anns, test_meds_allergy_note) # print([concept.__str__() for concept in concepts]) assert concepts == [ @@ -177,9 +172,8 @@ def test_allergy_annotator(test_meds_algy_medcat_model, test_substance_concepts_ assert concepts[2].linked_concepts == [] -def test_vtm_med_conversions(test_meds_algy_medcat_model, test_vtm_concepts): - annotator = MedsAllergiesAnnotator(test_meds_algy_medcat_model) - concepts = annotator.convert_VTM_to_VMP_or_text(test_vtm_concepts) +def test_vtm_med_conversions(test_meds_algy_annotator, test_vtm_concepts): + concepts = test_meds_algy_annotator.convert_VTM_to_VMP_or_text(test_vtm_concepts) # print([concept.__str__() for concept in concepts]) assert concepts == [ @@ -217,6 +211,33 @@ def test_vtm_med_conversions(test_meds_algy_medcat_model, test_vtm_concepts): ) +def test_filter_concepts_in_numbered_list(test_note, test_problems_annotator): + test_note.numbered_lists = [ + NumberedList( + list_start=0, + list_end=30, + items=[ListItem(content="", start=2, end=5), ListItem(content="", start=20, end=25)], + ), + NumberedList(list_start=40, list_end=52, items=[ListItem(content="", start=45, end=50)]), + ] + concepts = [ + Concept(id="0", name="Concept 0", start=0, end=3), # Partial overlap with list item - Filter + Concept(id="1", name="Concept 1", start=6, end=12), # In between list items - Filter + Concept(id="2", name="Concept 2", start=21, end=25), # List item - Keep + Concept(id="3", name="Concept 3", start=27, end=30), # After list item - Filter + Concept(id="4", name="Concept 4", start=30, end=40), # Content between lists - Keep + Concept(id="5", name="Concept 5", start=45, end=50), # List item - Keep + Concept(id="6", name="Concept 6", start=50, end=60), # After list item - Filter + Concept(id="7", name="Concept 7", start=55, end=70), # Content outside lists - Keep + ] + assert test_problems_annotator.filter_concepts_in_numbered_list(concepts, test_note) == [ + Concept(id="2", name="Concept 2", start=21, end=25), + Concept(id="4", name="Concept 4", start=30, end=40), + Concept(id="5", name="Concept 5", start=45, end=50), + Concept(id="7", name="Concept 7", start=55, end=70), + ] + + def test_annotator_config(test_meds_algy_medcat_model, test_problems_medcat_model, test_config): # check that all loads ok if pass in explicit path test_config.lookup_data_path = "./src/miade/data/" diff --git a/tests/test_note.py b/tests/test_note.py index aaef5a3..6ce20ca 100644 --- a/tests/test_note.py +++ b/tests/test_note.py @@ -1,8 +1,8 @@ -from miade.annotators import MedsAllergiesAnnotator, ProblemsAnnotator +from miade.annotators import ProblemsAnnotator from miade.concept import Concept, Category from miade.paragraph import Paragraph, ParagraphType from miade.metaannotations import MetaAnnotations - +from miade.note import Note, NumberedList, ListItem from miade.utils.metaannotationstypes import ( Presence, Relevance, @@ -10,9 +10,8 @@ ) -def test_note_cleaning_and_paragraphing(test_clean_and_paragraphing_note): - test_clean_and_paragraphing_note.clean_text() - test_clean_and_paragraphing_note.get_paragraphs() +def test_note_cleaning_and_paragraphing_naive(test_problems_annotator, test_clean_and_paragraphing_note): + test_problems_annotator.preprocess(test_clean_and_paragraphing_note, refine=False) assert test_clean_and_paragraphing_note.paragraphs == [ Paragraph(heading="", body="", type=ParagraphType.prose, start=0, end=182), @@ -27,13 +26,46 @@ def test_note_cleaning_and_paragraphing(test_clean_and_paragraphing_note): ] +def test_note_cleaning_and_paragraphing_refined(test_problems_annotator, test_clean_and_paragraphing_note): + test_problems_annotator.preprocess(test_clean_and_paragraphing_note, refine=True) + + assert test_clean_and_paragraphing_note.paragraphs == [ + Paragraph(heading="", body="", type=ParagraphType.prose, start=0, end=314), + Paragraph(heading="", body="", type=ParagraphType.pmh, start=316, end=341), + Paragraph(heading="", body="", type=ParagraphType.med, start=343, end=406), + Paragraph(heading="", body="", type=ParagraphType.allergy, start=408, end=445), + Paragraph(heading="", body="", type=ParagraphType.prob, start=447, end=477), + Paragraph(heading="", body="", type=ParagraphType.plan, start=479, end=505), + Paragraph(heading="", body="", type=ParagraphType.imp, start=507, end=523), + ] + + +def test_numbered_list_note(test_problems_annotator, test_numbered_list_note): + test_concepts = [ + Concept(id="correct", name="list item", start=10, end=17), + Concept(id="incorrect", name="list item not relevant", start=20, end=60), + Concept(id="incorrect", name="list item not relevant", start=200, end=210), + Concept(id="correct", name="prose that is not in lists", start=130, end=140), + Concept(id="correct", name="other section", start=280, end=300), + Concept(id="correct", name="other section", start=300, end=378), + ] + test_problems_annotator.preprocess(test_numbered_list_note, refine=True) + assert test_problems_annotator.filter_concepts_in_numbered_list(test_concepts, test_numbered_list_note) == [ + Concept(name="list item", id="correct"), + Concept(name="prose that is not in lists", id="correct"), + Concept(name="other section", id="correct"), + Concept(name="other section", id="correct"), + ] + + def test_prob_paragraph_note( - test_problems_medcat_model, test_clean_and_paragraphing_note, test_paragraph_chunking_prob_concepts + test_problems_annotator, test_clean_and_paragraphing_note, test_paragraph_chunking_prob_concepts ): - annotator = ProblemsAnnotator(test_problems_medcat_model) - annotator.preprocess(test_clean_and_paragraphing_note) + test_problems_annotator.preprocess(test_clean_and_paragraphing_note) - concepts = annotator.process_paragraphs(test_clean_and_paragraphing_note, test_paragraph_chunking_prob_concepts) + concepts = test_problems_annotator.process_paragraphs( + test_clean_and_paragraphing_note, test_paragraph_chunking_prob_concepts + ) # prose assert concepts[0].meta == [ MetaAnnotations(name="presence", value=Presence.CONFIRMED), @@ -66,12 +98,13 @@ def test_prob_paragraph_note( def test_med_paragraph_note( - test_meds_algy_medcat_model, test_clean_and_paragraphing_note, test_paragraph_chunking_med_concepts + test_meds_algy_annotator, test_clean_and_paragraphing_note, test_paragraph_chunking_med_concepts ): - annotator = MedsAllergiesAnnotator(test_meds_algy_medcat_model) - annotator.preprocess(test_clean_and_paragraphing_note) + test_meds_algy_annotator.preprocess(test_clean_and_paragraphing_note) - concepts = annotator.process_paragraphs(test_clean_and_paragraphing_note, test_paragraph_chunking_med_concepts) + concepts = test_meds_algy_annotator.process_paragraphs( + test_clean_and_paragraphing_note, test_paragraph_chunking_med_concepts + ) # pmh assert concepts[0].meta == [ MetaAnnotations(name="substance_category", value=SubstanceCategory.IRRELEVANT), @@ -144,3 +177,76 @@ def test_problem_list_limit( ], ), ] + + +def test_get_numbered_lists_empty_text(): + note = Note("") + note.get_numbered_lists() + assert note.numbered_lists == [] + + +def test_get_numbered_lists_no_lists(): + text = "This is a sample note without any numbered lists." + note = Note(text) + note.get_numbered_lists() + assert note.numbered_lists == [] + + +def test_get_numbered_lists_single_list(): + text = "\n1. First item\n2. Second item\n3. Third item" + note = Note(text) + note.get_numbered_lists() + assert len(note.numbered_lists) == 1 + assert isinstance(note.numbered_lists[0], NumberedList) + assert len(note.numbered_lists[0].items) == 3 + assert isinstance(note.numbered_lists[0].items[0], ListItem) + assert note.numbered_lists[0].items[0].content == "1. First item" + assert note.numbered_lists[0].items[0].start == 1 + assert note.numbered_lists[0].items[0].end == 14 + assert isinstance(note.numbered_lists[0].items[1], ListItem) + assert note.numbered_lists[0].items[1].content == "2. Second item" + assert note.numbered_lists[0].items[1].start == 15 + assert note.numbered_lists[0].items[1].end == 29 + assert isinstance(note.numbered_lists[0].items[2], ListItem) + assert note.numbered_lists[0].items[2].content == "3. Third item" + assert note.numbered_lists[0].items[2].start == 30 + assert note.numbered_lists[0].items[2].end == 43 + + +def test_merge_prose_sections(): + note = Note("This is the first paragraph.\n\nThis is the second paragraph.\n\nThis is the third paragraph.") + note.paragraphs = [ + Paragraph(heading="", body="This is the first paragraph.", type=ParagraphType.prose, start=0, end=27), + Paragraph(heading="", body="This is the second paragraph.", type=ParagraphType.prose, start=29, end=58), + Paragraph(heading="", body="This is the third paragraph.", type=ParagraphType.prose, start=60, end=89), + ] + note.merge_prose_sections() + assert len(note.paragraphs) == 1 + assert ( + note.paragraphs[0].heading + == "This is the first paragraph.\n\nThis is the second paragraph.\n\nThis is the third paragraph." + ) + assert note.paragraphs[0].type == ParagraphType.prose + + +def test_merge_empty_non_prose_with_next_prose(): + note = Note("This is the first paragraph.\n\nThis is the second paragraph.\n\nThis is the third paragraph.") + note.paragraphs = [ + Paragraph(heading="Heading 1", body="", type=ParagraphType.prob, start=0, end=14), + Paragraph(heading="This is the first paragraph.", body="", type=ParagraphType.prose, start=16, end=43), + Paragraph(heading="Heading 2", body="", type=ParagraphType.pmh, start=45, end=59), + Paragraph(heading="This is the second paragraph.", body="", type=ParagraphType.prose, start=61, end=90), + Paragraph(heading="Heading 3", body="", type=ParagraphType.med, start=92, end=106), + Paragraph(heading="This is the third paragraph.", body="", type=ParagraphType.prose, start=108, end=137), + ] + note.merge_empty_non_prose_with_next_prose() + assert len(note.paragraphs) == 3 + assert note.paragraphs[0].heading == "Heading 1" + assert note.paragraphs[0].body == "This is the first paragraph." + assert note.paragraphs[0].type == ParagraphType.prob + assert note.paragraphs[1].heading == "Heading 2" + assert note.paragraphs[1].body == "This is the second paragraph." + assert note.paragraphs[1].type == ParagraphType.pmh + assert note.paragraphs[2].heading == "Heading 3" + assert note.paragraphs[2].body == "This is the third paragraph." + assert note.paragraphs[2].type == ParagraphType.med diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..50b9cac --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,82 @@ +from unittest.mock import Mock, patch +from miade.concept import Category, Concept + + +def test_preprocessor_called(test_note, test_problems_annotator, test_meds_algy_annotator): + with patch.object(test_problems_annotator, "preprocess") as mock_preprocess: + test_problems_annotator.run_pipeline(test_note) + mock_preprocess.assert_called_once() + + with patch.object(test_meds_algy_annotator, "preprocess") as mock_preprocess: + test_meds_algy_annotator.run_pipeline(test_note) + mock_preprocess.assert_called_once() + + +def test_medcat_called(test_note, test_problems_annotator, test_meds_algy_annotator): + with patch.object(test_problems_annotator, "get_concepts") as mock_get_concepts: + test_problems_annotator.run_pipeline(test_note) + mock_get_concepts.assert_called_once_with(note=test_note) + + with patch.object(test_meds_algy_annotator, "get_concepts") as mock_get_concepts: + test_meds_algy_annotator.run_pipeline(test_note) + mock_get_concepts.assert_called_once_with(note=test_note) + + +def test_list_cleaner_called(test_note, test_problems_annotator, test_meds_algy_annotator): + with patch.object(test_problems_annotator, "filter_concepts_in_numbered_list") as mock_filter: + test_problems_annotator.run_pipeline(test_note) + mock_filter.assert_called() + + with patch.object(test_meds_algy_annotator, "filter_concepts_in_numbered_list") as mock_filter: + test_meds_algy_annotator.run_pipeline(test_note) + mock_filter.assert_called() + + +def test_paragrapher_called(test_note, test_problems_annotator, test_meds_algy_annotator): + with patch.object(test_problems_annotator, "process_paragraphs") as mock_process_paragraphs: + test_problems_annotator.run_pipeline(test_note) + mock_process_paragraphs.assert_called() + + with patch.object(test_meds_algy_annotator, "process_paragraphs") as mock_process_paragraphs: + test_meds_algy_annotator.run_pipeline(test_note) + mock_process_paragraphs.assert_called() + + +def test_postprocessor_called(test_note, test_problems_annotator, test_meds_algy_annotator): + with patch.object(test_problems_annotator, "postprocess") as mock_postprocess: + test_problems_annotator.run_pipeline(test_note) + mock_postprocess.assert_called() + + with patch.object(test_meds_algy_annotator, "postprocess") as mock_postprocess: + test_meds_algy_annotator.run_pipeline( + test_note, + ) + mock_postprocess.assert_called() + + +def test_deduplicator_called(test_note, test_record_concepts, test_problems_annotator, test_meds_algy_annotator): + with patch.object(test_problems_annotator, "deduplicate") as mock_deduplicate: + test_problems_annotator.run_pipeline(test_note, test_record_concepts) + mock_deduplicate.assert_called_with( + concepts=[Concept(id="59927004", name="liver failure", category=Category.PROBLEM)], + record_concepts=test_record_concepts, + ) + + with patch.object(test_meds_algy_annotator, "deduplicate") as mock_deduplicate: + test_meds_algy_annotator.run_pipeline(test_note, test_record_concepts) + mock_deduplicate.assert_called_with( + concepts=[Concept(id="322236009", name="paracetamol 500mg oral tablets", category=None)], + record_concepts=test_record_concepts, + ) + + +def test_vtm_converter_called(test_note, test_meds_algy_annotator): + with patch.object(test_meds_algy_annotator, "convert_VTM_to_VMP_or_text") as mock_vtm_converter: + test_meds_algy_annotator.run_pipeline(test_note) + mock_vtm_converter.assert_called() + + +def test_dosage_extractor_called(test_note, test_meds_algy_annotator): + with patch.object(test_meds_algy_annotator, "add_dosages_to_concepts") as mock_dosage_extractor: + test_meds_algy_annotator.run_pipeline(test_note, dosage_extractor=Mock()) + mock_dosage_extractor.assert_called()