diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2b0fe90..3befd09 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,3 +31,7 @@ jobs: pip install https://huggingface.co/kormilitzin/en_core_med7_lg/resolve/main/en_core_med7_lg-any-py3-none-any.whl - name: run pytest run: pytest ./tests/* + - name: install ruff + run: pip install ruff + - name: ruff format + run: ruff format --check . diff --git a/.gitignore b/.gitignore index a423ca1..2c5f129 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,9 @@ __pycache__/ .ipynb_checkpoints/ .idea/ +# Linting +.ruff_cache/ + # Pytest .pytest_cache/ diff --git a/configs/miade_config.yaml b/configs/miade_config.yaml index 7322add..1a8b8aa 100644 --- a/configs/miade_config.yaml +++ b/configs/miade_config.yaml @@ -6,8 +6,10 @@ annotators: meds/allergies: MedsAllergiesAnnotator general: problems: + lookup_data_path: ./lookup_data/ negation_detection: None disable: [] meds/allergies: + lookup_data_path: ./lookup_data/ negation_detection: None disable: [] \ No newline at end of file diff --git a/src/miade/data/allergens_subset.csv b/lookup_data/allergens_subset.csv similarity index 100% rename from src/miade/data/allergens_subset.csv rename to lookup_data/allergens_subset.csv diff --git a/src/miade/data/allergy_type.csv b/lookup_data/allergy_type.csv similarity index 100% rename from src/miade/data/allergy_type.csv rename to lookup_data/allergy_type.csv diff --git a/src/miade/data/historic.csv b/lookup_data/historic.csv similarity index 100% rename from src/miade/data/historic.csv rename to lookup_data/historic.csv diff --git a/src/miade/data/negated.csv b/lookup_data/negated.csv similarity index 100% rename from src/miade/data/negated.csv rename to lookup_data/negated.csv diff --git a/src/miade/data/problem_blacklist.csv b/lookup_data/problem_blacklist.csv similarity index 100% rename from src/miade/data/problem_blacklist.csv rename to lookup_data/problem_blacklist.csv diff --git a/src/miade/data/reactions_subset.csv b/lookup_data/reactions_subset.csv similarity index 100% rename from src/miade/data/reactions_subset.csv rename to lookup_data/reactions_subset.csv diff --git a/src/miade/data/suspected.csv b/lookup_data/suspected.csv similarity index 100% rename from src/miade/data/suspected.csv rename to lookup_data/suspected.csv diff --git a/src/miade/data/valid_meds.csv b/lookup_data/valid_meds.csv similarity index 100% rename from src/miade/data/valid_meds.csv rename to lookup_data/valid_meds.csv diff --git a/src/miade/data/vtm_to_text.csv b/lookup_data/vtm_to_text.csv similarity index 100% rename from src/miade/data/vtm_to_text.csv rename to lookup_data/vtm_to_text.csv diff --git a/src/miade/data/vtm_to_vmp.csv b/lookup_data/vtm_to_vmp.csv similarity index 100% rename from src/miade/data/vtm_to_vmp.csv rename to lookup_data/vtm_to_vmp.csv diff --git a/pyproject.toml b/pyproject.toml index 31bad4a..ffde221 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,3 +56,8 @@ where = ["src"] [tool.setuptools.package-data] miade = ["data/*.csv"] +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +ignore = ["E721"] \ No newline at end of file diff --git a/src/miade/annotators.py b/src/miade/annotators.py index 681a3be..158b1c9 100644 --- a/src/miade/annotators.py +++ b/src/miade/annotators.py @@ -1,7 +1,8 @@ -import io +import os import logging -import pkgutil import re +from enum import Enum + import pandas as pd from typing import List, Optional, Tuple, Dict @@ -15,7 +16,13 @@ from .note import Note from .paragraph import ParagraphType from .dosageextractor import DosageExtractor -from .utils.metaannotationstypes import * +from .utils.metaannotationstypes import ( + Presence, + Relevance, + ReactionPos, + SubstanceCategory, + Severity, +) from .utils.annotatorconfig import AnnotatorConfig log = logging.getLogger(__name__) @@ -23,6 +30,7 @@ # Precompile regular expressions sent_regex = re.compile(r"[^\s][^\n]+") + class AllergenType(Enum): FOOD = "food" DRUG = "drug" @@ -34,39 +42,33 @@ class AllergenType(Enum): def load_lookup_data(filename: str, as_dict: bool = False, no_header: bool = False): - lookup_data = pkgutil.get_data(__name__, filename) if as_dict: return ( pd.read_csv( - io.BytesIO(lookup_data), + filename, index_col=0, - ).squeeze("columns") + ) + .squeeze("columns") .T.to_dict() ) if no_header: - return ( - pd.read_csv( - io.BytesIO(lookup_data), - header=None - ) - ) + return pd.read_csv(filename, header=None) else: - return pd.read_csv(io.BytesIO(lookup_data)).drop_duplicates() + return pd.read_csv(filename).drop_duplicates() def load_allergy_type_combinations(filename: str) -> Dict: - data = pkgutil.get_data(__name__, filename) - df = pd.read_csv(io.BytesIO(data)) + df = pd.read_csv(filename) # Convert 'allergenType' and 'adverseReactionType' columns to lowercase - df['allergenType'] = df['allergenType'].str.lower() - df['adverseReactionType'] = df['adverseReactionType'].str.lower() + df["allergenType"] = df["allergenType"].str.lower() + df["adverseReactionType"] = df["adverseReactionType"].str.lower() # Create a tuple column containing (reaction_id, reaction_name) for each row - df['reaction_id_name'] = list(zip(df['adverseReactionId'], df['adverseReactionName'])) + df["reaction_id_name"] = list(zip(df["adverseReactionId"], df["adverseReactionName"])) # Set (allergenType, adverseReactionType) as the index and convert to dictionary - result_dict = df.set_index(['allergenType', 'adverseReactionType'])['reaction_id_name'].to_dict() + result_dict = df.set_index(["allergenType", "adverseReactionType"])["reaction_id_name"].to_dict() return result_dict @@ -79,10 +81,10 @@ def get_dosage_string(med: Concept, next_med: Optional[Concept], text: str) -> s :param text: (str) whole text :return: (str) dosage text """ - sents = sent_regex.findall(text[med.start: next_med.start] if next_med is not None else text[med.start:]) + sents = sent_regex.findall(text[med.start : next_med.start] if next_med is not None else text[med.start :]) - concept_name = text[med.start: med.end] - next_concept_name = text[next_med.start: next_med.end] if next_med else None + concept_name = text[med.start : med.end] + next_concept_name = text[next_med.start : next_med.end] if next_med else None for sent in sents: if concept_name in sent: @@ -90,7 +92,7 @@ def get_dosage_string(med: Concept, next_med: Optional[Concept], text: str) -> s if next_concept_name not in sent: return sent else: - return text[med.start: next_med.start] + return text[med.start : next_med.start] else: ind = sent.find(concept_name) return sent[ind:] @@ -133,6 +135,7 @@ class Annotator: """ Docstring for Annotator """ + # TODO: Create abstract class methods for easier unit testing def __init__(self, cat: CAT, config: AnnotatorConfig = None): self.cat = cat @@ -159,6 +162,7 @@ def get_concepts(self, note: Note) -> List[Concept]: log.warning(f"Concept skipped: {e}") return concepts + @staticmethod def preprocess(note: Note) -> Note: note.clean_text() @@ -180,81 +184,117 @@ def process_paragraphs(note: Note, concepts: List[Concept]) -> List[Concept]: # problem is present and allergy is irrelevant for meta in concept.meta: if meta.name == "relevance" and meta.value == Relevance.IRRELEVANT: - log.debug(f"Converted {meta.value} to " - f"{Relevance.PRESENT} for concept ({concept.id} | {concept.name}): " - f"paragraph is {paragraph.type}") + log.debug( + f"Converted {meta.value} to " + f"{Relevance.PRESENT} for concept ({concept.id} | {concept.name}): " + f"paragraph is {paragraph.type}" + ) meta.value = Relevance.PRESENT if meta.name == "substance_category": - log.debug(f"Converted {meta.value} to " - f"{SubstanceCategory.IRRELEVANT} for concept ({concept.id} | {concept.name}): " - f"paragraph is {paragraph.type}") - meta.value = SubstanceCategory.IRRELEVANT + if meta.value != SubstanceCategory.IRRELEVANT: + log.debug( + f"Converted {meta.value} to " + f"{SubstanceCategory.IRRELEVANT} for concept ({concept.id} | {concept.name}): " + f"paragraph is {paragraph.type}" + ) + meta.value = SubstanceCategory.IRRELEVANT elif paragraph.type == ParagraphType.pmh: prob_concepts.append(concept) # problem is historic and allergy is irrelevant for meta in concept.meta: if meta.name == "relevance" and meta.value == Relevance.IRRELEVANT: - log.debug(f"Converted {meta.value} to " - f"{Relevance.HISTORIC} for concept ({concept.id} | {concept.name}): " - f"paragraph is {paragraph.type}") + log.debug( + f"Converted {meta.value} to " + f"{Relevance.HISTORIC} for concept ({concept.id} | {concept.name}): " + f"paragraph is {paragraph.type}" + ) meta.value = Relevance.HISTORIC if meta.name == "substance_category": - log.debug(f"Converted {meta.value} to " - f"{SubstanceCategory.IRRELEVANT} for concept ({concept.id} | {concept.name}): " - f"paragraph is {paragraph.type}") - meta.value = SubstanceCategory.IRRELEVANT + if meta.value != SubstanceCategory.IRRELEVANT: + log.debug( + f"Converted {meta.value} to " + f"{SubstanceCategory.IRRELEVANT} for concept ({concept.id} | {concept.name}): " + f"paragraph is {paragraph.type}" + ) + meta.value = SubstanceCategory.IRRELEVANT elif paragraph.type == ParagraphType.med: # problem is irrelevant and allergy is taking for meta in concept.meta: if meta.name == "relevance": - log.debug(f"Converted {meta.value} to " - f"{Relevance.IRRELEVANT} for concept ({concept.id} | {concept.name}): " - f"paragraph is {paragraph.type}") - meta.value = Relevance.IRRELEVANT + if meta.value != Relevance.IRRELEVANT: + log.debug( + f"Converted {meta.value} to " + f"{Relevance.IRRELEVANT} for concept ({concept.id} | {concept.name}): " + f"paragraph is {paragraph.type}" + ) + meta.value = Relevance.IRRELEVANT if meta.name == "substance_category" and meta.value == SubstanceCategory.IRRELEVANT: - log.debug(f"Converted {meta.value} to " - f"{SubstanceCategory.TAKING} for concept ({concept.id} | {concept.name}): " - f"paragraph is {paragraph.type}") - meta.value = SubstanceCategory.TAKING + if meta.value != SubstanceCategory.TAKING: + log.debug( + f"Converted {meta.value} to " + f"{SubstanceCategory.TAKING} for concept ({concept.id} | {concept.name}): " + f"paragraph is {paragraph.type}" + ) + meta.value = SubstanceCategory.TAKING elif paragraph.type == ParagraphType.allergy: # problem is irrelevant and allergy is as is for meta in concept.meta: if meta.name == "relevance": - log.debug(f"Converted {meta.value} to " - f"{Relevance.IRRELEVANT} for concept ({concept.id} | {concept.name}): " - f"paragraph is {paragraph.type}") - meta.value = Relevance.IRRELEVANT + if meta.value != Relevance.IRRELEVANT: + log.debug( + f"Converted {meta.value} to " + f"{Relevance.IRRELEVANT} for concept ({concept.id} | {concept.name}): " + f"paragraph is {paragraph.type}" + ) + meta.value = Relevance.IRRELEVANT if meta.name == "substance_category": - log.debug(f"Converted {meta.value} to " - f"{SubstanceCategory.ADVERSE_REACTION} for concept ({concept.id} | {concept.name}): " - f"paragraph is {paragraph.type}") - meta.value = SubstanceCategory.ADVERSE_REACTION - elif paragraph.type == ParagraphType.exam or paragraph.type == ParagraphType.ddx or paragraph.type == ParagraphType.plan: + # DO NOT CONVERT REACTIONS + if ( + meta.value != SubstanceCategory.ADVERSE_REACTION + and meta.value != SubstanceCategory.NOT_SUBSTANCE + ): + log.debug( + f"Converted {meta.value} to " + f"{SubstanceCategory.ADVERSE_REACTION} for concept ({concept.id} | {concept.name}): " + f"paragraph is {paragraph.type}" + ) + meta.value = SubstanceCategory.ADVERSE_REACTION + elif ( + paragraph.type == ParagraphType.exam + or paragraph.type == ParagraphType.ddx + or paragraph.type == ParagraphType.plan + ): # problem is irrelevant and allergy is irrelevant for meta in concept.meta: if meta.name == "relevance": - log.debug(f"Converted {meta.value} to " - f"{Relevance.IRRELEVANT} for concept ({concept.id} | {concept.name}): " - f"paragraph is {paragraph.type}") - meta.value = Relevance.IRRELEVANT + if meta.value != Relevance.IRRELEVANT: + log.debug( + f"Converted {meta.value} to " + f"{Relevance.IRRELEVANT} for concept ({concept.id} | {concept.name}): " + f"paragraph is {paragraph.type}" + ) + meta.value = Relevance.IRRELEVANT if meta.name == "substance_category": - log.debug(f"Converted {meta.value} to " - f"{SubstanceCategory.IRRELEVANT} for concept ({concept.id} | {concept.name}): " - f"paragraph is {paragraph.type}") - meta.value = SubstanceCategory.IRRELEVANT - + if meta.value != SubstanceCategory.IRRELEVANT: + log.debug( + f"Converted {meta.value} to " + f"{SubstanceCategory.IRRELEVANT} for concept ({concept.id} | {concept.name}): " + f"paragraph is {paragraph.type}" + ) + meta.value = SubstanceCategory.IRRELEVANT # print(len(prob_concepts)) # if more than 10 concepts in prob or imp or pmh sections, return only those and ignore all other concepts if len(prob_concepts) > 10: - log.debug(f"Ignoring concepts elsewhere in the document because " - f"concepts in prob, imp, pmh sections exceed 10: {len(prob_concepts)}") + log.debug( + f"Ignoring concepts elsewhere in the document because " + f"concepts in prob, imp, pmh sections exceed 10: {len(prob_concepts)}" + ) return prob_concepts else: return concepts - @staticmethod def deduplicate(concepts: List[Concept], record_concepts: Optional[List[Concept]]) -> List[Concept]: if record_concepts is not None: @@ -283,9 +323,7 @@ def deduplicate(concepts: List[Concept], record_concepts: Optional[List[Concept] @staticmethod def add_dosages_to_concepts( - dosage_extractor: DosageExtractor, - concepts: List[Concept], - note: Note + dosage_extractor: DosageExtractor, concepts: List[Concept], note: Note ) -> List[Concept]: """ Gets dosages for medication concepts @@ -296,18 +334,16 @@ def add_dosages_to_concepts( """ for ind, concept in enumerate(concepts): - next_med_concept = ( - concepts[ind + 1] - if len(concepts) > ind + 1 - else None - ) + next_med_concept = concepts[ind + 1] if len(concepts) > ind + 1 else None dosage_string = get_dosage_string(concept, next_med_concept, note.text) if len(dosage_string.split()) > 2: concept.dosage = dosage_extractor(dosage_string) concept.category = Category.MEDICATION if concept.dosage is not None else None if concept.dosage is not None: - log.debug(f"Extracted dosage for medication concept " - f"({concept.id} | {concept.name}): {concept.dosage.text} {concept.dosage.dose}") + log.debug( + f"Extracted dosage for medication concept " + f"({concept.id} | {concept.name}): {concept.dosage.text} {concept.dosage.dose}" + ) return concepts @@ -341,10 +377,18 @@ def __init__(self, cat: CAT, config: AnnotatorConfig = None): self.concept_types = [Category.PROBLEM] self.pipeline = ["preprocessor", "medcat", "paragrapher", "postprocessor", "deduplicator"] - self.negated_lookup = load_lookup_data("./data/negated.csv", as_dict=True) - self.historic_lookup = load_lookup_data("./data/historic.csv", as_dict=True) - self.suspected_lookup = load_lookup_data("./data/suspected.csv", as_dict=True) - self.filtering_blacklist = load_lookup_data("./data/problem_blacklist.csv", no_header=True) + self._load_problems_lookup_data() + + def _load_problems_lookup_data(self) -> None: + if not os.path.isdir(self.config.lookup_data_path): + raise RuntimeError(f"No lookup data configured: {self.config.lookup_data_path} does not exist!") + else: + self.negated_lookup = load_lookup_data(self.config.lookup_data_path + "negated.csv", as_dict=True) + self.historic_lookup = load_lookup_data(self.config.lookup_data_path + "historic.csv", as_dict=True) + self.suspected_lookup = load_lookup_data(self.config.lookup_data_path + "suspected.csv", as_dict=True) + self.filtering_blacklist = load_lookup_data( + self.config.lookup_data_path + "problem_blacklist.csv", no_header=True + ) def _process_meta_annotations(self, concept: Concept) -> Optional[Concept]: # Add, convert, or ignore concepts @@ -378,20 +422,23 @@ def _process_meta_annotations(self, concept: Concept) -> Optional[Concept]: if tag == " (negated)" and concept.negex: log.debug( f"Converted concept ({concept.id} | {concept.name}) to ({str(convert)} | {concept.name + tag}): " - f"negation detected by negex") + f"negation detected by negex" + ) else: - log.debug(f"Converted concept ({concept.id} | {concept.name}) to ({str(convert)} | {concept.name + tag}):" - f"detected by meta model") + log.debug( + f"Converted concept ({concept.id} | {concept.name}) to ({str(convert)} | {concept.name + tag}):" + f"detected by meta model" + ) concept.id = str(convert) concept.name += tag else: if concept.negex: - log.debug( - f"Removed concept ({concept.id} | {concept.name}): negation (negex) with no conversion match") + log.debug(f"Removed concept ({concept.id} | {concept.name}): negation (negex) with no conversion match") return None if concept.negex is None and Presence.NEGATED in meta_ann_values: log.debug( - f"Removed concept ({concept.id} | {concept.name}): negation (meta model) with no conversion match") + f"Removed concept ({concept.id} | {concept.name}): negation (meta model) with no conversion match" + ) return None if Presence.SUSPECTED in meta_ann_values: log.debug(f"Removed concept ({concept.id} | {concept.name}): suspected with no conversion match") @@ -409,9 +456,7 @@ def _process_meta_annotations(self, concept: Concept) -> Optional[Concept]: def _is_blacklist(self, concept): # filtering blacklist if int(concept.id) in self.filtering_blacklist.values: - log.debug( - f"Removed concept ({concept.id} | {concept.name}): concept in problems blacklist" - ) + log.debug(f"Removed concept ({concept.id} | {concept.name}): concept in problems blacklist") return True return False @@ -460,16 +505,32 @@ class MedsAllergiesAnnotator(Annotator): def __init__(self, cat: CAT, config: AnnotatorConfig = None): super().__init__(cat, config) self.concept_types = [Category.MEDICATION, Category.ALLERGY, Category.REACTION] - self.pipeline = ["preprocessor", "medcat", "paragrapher", "postprocessor", "dosage_extractor", - "vtm_converter", "deduplicator"] - - # load the lookup data - self.valid_meds = load_lookup_data("./data/valid_meds.csv", no_header=True) - self.reactions_subset_lookup = load_lookup_data("./data/reactions_subset.csv", as_dict=True) - self.allergens_subset_lookup = load_lookup_data("./data/allergens_subset.csv", as_dict=True) - self.allergy_type_lookup = load_allergy_type_combinations("./data/allergy_type.csv") - self.vtm_to_vmp_lookup = load_lookup_data("./data/vtm_to_vmp.csv") - self.vtm_to_text_lookup = load_lookup_data("./data/vtm_to_text.csv", as_dict=True) + self.pipeline = [ + "preprocessor", + "medcat", + "paragrapher", + "postprocessor", + "dosage_extractor", + "vtm_converter", + "deduplicator", + ] + + self._load_med_allergy_lookup_data() + + def _load_med_allergy_lookup_data(self) -> None: + if not os.path.isdir(self.config.lookup_data_path): + raise RuntimeError(f"No lookup data configured: {self.config.lookup_data_path} does not exist!") + else: + self.valid_meds = load_lookup_data(self.config.lookup_data_path + "valid_meds.csv", no_header=True) + self.reactions_subset_lookup = load_lookup_data( + self.config.lookup_data_path + "reactions_subset.csv", as_dict=True + ) + self.allergens_subset_lookup = load_lookup_data( + self.config.lookup_data_path + "allergens_subset.csv", as_dict=True + ) + self.allergy_type_lookup = load_allergy_type_combinations(self.config.lookup_data_path + "allergy_type.csv") + self.vtm_to_vmp_lookup = load_lookup_data(self.config.lookup_data_path + "vtm_to_vmp.csv") + self.vtm_to_text_lookup = load_lookup_data(self.config.lookup_data_path + "vtm_to_text.csv", as_dict=True) def _validate_meds(self, concept) -> bool: # check if substance is valid med @@ -481,16 +542,19 @@ def _validate_and_convert_substance(self, concept) -> bool: # check if substance is valid substance for allergy - if it is, convert it to Epic subset and return that concept lookup_result = self.allergens_subset_lookup.get(int(concept.id)) if lookup_result is not None: - tag = " (converted)" - log.debug(f"Converted concept ({concept.id} | {concept.name}) to " - f"({lookup_result['subsetId']} | {concept.name + tag}): valid Epic allergen subset") + log.debug( + f"Converted concept ({concept.id} | {concept.name}) to " + f"({lookup_result['subsetId']} | {concept.name}): valid Epic allergen subset" + ) concept.id = str(lookup_result["subsetId"]) # then check the allergen type from lookup result - e.g. drug, food try: - concept.category = AllergenType(str(lookup_result['allergenType']).lower()) - log.debug(f"Assigned substance concept ({concept.id} | {concept.name}) " - f"to allergen type category {concept.category}") + concept.category = AllergenType(str(lookup_result["allergenType"]).lower()) + log.debug( + f"Assigned substance concept ({concept.id} | {concept.name}) " + f"to allergen type category {concept.category}" + ) except ValueError as e: log.warning(f"Allergen type not found for {concept.__str__()}: {e}") @@ -503,10 +567,10 @@ def _validate_and_convert_reaction(self, concept) -> bool: # check if substance is valid reaction - if it is, convert it to Epic subset and return that concept lookup_result = self.reactions_subset_lookup.get(int(concept.id), None) if lookup_result is not None: - tag = " (converted)" - log.debug(f"Converted concept ({concept.id} | {concept.name}) to " - f"({lookup_result} | {concept.name + tag}): valid Epic reaction subset") - + log.debug( + f"Converted concept ({concept.id} | {concept.name}) to " + f"({lookup_result} | {concept.name}): valid Epic reaction subset" + ) concept.id = str(lookup_result) return True else: @@ -522,11 +586,22 @@ def _validate_and_convert_concepts(self, concept: Concept) -> Concept: self._convert_allergy_type_to_code(concept) self._convert_allergy_severity_to_code(concept) concept.category = Category.ALLERGY + else: + log.warning(f"Double-checking if concept ({concept.id} | {concept.name}) is in reaction subset") + if self._validate_and_convert_reaction(concept) and ( + ReactionPos.BEFORE_SUBSTANCE in meta_ann_values or ReactionPos.AFTER_SUBSTANCE in meta_ann_values + ): + concept.category = Category.REACTION + else: + log.warning( + f"Reaction concept ({concept.id} | {concept.name}) not in subset or reaction_pos is NOT_REACTION" + ) if SubstanceCategory.TAKING in meta_ann_values: if self._validate_meds(concept): concept.category = Category.MEDICATION if SubstanceCategory.NOT_SUBSTANCE in meta_ann_values and ( - ReactionPos.BEFORE_SUBSTANCE in meta_ann_values or ReactionPos.AFTER_SUBSTANCE in meta_ann_values): + ReactionPos.BEFORE_SUBSTANCE in meta_ann_values or ReactionPos.AFTER_SUBSTANCE in meta_ann_values + ): if self._validate_and_convert_reaction(concept): concept.category = Category.REACTION @@ -540,9 +615,9 @@ def _link_reactions_to_allergens(concept_list: List[Concept], note: Note, link_d for reaction_concept in reaction_concepts: nearest_allergy_concept = None min_distance = inf - meta_ann_values = [ - meta_ann.value for meta_ann in reaction_concept.meta - ] if reaction_concept.meta is not None else [] + meta_ann_values = ( + [meta_ann.value for meta_ann in reaction_concept.meta] if reaction_concept.meta is not None else [] + ) for allergy_concept in allergy_concepts: # skip if allergy is after and meta is before_substance @@ -552,15 +627,19 @@ def _link_reactions_to_allergens(concept_list: List[Concept], note: Note, link_d elif ReactionPos.AFTER_SUBSTANCE in meta_ann_values and allergy_concept.start > reaction_concept.start: continue else: - distance = calculate_word_distance(reaction_concept.start, reaction_concept.end, - allergy_concept.start, allergy_concept.end, - note) - log.debug(f"Calculated distance between reaction {reaction_concept.name} " - f"and allergen {allergy_concept.name}: {distance}") + distance = calculate_word_distance( + reaction_concept.start, reaction_concept.end, allergy_concept.start, allergy_concept.end, note + ) + log.debug( + f"Calculated distance between reaction {reaction_concept.name} " + f"and allergen {allergy_concept.name}: {distance}" + ) if distance == -1: - log.warning(f"Indices for {reaction_concept.name} or {allergy_concept.name} invalid: " - f"({reaction_concept.start}, {reaction_concept.end})" - f"({allergy_concept.start}, {allergy_concept.end})") + log.warning( + f"Indices for {reaction_concept.name} or {allergy_concept.name} invalid: " + f"({reaction_concept.start}, {reaction_concept.end})" + f"({allergy_concept.start}, {allergy_concept.end})" + ) continue if distance <= link_distance and distance < min_distance: @@ -569,8 +648,10 @@ def _link_reactions_to_allergens(concept_list: List[Concept], note: Note, link_d if nearest_allergy_concept is not None: nearest_allergy_concept.linked_concepts.append(reaction_concept) - log.debug(f"Linked reaction concept {reaction_concept.name} to " - f"allergen concept {nearest_allergy_concept.name}") + log.debug( + f"Linked reaction concept {reaction_concept.name} to " + f"allergen concept {nearest_allergy_concept.name}" + ) # Remove the linked REACTION concepts from the main list updated_concept_list = [concept for concept in concept_list if concept.category != Category.REACTION] @@ -592,17 +673,21 @@ def _convert_allergy_severity_to_code(concept: Concept) -> bool: log.warning(f"No severity annotation associated with ({concept.id} | {concept.name})") return False - log.debug(f"Linked severity concept ({concept.linked_concepts[-1].id} | {concept.linked_concepts[-1].name}) " - f"to allergen concept ({concept.id} | {concept.name}): valid meta model output") + log.debug( + f"Linked severity concept ({concept.linked_concepts[-1].id} | {concept.linked_concepts[-1].name}) " + f"to allergen concept ({concept.id} | {concept.name}): valid meta model output" + ) return True def _convert_allergy_type_to_code(self, concept: Concept) -> bool: # get the ALLERGYTYPE meta-annotation - allergy_type = [meta_ann for meta_ann in concept.meta if meta_ann.name == "allergytype"] + allergy_type = [meta_ann for meta_ann in concept.meta if meta_ann.name == "allergy_type"] if len(allergy_type) != 1: - log.warning(f"Unable to map allergy type code: allergytype meta-annotation " - f"not found for concept {concept.__str__()}") + log.warning( + f"Unable to map allergy type code: allergy_type meta-annotation " + f"not found for concept {concept.__str__()}" + ) return False else: allergy_type = allergy_type[0].value @@ -613,17 +698,22 @@ def _convert_allergy_type_to_code(self, concept: Concept) -> bool: # add resulting allergy type concept as to linked_concept if allergy_type_lookup_result is not None: - concept.linked_concepts.append(Concept(id=str(allergy_type_lookup_result[0]), - name=allergy_type_lookup_result[1], - category=Category.ALLERGY_TYPE)) - log.debug(f"Linked allergytype concept ({allergy_type_lookup_result[0]} | {allergy_type_lookup_result[1]})" - f" to allergen concept ({concept.id} | {concept.name}): valid meta model output + allergytype lookup") + concept.linked_concepts.append( + Concept( + id=str(allergy_type_lookup_result[0]), + name=allergy_type_lookup_result[1], + category=Category.ALLERGY_TYPE, + ) + ) + log.debug( + f"Linked allergy_type concept ({allergy_type_lookup_result[0]} | {allergy_type_lookup_result[1]})" + f" to allergen concept ({concept.id} | {concept.name}): valid meta model output + allergytype lookup" + ) else: log.warning(f"Allergen and adverse reaction type combination not found: {lookup_combination}") return True - def postprocess(self, concepts: List[Concept], note: Note) -> List[Concept]: # deepcopy so we still have reference to original list of concepts all_concepts = deepcopy(concepts) @@ -653,22 +743,27 @@ def convert_VTM_to_VMP_or_text(self, concepts: List[Concept]) -> List[Concept]: med_concepts_no_dose = [concept for concept in concepts if concept not in med_concepts_with_dose] # Create a temporary DataFrame to match vtmId, dose, and unit - temp_df = pd.DataFrame({'vtmId': [int(concept.id) for concept in med_concepts_with_dose], - 'dose': [float(concept.dosage.dose.value) for concept in med_concepts_with_dose], - 'unit': [concept.dosage.dose.unit for concept in med_concepts_with_dose]}) + temp_df = pd.DataFrame( + { + "vtmId": [int(concept.id) for concept in med_concepts_with_dose], + "dose": [float(concept.dosage.dose.value) for concept in med_concepts_with_dose], + "unit": [concept.dosage.dose.unit for concept in med_concepts_with_dose], + } + ) # Merge with the lookup df to get vmpId - merged_df = temp_df.merge(self.vtm_to_vmp_lookup, on=['vtmId', 'dose', 'unit'], how='left') + merged_df = temp_df.merge(self.vtm_to_vmp_lookup, on=["vtmId", "dose", "unit"], how="left") # Update id in the concepts list for index, concept in enumerate(med_concepts_with_dose): # Convert VTM to VMP id - vmp_id = merged_df.at[index, 'vmpId'] + vmp_id = merged_df.at[index, "vmpId"] if not pd.isna(vmp_id): log.debug( f"Converted ({concept.id} | {concept.name}) to " f"({int(vmp_id)} | {concept.name + ' ' + str(int(concept.dosage.dose.value)) + concept.dosage.dose.unit} " - f"tablets): valid extracted dosage + VMP lookup") + f"tablets): valid extracted dosage + VMP lookup" + ) concept.id = str(int(vmp_id)) concept.name += " " + str(int(concept.dosage.dose.value)) + str(concept.dosage.dose.unit) + " tablets" # If found VMP match change the dosage to 1 tablet @@ -679,7 +774,8 @@ def convert_VTM_to_VMP_or_text(self, concepts: List[Concept]) -> List[Concept]: lookup_result = self.vtm_to_text_lookup.get(int(concept.id)) if lookup_result is not None: log.debug( - f"Converted ({concept.id} | {concept.name}) to (None | {lookup_result}: no match to VMP dosage lookup)") + f"Converted ({concept.id} | {concept.name}) to (None | {lookup_result}: no match to VMP dosage lookup)" + ) concept.id = None concept.name = lookup_result @@ -687,19 +783,17 @@ def convert_VTM_to_VMP_or_text(self, concepts: List[Concept]) -> List[Concept]: for concept in med_concepts_no_dose: lookup_result = self.vtm_to_text_lookup.get(int(concept.id)) if lookup_result is not None: - log.debug( - f"Converted ({concept.id} | {concept.name}) to (None | {lookup_result}): no dosage detected") + log.debug(f"Converted ({concept.id} | {concept.name}) to (None | {lookup_result}): no dosage detected") concept.id = None concept.name = lookup_result return concepts - def __call__( self, note: Note, record_concepts: Optional[List[Concept]] = None, - dosage_extractor: Optional[DosageExtractor] = None + dosage_extractor: Optional[DosageExtractor] = None, ): if "preprocessor" not in self.config.disable: note = self.preprocess(note) diff --git a/src/miade/concept.py b/src/miade/concept.py index a866f0a..1e22036 100644 --- a/src/miade/concept.py +++ b/src/miade/concept.py @@ -31,7 +31,6 @@ def __init__( meta_anns: Optional[List[MetaAnnotations]] = None, debug_dict: Optional[Dict] = None, ): - self.name = name self.id = id self.category = category @@ -54,7 +53,9 @@ def from_entity(cls, entity: [Dict]): return Concept( id=entity["cui"], - name=entity["source_value"], # can also use detected_name which is spell checked but delimited by ~ e.g. liver~failure + name=entity[ + "source_value" + ], # can also use detected_name which is spell checked but delimited by ~ e.g. liver~failure category=None, start=entity["start"], end=entity["end"], @@ -72,11 +73,7 @@ def __hash__(self): return hash((self.id, self.name, self.category)) def __eq__(self, other): - return ( - self.id == other.id - and self.name == other.name - and self.category == other.category - ) + return self.id == other.id and self.name == other.name and self.category == other.category def __lt__(self, other): return int(self.id) < int(other.id) diff --git a/src/miade/core.py b/src/miade/core.py index f9e8bb5..2e63791 100644 --- a/src/miade/core.py +++ b/src/miade/core.py @@ -3,15 +3,14 @@ import yaml import logging -from negspacy.negation import Negex +from negspacy.negation import Negex # noqa: F401 from pathlib import Path from typing import List, Optional, Dict from .concept import Concept, Category from .note import Note -from .annotators import Annotator, ProblemsAnnotator, MedsAllergiesAnnotator +from .annotators import Annotator, ProblemsAnnotator, MedsAllergiesAnnotator # noqa: F401 from .dosageextractor import DosageExtractor -from .utils.metaannotationstypes import SubstanceCategory from .utils.miade_cat import MiADE_CAT from .utils.modelfactory import ModelFactory from .utils.annotatorconfig import AnnotatorConfig @@ -29,11 +28,15 @@ def create_annotator(name: str, model_factory: ModelFactory): """ name = name.lower() if name not in model_factory.models: - raise ValueError(f"MedCAT model for {name} does not exist: either not configured in config.yaml or " - f"missing from models directory") + raise ValueError( + f"MedCAT model for {name} does not exist: either not configured in config.yaml or " + f"missing from models directory" + ) if name in model_factory.annotators.keys(): - return model_factory.annotators[name](cat=model_factory.models.get(name), config=model_factory.configs.get(name)) + return model_factory.annotators[name]( + cat=model_factory.models.get(name), config=model_factory.configs.get(name) + ) else: log.warning(f"Annotator {name} does not exist, loading generic Annotator") return Annotator(model_factory.models[name]) @@ -48,6 +51,7 @@ class NoteProcessor: :param device (str) whether inference should be run on cpu or gpu - default "cpu" :param custom_annotators (List[Annotators]) List of custom annotators """ + def __init__( self, model_directory: Path, @@ -55,7 +59,7 @@ def __init__( log_level: int = logging.INFO, dosage_extractor_log_level: int = logging.INFO, device: str = "cpu", - custom_annotators: Optional[List[Annotator]] = None + custom_annotators: Optional[List[Annotator]] = None, ): logging.getLogger("miade").setLevel(log_level) logging.getLogger("miade.dosageextractor").setLevel(dosage_extractor_log_level) @@ -122,7 +126,7 @@ def _load_model_factory(self, custom_annotators: Optional[List[Annotator]] = Non continue mapped_models[name] = cat_model else: - log.warning(f"No model ids configured!") + log.warning("No model ids configured!") mapped_annotators = {} # {name: } @@ -140,7 +144,7 @@ def _load_model_factory(self, custom_annotators: Optional[List[Annotator]] = Non except AttributeError as e: log.warning(f"{annotator_string} not found: {e}") else: - log.warning(f"No annotators configured!") + log.warning("No annotators configured!") mapped_configs = {} if "general" in config_dict: @@ -152,13 +156,10 @@ def _load_model_factory(self, custom_annotators: Optional[List[Annotator]] = Non else: log.warning("No general settings configured, using default settings.") - model_factory_config = {"models": mapped_models, - "annotators": mapped_annotators, - "configs": mapped_configs} + model_factory_config = {"models": mapped_models, "annotators": mapped_annotators, "configs": mapped_configs} return ModelFactory(**model_factory_config) - def add_annotator(self, name: str) -> None: """ Adds annotators to processor @@ -167,7 +168,9 @@ def add_annotator(self, name: str) -> None: """ try: annotator = create_annotator(name, self.model_factory) - log.info(f"Added {type(annotator).__name__} to processor with config {self.model_factory.configs.get(name)}") + log.info( + f"Added {type(annotator).__name__} to processor with config {self.model_factory.configs.get(name)}" + ) except Exception as e: raise Exception(f"Error creating annotator: {e}") @@ -214,11 +217,9 @@ def process(self, note: Note, record_concepts: Optional[List[Concept]] = None) - return concepts - def get_concept_dicts(self, - note: Note, - filter_uncategorized: bool = True, - record_concepts: Optional[List[Concept]] = None - ) -> List[Dict]: + def get_concept_dicts( + self, note: Note, filter_uncategorized: bool = True, record_concepts: Optional[List[Concept]] = None + ) -> List[Dict]: """ Returns concepts in dictionary format :param note: (Note) note containing text to extract concepts from @@ -233,10 +234,12 @@ def get_concept_dicts(self, continue concept_dict = concept.__dict__ if concept.dosage is not None: - concept_dict["dosage"] = {"dose": concept.dosage.dose.dict() if concept.dosage.dose else None, - "duration": concept.dosage.duration.dict() if concept.dosage.duration else None, - "frequency": concept.dosage.frequency.dict() if concept.dosage.frequency else None, - "route": concept.dosage.route.dict() if concept.dosage.route else None} + concept_dict["dosage"] = { + "dose": concept.dosage.dose.dict() if concept.dosage.dose else None, + "duration": concept.dosage.duration.dict() if concept.dosage.duration else None, + "frequency": concept.dosage.frequency.dict() if concept.dosage.frequency else None, + "route": concept.dosage.route.dict() if concept.dosage.route else None, + } if concept.meta is not None: meta_anns = [] for meta in concept.meta: @@ -249,4 +252,3 @@ def get_concept_dicts(self, concept_list.append(concept_dict) return concept_list - diff --git a/src/miade/dosage.py b/src/miade/dosage.py index 1acd30c..24b48f8 100644 --- a/src/miade/dosage.py +++ b/src/miade/dosage.py @@ -62,9 +62,7 @@ class Route(BaseModel): code_system: Optional[str] = ROUTE_CODE_SYSTEM -def parse_dose( - text: str, quantities: List[str], units: List[str], results: Dict -) -> Optional[Dose]: +def parse_dose(text: str, quantities: List[str], units: List[str], results: Dict) -> Optional[Dose]: """ :param text: (str) string containing dose :param quantities: (list) list of quantity entities NER @@ -99,7 +97,7 @@ def parse_dose( else: try: quantity_dosage.value = float(quantities[0]) - except: + except ValueError: quantity_dosage.value = float(re.sub(r"[^\d.]+", "", quantities[0])) quantity_dosage.unit = units[0] elif len(quantities) == 2 and len(units) == 2: @@ -107,7 +105,7 @@ def parse_dose( try: quantity_dosage.low = float(quantities[0]) quantity_dosage.high = float(quantities[1]) - except: + except ValueError: quantity_dosage.low = float(re.sub(r"[^\d.]+", "", quantities[0])) quantity_dosage.high = float(re.sub(r"[^\d.]+", "", quantities[1])) if units[0] == units[1]: @@ -119,8 +117,7 @@ def parse_dose( # use caliber results as backup if results["units"] is not None: log.debug( - f"Inconclusive dose entities {quantities}, " - f"using lookup results {results['qty']} {results['units']}" + f"Inconclusive dose entities {quantities}, " f"using lookup results {results['qty']} {results['units']}" ) quantity_dosage.unit = results["units"] # only autofill 1 if non-quantitative units e.g. tab, cap, puff @@ -165,7 +162,7 @@ def parse_frequency(text: str, results: Dict) -> Optional[Frequency]: if results["freq"] is not None and results["time"] is not None: try: frequency_dosage.value = results["time"] / results["freq"] - except ZeroDivisionError as e: + except ZeroDivisionError: frequency_dosage.value = None # here i convert time to hours if not institution specified # (every X hrs as opposed to X times day) but it's arbitrary really... @@ -327,13 +324,9 @@ def from_doc(cls, doc: Doc, calculate: bool = True): # if duration not given in text could extract this from total dose if given if total_dose is not None and dose is not None and doc._.results["freq"]: if dose.value is not None: - daily_dose = float(dose.value) * ( - round(doc._.results["freq"] / doc._.results["time"]) - ) + daily_dose = float(dose.value) * (round(doc._.results["freq"] / doc._.results["time"])) elif dose.high is not None: - daily_dose = float(dose.high) * ( - round(doc._.results["freq"] / doc._.results["time"]) - ) + daily_dose = float(dose.high) * (round(doc._.results["freq"] / doc._.results["time"])) duration = parse_duration( text=duration_text, diff --git a/src/miade/dosageextractor.py b/src/miade/dosageextractor.py index 9d8e41a..d3269c6 100644 --- a/src/miade/dosageextractor.py +++ b/src/miade/dosageextractor.py @@ -5,10 +5,9 @@ from typing import Optional from .dosage import Dosage -from .drugdoseade.preprocessor import Preprocessor -from .drugdoseade.pattern_matcher import PatternMatcher -from .drugdoseade.entities_refiner import EntitiesRefiner - +from .drugdoseade.preprocessor import Preprocessor # noqa: F401 +from .drugdoseade.pattern_matcher import PatternMatcher # noqa: F401 +from .drugdoseade.entities_refiner import EntitiesRefiner # noqa: F401 log = logging.getLogger(__name__) @@ -46,17 +45,12 @@ def extract(self, text: str, calculate: bool = True) -> Optional[Dosage]: """ doc = self.dosage_extractor(text) - log.debug( - f"NER results: {[(e.text, e.label_, e._.total_dose) for e in doc.ents]}" - ) + log.debug(f"NER results: {[(e.text, e.label_, e._.total_dose) for e in doc.ents]}") log.debug(f"Lookup results: {doc._.results}") dosage = Dosage.from_doc(doc=doc, calculate=calculate) - if all( - v is None - for v in [dosage.dose, dosage.frequency, dosage.route, dosage.duration] - ): + if all(v is None for v in [dosage.dose, dosage.frequency, dosage.route, dosage.duration]): return None return dosage diff --git a/src/miade/drugdoseade/entities_refiner.py b/src/miade/drugdoseade/entities_refiner.py index 569e737..48a67c3 100644 --- a/src/miade/drugdoseade/entities_refiner.py +++ b/src/miade/drugdoseade/entities_refiner.py @@ -14,11 +14,7 @@ def EntitiesRefiner(doc): new_ents = [] for ind, ent in enumerate(doc.ents): # combine consecutive labels with the same tag - if ( - ent.label_ == "DURATION" - or ent.label_ == "FREQUENCY" - or ent.label_ == "DOSAGE" - ) and ind != 0: + if (ent.label_ == "DURATION" or ent.label_ == "FREQUENCY" or ent.label_ == "DOSAGE") and ind != 0: prev_ent = doc.ents[ind - 1] if prev_ent.label_ == ent.label_: new_ent = Span(doc, prev_ent.start, ent.end, label=ent.label) diff --git a/src/miade/drugdoseade/pattern_matcher.py b/src/miade/drugdoseade/pattern_matcher.py index 6243e47..316dee1 100644 --- a/src/miade/drugdoseade/pattern_matcher.py +++ b/src/miade/drugdoseade/pattern_matcher.py @@ -18,11 +18,7 @@ @spacy.registry.misc("patterns_lookup_table.v1") def create_patterns_dict(): patterns_data = pkgutil.get_data(__name__, "../data/patterns.csv") - patterns_dict = ( - pd.read_csv(io.BytesIO(patterns_data), index_col=0) - .squeeze("columns") - .T.to_dict() - ) + patterns_dict = pd.read_csv(io.BytesIO(patterns_data), index_col=0).squeeze("columns").T.to_dict() return patterns_dict @@ -67,9 +63,7 @@ def __call__(self, doc: Doc) -> Doc: # rule-based matching based on structure of dosage - HIE medication e.g. take 2 every day, 24 tablets expression = r"(?Pstart [\w\s,-]+ ), (?P\d+) (?P[a-z]+ )?$" for match in re.finditer(expression, dose_string): - dose_string = match.group( - "dose_string" - ) # remove total dose component for lookup + dose_string = match.group("dose_string") # remove total dose component for lookup start, end = match.span("total_dose") total_dose_span = doc.char_span(start, end, alignment_mode="contract") total_dose_span.label_ = "DOSAGE" @@ -81,9 +75,7 @@ def __call__(self, doc: Doc) -> Doc: unit_span = doc.char_span(start, end, alignment_mode="contract") unit_span.label_ = "FORM" unit_span._.total_dose = True - doc._.results[ - "units" - ] = unit_span.text # set unit in results dict as well + doc._.results["units"] = unit_span.text # set unit in results dict as well new_entities.append(unit_span) # lookup patterns from CALIBERdrugdose - returns dosage results in doc._.results attribute diff --git a/src/miade/drugdoseade/preprocessor.py b/src/miade/drugdoseade/preprocessor.py index dc92297..0168687 100644 --- a/src/miade/drugdoseade/preprocessor.py +++ b/src/miade/drugdoseade/preprocessor.py @@ -81,8 +81,7 @@ def __call__(self, doc: Doc) -> Doc: # remove numbers relating to strength of med e.g. aspirin 200mg tablets... processed_text = re.sub( - r" (\d+\.?\d*) (mg|ml|g|mcg|microgram|gram|%)" - r"(\s|/)(tab|cap|gel|cream|dose|pessaries)", + r" (\d+\.?\d*) (mg|ml|g|mcg|microgram|gram|%)" r"(\s|/)(tab|cap|gel|cream|dose|pessaries)", "", processed_text, ) @@ -102,9 +101,7 @@ def __call__(self, doc: Doc) -> Doc: if replacement == " ": log.debug(f"Removed multiword match '{words}'") else: - log.debug( - f"Replaced multiword match '{words}' with '{replacement}'" - ) + log.debug(f"Replaced multiword match '{words}' with '{replacement}'") processed_text = new_text # numbers replace 2 diff --git a/src/miade/drugdoseade/utils.py b/src/miade/drugdoseade/utils.py index 7f846b8..8f0ac7c 100644 --- a/src/miade/drugdoseade/utils.py +++ b/src/miade/drugdoseade/utils.py @@ -92,9 +92,7 @@ def numbers_replace(text): text, ) # 3 weeks... - text = re.sub( - r" ([\d.]+) (week) ", lambda m: " {:g} days ".format(int(m.group(1)) * 7), text - ) + text = re.sub(r" ([\d.]+) (week) ", lambda m: " {:g} days ".format(int(m.group(1)) * 7), text) # 3 months ... NB assume 30 days in a month text = re.sub( r" ([\d.]+) (month) ", diff --git a/src/miade/metaannotations.py b/src/miade/metaannotations.py index 5226f31..9bdb16a 100644 --- a/src/miade/metaannotations.py +++ b/src/miade/metaannotations.py @@ -1,8 +1,16 @@ from typing import Optional from pydantic import BaseModel, validator - -from .utils.metaannotationstypes import * - +from enum import Enum + +from .utils.metaannotationstypes import ( + Presence, + Relevance, + Laterality, + ReactionPos, + SubstanceCategory, + AllergyType, + Severity, +) META_ANNS_DICT = { "presence": Presence, @@ -11,7 +19,7 @@ "substance_category": SubstanceCategory, "reaction_pos": ReactionPos, "allergy_type": AllergyType, - "severity": Severity + "severity": Severity, } @@ -20,7 +28,7 @@ class MetaAnnotations(BaseModel): value: Enum confidence: Optional[float] - @validator('value', pre=True) + @validator("value", pre=True) def validate_value(cls, value, values): enum_dict = META_ANNS_DICT if isinstance(value, str): @@ -36,8 +44,4 @@ def validate_value(cls, value, values): return value def __eq__(self, other): - return ( - self.name == other.name - and self.value == other.value - ) - + return self.name == other.name and self.value == other.value diff --git a/src/miade/model_builders/cdbbuilder.py b/src/miade/model_builders/cdbbuilder.py index 840c52d..33f11c0 100644 --- a/src/miade/model_builders/cdbbuilder.py +++ b/src/miade/model_builders/cdbbuilder.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from shutil import copy from typing import List, Optional @@ -60,24 +59,18 @@ def preprocess_snomed(self, output_dir: Path = Path.cwd()) -> Path: print("Exporting preprocessed SNOMED to csv...") if self.snomed_subset_path is not None: - snomed_subset = pd.read_csv( - str(self.snomed_subset_path), header=0, dtype={"cui": object} - ) + snomed_subset = pd.read_csv(str(self.snomed_subset_path), header=0, dtype={"cui": object}) else: snomed_subset = None if self.snomed_exclusions_path is not None: - snomed_exclusions = pd.read_csv( - str(self.snomed_exclusions_path), sep="\n", header=None - ) + snomed_exclusions = pd.read_csv(str(self.snomed_exclusions_path), sep="\n", header=None) snomed_exclusions.columns = ["cui"] else: snomed_exclusions = None output_file = output_dir / Path("preprocessed_snomed.csv") - df = self.snomed.to_concept_df( - subset_list=snomed_subset, exclusion_list=snomed_exclusions - ) + df = self.snomed.to_concept_df(subset_list=snomed_subset, exclusion_list=snomed_exclusions) df.to_csv(str(output_file), index=False) return output_file @@ -105,7 +98,7 @@ def preprocess(self): if self.elg_data_path: self.vocab_files.append(str(self.preprocess_elg(self.temp_dir))) if self.custom_data_paths: - string_paths = [str(path) for path in self.custom_data_paths] + string_paths = [str(path) for path in self.custom_data_paths] self.vocab_files.extend(string_paths) def create_cdb(self) -> CDB: diff --git a/src/miade/model_builders/preprocess_elg.py b/src/miade/model_builders/preprocess_elg.py index c3323e9..5127aeb 100644 --- a/src/miade/model_builders/preprocess_elg.py +++ b/src/miade/model_builders/preprocess_elg.py @@ -1,4 +1,4 @@ -from pandas import DataFrame, read_csv, isnull +from pandas import DataFrame, read_csv from pathlib import Path diff --git a/src/miade/model_builders/preprocess_snomeduk.py b/src/miade/model_builders/preprocess_snomeduk.py index 275fd82..11242e8 100644 --- a/src/miade/model_builders/preprocess_snomeduk.py +++ b/src/miade/model_builders/preprocess_snomeduk.py @@ -2,7 +2,6 @@ with a few minor changes adapted to reading snomed UK folder paths""" import os -import json import re import hashlib import pandas as pd @@ -11,9 +10,7 @@ def parse_file(filename, first_row_header=True, columns=None): with open(filename, encoding="utf-8") as f: entities = [[n.strip() for n in line.split("\t")] for line in f] - return pd.DataFrame( - entities[1:], columns=entities[0] if first_row_header else columns - ) + return pd.DataFrame(entities[1:], columns=entities[0] if first_row_header else columns) class Snomed: @@ -64,9 +61,7 @@ def to_concept_df(self, subset_list=None, exclusion_list=None): if uk_code is None or snomed_v is None: raise FileNotFoundError("Could not find file matching pattern") - int_terms = parse_file( - f"{contents_path}/sct2_Concept_{uk_code}Snapshot_{snomed_v}_{snomed_release}.txt" - ) + int_terms = parse_file(f"{contents_path}/sct2_Concept_{uk_code}Snapshot_{snomed_v}_{snomed_release}.txt") active_terms = int_terms[int_terms.active == "1"] del int_terms @@ -86,16 +81,10 @@ def to_concept_df(self, subset_list=None, exclusion_list=None): del active_terms del active_descs - active_with_primary_desc = _[ - _["typeId"] == "900000000000003001" - ] # active description - active_with_synonym_desc = _[ - _["typeId"] == "900000000000013009" - ] # active synonym + active_with_primary_desc = _[_["typeId"] == "900000000000003001"] # active description + active_with_synonym_desc = _[_["typeId"] == "900000000000013009"] # active synonym del _ - active_with_all_desc = pd.concat( - [active_with_primary_desc, active_with_synonym_desc] - ) + active_with_all_desc = pd.concat([active_with_primary_desc, active_with_synonym_desc]) active_snomed_df = active_with_all_desc[["id_x", "term", "typeId"]] del active_with_all_desc @@ -110,12 +99,8 @@ def to_concept_df(self, subset_list=None, exclusion_list=None): ) active_snomed_df.reset_index(drop=True, inplace=True) - temp_df = active_snomed_df[active_snomed_df["name_status"] == "P"][ - ["cui", "name"] - ] - temp_df["description_type_ids"] = temp_df["name"].str.extract( - r"\((\w+\s?.?\s?\w+.?\w+.?\w+.?)\)$" - ) + temp_df = active_snomed_df[active_snomed_df["name_status"] == "P"][["cui", "name"]] + temp_df["description_type_ids"] = temp_df["name"].str.extract(r"\((\w+\s?.?\s?\w+.?\w+.?\w+.?)\)$") active_snomed_df = pd.merge( active_snomed_df, temp_df.loc[:, ["cui", "description_type_ids"]], @@ -129,10 +114,7 @@ def to_concept_df(self, subset_list=None, exclusion_list=None): active_snomed_df["type_ids"] = ( active_snomed_df["description_type_ids"] .dropna() - .apply( - lambda x: int(hashlib.sha256(x.encode("utf-8")).hexdigest(), 16) - % 10**8 - ) + .apply(lambda x: int(hashlib.sha256(x.encode("utf-8")).hexdigest(), 16) % 10**8) ) df2merge.append(active_snomed_df) @@ -184,8 +166,6 @@ def list_all_relationships(self): active_relat = int_relat[int_relat.active == "1"] del int_relat - all_rela.extend( - [relationship for relationship in active_relat["typeId"].unique()] - ) + all_rela.extend([relationship for relationship in active_relat["typeId"].unique()]) return all_rela diff --git a/src/miade/model_builders/vocabbuilder.py b/src/miade/model_builders/vocabbuilder.py index d15ff90..643f596 100644 --- a/src/miade/model_builders/vocabbuilder.py +++ b/src/miade/model_builders/vocabbuilder.py @@ -29,9 +29,7 @@ def create_new_vocab( make_vocab = MakeVocab(cdb=cdb, config=config) make_vocab.make(training_data_list, out_folder=str(output_dir)) - make_vocab.add_vectors( - in_path=str(output_dir / "data.txt"), unigram_table_size=unigram_table_size - ) + make_vocab.add_vectors(in_path=str(output_dir / "data.txt"), unigram_table_size=unigram_table_size) self.vocab = make_vocab.vocab return self.vocab @@ -43,11 +41,6 @@ def update_vocab(self) -> Vocab: self.vocab.make_unigram_table() return self.vocab - def make_model_pack(self, - cdb: CDB, - save_name: str, - output_dir: Path = Path.cwd() - ) -> None: + def make_model_pack(self, cdb: CDB, save_name: str, output_dir: Path = Path.cwd()) -> None: cat = CAT(cdb=cdb, config=cdb.config, vocab=self.vocab) cat.create_model_pack(str(output_dir), save_name) - diff --git a/src/miade/note.py b/src/miade/note.py index 7877820..3c98815 100644 --- a/src/miade/note.py +++ b/src/miade/note.py @@ -14,7 +14,14 @@ def load_regex_config_mappings(filename: str) -> Dict: regex_config = pkgutil.get_data(__name__, filename) - data = pd.read_csv(io.BytesIO(regex_config),index_col=0,).squeeze("columns").T.to_dict() + data = ( + pd.read_csv( + io.BytesIO(regex_config), + index_col=0, + ) + .squeeze("columns") + .T.to_dict() + ) regex_lookup = {} for paragraph, regex in data.items(): @@ -41,16 +48,16 @@ def __init__(self, text: str, regex_config_path: str = "./data/regex_para_chunk. def clean_text(self) -> None: # Replace all types of spaces with a single normal space, preserving "\n" - self.text = re.sub(r'(?:(?!\n)\s)+', ' ', self.text) + self.text = re.sub(r"(?:(?!\n)\s)+", " ", self.text) # Remove en dashes that are not between two numbers - self.text = re.sub(r'(? None: paragraphs = re.split(r"\n\n+", self.text) @@ -61,7 +68,7 @@ def get_paragraphs(self) -> None: paragraph_type = ParagraphType.prose # Use re.search to find everything before first \n - match = re.search(r'^(.*?)(?:\n|$)([\s\S]*)', text) + match = re.search(r"^(.*?)(?:\n|$)([\s\S]*)", text) # Check if a match is found if match: diff --git a/src/miade/paragraph.py b/src/miade/paragraph.py index b134d04..62bedeb 100644 --- a/src/miade/paragraph.py +++ b/src/miade/paragraph.py @@ -26,8 +26,4 @@ def __str__(self): return str(self.__dict__) def __eq__(self, other): - return ( - self.type == other.type - and self.start == other.start - and self.end == other.end - ) + return self.type == other.type and self.start == other.start and self.end == other.end diff --git a/src/miade/utils/annotatorconfig.py b/src/miade/utils/annotatorconfig.py index 1102ff5..50cd89b 100644 --- a/src/miade/utils/annotatorconfig.py +++ b/src/miade/utils/annotatorconfig.py @@ -3,5 +3,6 @@ class AnnotatorConfig(BaseModel): + lookup_data_path: Optional[str] = "./lookup_data/" negation_detection: Optional[str] = "negex" - disable: List[str] = [] \ No newline at end of file + disable: List[str] = [] diff --git a/src/miade/utils/logger.py b/src/miade/utils/logger.py index ad3f63a..a1cad54 100644 --- a/src/miade/utils/logger.py +++ b/src/miade/utils/logger.py @@ -5,9 +5,7 @@ def add_handlers(log): if len(log.handlers) == 0: - formatter = logging.Formatter( - fmt="[%(asctime)s] [%(levelname)s] %(name)s.%(funcName)s(): %(message)s" - ) + formatter = logging.Formatter(fmt="[%(asctime)s] [%(levelname)s] %(name)s.%(funcName)s(): %(message)s") fh = logging.FileHandler("miade.log") ch = logging.StreamHandler() diff --git a/src/miade/utils/miade_cat.py b/src/miade/utils/miade_cat.py index 7c1f698..aef9190 100644 --- a/src/miade/utils/miade_cat.py +++ b/src/miade/utils/miade_cat.py @@ -3,9 +3,9 @@ import pandas as pd from copy import deepcopy -from typing import Union, List, Tuple, Optional, Dict, Iterable, Set +from typing import List, Tuple, Optional, Dict, Set -from tqdm.autonotebook import tqdm, trange +from tqdm.autonotebook import trange from spacy.tokens import Span, Doc from medcat.cat import CAT @@ -17,6 +17,7 @@ logger = logging.getLogger("cat") + class MiADE_CAT(CAT): """Experimental - overriding medcat write out function - more control over spacy pipeline: add negex results""" @@ -64,10 +65,7 @@ def _doc_to_out( out_ent["pretty_name"] = self.cdb.get_name(cui) out_ent["cui"] = cui out_ent["type_ids"] = list(self.cdb.cui2type_ids.get(cui, "")) - out_ent["types"] = [ - self.cdb.addl_info["type_id2name"].get(tui, "") - for tui in out_ent["type_ids"] - ] + out_ent["types"] = [self.cdb.addl_info["type_id2name"].get(tui, "") for tui in out_ent["type_ids"]] out_ent["source_value"] = ent.text out_ent["detected_name"] = str(ent._.detected_name) out_ent["acc"] = float(ent._.context_similarity) @@ -76,9 +74,7 @@ def _doc_to_out( out_ent["end"] = ent.end_char for addl in addl_info: tmp = self.cdb.addl_info.get(addl, {}).get(cui, []) - out_ent[addl.split("2")[-1]] = ( - list(tmp) if type(tmp) == set else tmp - ) + out_ent[addl.split("2")[-1]] = list(tmp) if type(tmp) == set else tmp out_ent["id"] = ent._.id out_ent["meta_anns"] = {} @@ -87,12 +83,8 @@ def _doc_to_out( out_ent["end_tkn"] = ent.end if context_left > 0 and context_right > 0: - out_ent["context_left"] = doc_tokens[ - max(ent.start - context_left, 0) : ent.start - ] - out_ent["context_right"] = doc_tokens[ - ent.end : min(ent.end + context_right, len(doc_tokens)) - ] + out_ent["context_left"] = doc_tokens[max(ent.start - context_left, 0) : ent.start] + out_ent["context_right"] = doc_tokens[ent.end : min(ent.end + context_right, len(doc_tokens))] out_ent["context_center"] = doc_tokens[ent.start : ent.end] if hasattr(ent._, "meta_anns") and ent._.meta_anns: @@ -105,10 +97,7 @@ def _doc_to_out( else: out["entities"][ent._.id] = cui - if ( - cnf_annotation_output.get("include_text_in_output", False) - or out_with_text - ): + if cnf_annotation_output.get("include_text_in_output", False) or out_with_text: out["text"] = doc.text return out @@ -132,7 +121,6 @@ def train_supervised( checkpoint: Optional[Checkpoint] = None, is_resumed: bool = False, ) -> Tuple: - checkpoint = self._init_ckpts(is_resumed, checkpoint) # Backup filters @@ -149,9 +137,7 @@ def train_supervised( test_set = data train_set = data else: - train_set, test_set, _, _ = make_mc_train_test( - data, self.cdb, test_size=test_size - ) + train_set, test_set, _, _ = make_mc_train_test(data, self.cdb, test_size=test_size) if print_stats > 0: fp, fn, tp, p, r, f1, cui_counts, examples = self._print_stats( @@ -184,9 +170,7 @@ def train_supervised( self.unlink_concept_name(ann["cui"], ann["value"]) latest_trained_step = checkpoint.count if checkpoint is not None else 0 - current_epoch, current_project, current_document = self._get_training_start( - train_set, latest_trained_step - ) + current_epoch, current_project, current_document = self._get_training_start(train_set, latest_trained_step) for epoch in trange( current_epoch, @@ -221,9 +205,7 @@ def train_supervised( ) if project_filter: - filters["cuis"] = intersect_nonempty_set( - project_filter, filters["cuis"] - ) + filters["cuis"] = intersect_nonempty_set(project_filter, filters["cuis"]) for idx_doc in trange( current_document, @@ -243,9 +225,7 @@ def train_supervised( cui = ann["cui"] start = ann["start"] end = ann["end"] - spacy_entity = tkns_from_doc( - spacy_doc=spacy_doc, start=start, end=end - ) + spacy_entity = tkns_from_doc(spacy_doc=spacy_doc, start=start, end=end) deleted = ann.get("deleted", False) self.add_and_train_concept( cui=cui, @@ -288,9 +268,7 @@ def train_supervised( name = synth_data.name.values[i] start = synth_data.start.values[i] end = synth_data.end.values[i] - spacy_entity = tkns_from_doc( - spacy_doc=spacy_doc, start=start, end=end - ) + spacy_entity = tkns_from_doc(spacy_doc=spacy_doc, start=start, end=end) self.add_and_train_concept( cui=cui, name=name, diff --git a/src/miade/utils/miade_meta_cat.py b/src/miade/utils/miade_meta_cat.py index 996ca99..487c153 100644 --- a/src/miade/utils/miade_meta_cat.py +++ b/src/miade/utils/miade_meta_cat.py @@ -18,6 +18,7 @@ logger = logging.getLogger("meta_cat") + # Hacky as hell, just for the dashboard, NOT permanent solution - will not merge with main branch def create_batch_piped_data(data: List, start_ind: int, end_ind: int, device: torch.device, pad_id: int) -> Tuple: """Creates a batch given data and start/end that denote batch size, will also add @@ -41,7 +42,7 @@ def create_batch_piped_data(data: List, start_ind: int, end_ind: int, device: to Center positions for the data """ max_seq_len = max([len(x[0]) for x in data]) - x = [x[0][0:max_seq_len] + [pad_id]*max(0, max_seq_len - len(x[0])) for x in data[start_ind:end_ind]] + x = [x[0][0:max_seq_len] + [pad_id] * max(0, max_seq_len - len(x[0])) for x in data[start_ind:end_ind]] cpos = [x[1] for x in data[start_ind:end_ind]] y = None if len(data[0]) == 3: @@ -54,17 +55,17 @@ def create_batch_piped_data(data: List, start_ind: int, end_ind: int, device: to return x, cpos, y -def print_report(epoch: int, running_loss: List, all_logits: List, y: Any, name: str = 'Train') -> None: - r''' Prints some basic stats during training +def print_report(epoch: int, running_loss: List, all_logits: List, y: Any, name: str = "Train") -> None: + r"""Prints some basic stats during training Args: epoch running_loss all_logits y name - ''' + """ if all_logits: - print(f'Epoch: {epoch} ' + "*"*50 + f" {name}") + print(f"Epoch: {epoch} " + "*" * 50 + f" {name}") print(classification_report(y, np.argmax(np.concatenate(all_logits, axis=0), axis=1))) @@ -75,11 +76,11 @@ def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: T data config """ - device = torch.device(config.general['device']) # Create a torch device - batch_size_eval = config.general['batch_size_eval'] - pad_id = config.model['padding_idx'] - ignore_cpos = config.model['ignore_cpos'] - class_weights = config.train['class_weights'] + device = torch.device(config.general["device"]) # Create a torch device + batch_size_eval = config.general["batch_size_eval"] + pad_id = config.model["padding_idx"] + ignore_cpos = config.model["ignore_cpos"] + class_weights = config.train["class_weights"] if class_weights is not None: class_weights = torch.FloatTensor(class_weights).to(device) @@ -96,8 +97,9 @@ def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: T with torch.no_grad(): for i in range(num_batches): - x, cpos, y = create_batch_piped_data(data, i * batch_size_eval, (i + 1) * batch_size_eval, - device=device, pad_id=pad_id) + x, cpos, y = create_batch_piped_data( + data, i * batch_size_eval, (i + 1) * batch_size_eval, device=device, pad_id=pad_id + ) logits = model(x, cpos, ignore_cpos=ignore_cpos) loss = criterion(logits, y) @@ -105,39 +107,46 @@ def eval_model(model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: T running_loss.append(loss.item()) all_logits.append(logits.detach().cpu().numpy()) - print_report(0, running_loss, all_logits, y=y_eval, name='Eval') + print_report(0, running_loss, all_logits, y=y_eval, name="Eval") - score_average = config.train['score_average'] + score_average = config.train["score_average"] predictions = np.argmax(np.concatenate(all_logits, axis=0), axis=1) precision, recall, f1, support = precision_recall_fscore_support(y_eval, predictions, average=score_average) - labels = [name for (name, _) in sorted(config.general['category_value2id'].items(), key=lambda x: x[1])] + labels = [name for (name, _) in sorted(config.general["category_value2id"].items(), key=lambda x: x[1])] confusion = pd.DataFrame( - data=confusion_matrix(y_eval, predictions, ), + data=confusion_matrix( + y_eval, + predictions, + ), columns=["true " + label for label in labels], index=["predicted " + label for label in labels], ) - examples: Dict = {'FP': {}, 'FN': {}, 'TP': {}} - id2category_value = {v: k for k, v in config.general['category_value2id'].items()} + examples: Dict = {"FP": {}, "FN": {}, "TP": {}} + id2category_value = {v: k for k, v in config.general["category_value2id"].items()} for i, p in enumerate(predictions): y = id2category_value[y_eval[i]] p = id2category_value[p] c = data[i][1] tkns = data[i][0] assert tokenizer.hf_tokenizers is not None - text = tokenizer.hf_tokenizers.decode(tkns[0:c]) + " <<" + tokenizer.hf_tokenizers.decode( - tkns[c:c + 1]).strip() + ">> " + \ - tokenizer.hf_tokenizers.decode(tkns[c + 1:]) + text = ( + tokenizer.hf_tokenizers.decode(tkns[0:c]) + + " <<" + + tokenizer.hf_tokenizers.decode(tkns[c : c + 1]).strip() + + ">> " + + tokenizer.hf_tokenizers.decode(tkns[c + 1 :]) + ) info = "Predicted: {}, True: {}".format(p, y) if p != y: # We made a mistake - examples['FN'][y] = examples['FN'].get(y, []) + [(info, text)] - examples['FP'][p] = examples['FP'].get(p, []) + [(info, text)] + examples["FN"][y] = examples["FN"].get(y, []) + [(info, text)] + examples["FP"][p] = examples["FP"].get(p, []) + [(info, text)] else: - examples['TP'][y] = examples['TP'].get(y, []) + [(info, text)] + examples["TP"][y] = examples["TP"].get(y, []) + [(info, text)] - return {'precision': precision, 'recall': recall, 'f1': f1, 'examples': examples, 'confusion matrix': confusion} + return {"precision": precision, "recall": recall, "f1": f1, "examples": examples, "confusion matrix": confusion} def prepare_from_miade_csv( @@ -181,11 +190,7 @@ def prepare_from_miade_csv( e_ind = p_ind ln = e_ind - s_ind - tkns = ( - tkns[:cpos] - + tokenizer(replace_center)["input_ids"] - + tkns[cpos + ln + 1 :] - ) + tkns = tkns[:cpos] + tokenizer(replace_center)["input_ids"] + tkns[cpos + ln + 1 :] value = data[category_name].values[i] sample = [tkns, cpos, value] @@ -222,9 +227,7 @@ def train( # Create directories if they don't exist if t_config["auto_save_model"]: if save_dir_path is None: - raise Exception( - "The `save_dir_path` argument is required if `aut_save_model` is `True` in the config" - ) + raise Exception("The `save_dir_path` argument is required if `aut_save_model` is `True` in the config") else: os.makedirs(save_dir_path, exist_ok=True) @@ -276,9 +279,7 @@ def train( g_config["category_value2id"] = category_value2id else: # We already have everything, just get the data - data, _ = encode_category_values( - data, existing_category_value2id=category_value2id - ) + data, _ = encode_category_values(data, existing_category_value2id=category_value2id) # Make sure the config number of classes is the same as the one found in the data if len(category_value2id) != self.config.model["nclasses"]: @@ -287,22 +288,16 @@ def train( self.config.model["nclasses"], len(category_value2id) ) ) - logger.warning( - "Auto-setting the nclasses value in config and rebuilding the model." - ) + logger.warning("Auto-setting the nclasses value in config and rebuilding the model.") self.config.model["nclasses"] = len(category_value2id) self.model = self.get_model(embeddings=self.embeddings) - report = train_model( - self.model, data=data, config=self.config, save_dir_path=save_dir_path - ) + report = train_model(self.model, data=data, config=self.config, save_dir_path=save_dir_path) # If autosave, then load the best model here if t_config["auto_save_model"]: if save_dir_path is None: - raise Exception( - "The `save_dir_path` argument is required if `aut_save_model` is `True` in the config" - ) + raise Exception("The `save_dir_path` argument is required if `aut_save_model` is `True` in the config") else: path = os.path.join(save_dir_path, "model.dat") device = torch.device(g_config["device"]) @@ -315,31 +310,35 @@ def train( return report def eval(self, json_path: str) -> Dict: - """Evaluate from json. - - """ + """Evaluate from json.""" g_config = self.config.general t_config = self.config.train - with open(json_path, 'r') as f: + with open(json_path, "r") as f: data_loaded: Dict = json.load(f) # Prepare the data assert self.tokenizer is not None - data = prepare_from_json(data_loaded, g_config['cntx_left'], g_config['cntx_right'], self.tokenizer, - cui_filter=t_config['cui_filter'], - replace_center=g_config['replace_center'], prerequisites=t_config['prerequisites'], - lowercase=g_config['lowercase']) + data = prepare_from_json( + data_loaded, + g_config["cntx_left"], + g_config["cntx_right"], + self.tokenizer, + cui_filter=t_config["cui_filter"], + replace_center=g_config["replace_center"], + prerequisites=t_config["prerequisites"], + lowercase=g_config["lowercase"], + ) # Check is the name there - category_name = g_config['category_name'] + category_name = g_config["category_name"] if category_name not in data: raise Exception("The category name does not exist in this json file.") data = data[category_name] # We already have everything, just get the data - category_value2id = g_config['category_value2id'] + category_value2id = g_config["category_value2id"] data, _ = encode_category_values(data, existing_category_value2id=category_value2id) # Run evaluation diff --git a/src/miade/utils/modelfactory.py b/src/miade/utils/modelfactory.py index abe0100..bbeb5a1 100644 --- a/src/miade/utils/modelfactory.py +++ b/src/miade/utils/modelfactory.py @@ -11,7 +11,7 @@ class ModelFactory(BaseModel): annotators: Dict[str, Type[Annotator]] configs: Dict[str, AnnotatorConfig] - @validator('annotators') + @validator("annotators") def validate_annotators(cls, annotators): for annotator_name, annotator_class in annotators.items(): if not issubclass(annotator_class, Annotator): @@ -19,4 +19,4 @@ def validate_annotators(cls, annotators): return annotators class Config: - arbitrary_types_allowed = True \ No newline at end of file + arbitrary_types_allowed = True diff --git a/src/scripts/build_model_pack.py b/src/scripts/build_model_pack.py index b92be9f..b234e75 100644 --- a/src/scripts/build_model_pack.py +++ b/src/scripts/build_model_pack.py @@ -19,11 +19,8 @@ def build_model_pack( unigram_table_size: int, output_dir: Path, ): - # TODO: option to input list of concept csv files - cdb_builder = CDBBuilder( - snomed_data_path=snomed_data_path, fdb_data_path=fdb_data_path, config=config - ) + cdb_builder = CDBBuilder(snomed_data_path=snomed_data_path, fdb_data_path=fdb_data_path, config=config) cdb_builder.preprocess_snomed(output_dir=output_dir) cdb = cdb_builder.create_cdb(["preprocessed_snomed.csv"]) @@ -43,7 +40,6 @@ def build_model_pack( if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument("config_file") args = parser.parse_args() @@ -52,9 +48,7 @@ def build_model_pack( with open(config_file, "r") as stream: config = yaml.safe_load(stream) - with open( - Path(config["unsupervised_training_data_file"]), "r", encoding="utf-8" - ) as training_data: + with open(Path(config["unsupervised_training_data_file"]), "r", encoding="utf-8") as training_data: training_data_list = [line.strip() for line in training_data] # Load MedCAT configuration diff --git a/src/scripts/miade.py b/src/scripts/miade.py index 80f5e92..edcd851 100644 --- a/src/scripts/miade.py +++ b/src/scripts/miade.py @@ -83,11 +83,7 @@ def build_model_pack( cat.config.version["ontology"] = ontology current_date = datetime.datetime.now().strftime("%b_%Y") - name = ( - f"miade_{tag}_blank_modelpack_{current_date}" - if tag is not None - else f"miade_blank_modelpack_{current_date}" - ) + name = f"miade_{tag}_blank_modelpack_{current_date}" if tag is not None else f"miade_blank_modelpack_{current_date}" cat.create_model_pack(str(output), name) log.info(f"Saved model pack at {output}/{name}_{cat.config.version['id']}") @@ -123,9 +119,7 @@ def train( if checkpoint: log.info(f"Checkpoint steps configured to {checkpoint}") cat.config.general["checkpoint"]["steps"] = checkpoint - cat.config.general["checkpoint"]["output_dir"] = os.path.join( - Path.cwd(), "checkpoints" - ) + cat.config.general["checkpoint"]["output_dir"] = os.path.join(Path.cwd(), "checkpoints") cat.train(training_data) @@ -246,7 +240,7 @@ def create_bbpe_tokenizer( data.append(tokenizer.encode(line).tokens) step += 1 - log.info(f"Started training word2vec model with tokenized text...") + log.info("Started training word2vec model with tokenized text...") w2v = Word2Vec(data, vector_size=300, min_count=1) log.info(f"Creating embeddings matrix, vocab size {tokenizer.get_vocab_size()}") @@ -277,19 +271,17 @@ def create_metacats( ): log.info(f"Loading tokenizer from {tokenizer_path}/...") tokenizer = TokenizerWrapperBPE.load(str(tokenizer_path)) - log.info(f"Loading embeddings from embeddings.npy...") + log.info("Loading embeddings from embeddings.npy...") embeddings = np.load(str(os.path.join(tokenizer_path, "embeddings.npy"))) assert len(embeddings) == tokenizer.get_size(), ( - f"Tokenizer and embeddings not the same size {len(embeddings)}, " - f"{tokenizer.get_size()}" + f"Tokenizer and embeddings not the same size {len(embeddings)}, " f"{tokenizer.get_size()}" ) metacat = MetaCAT(tokenizer=tokenizer, embeddings=embeddings) for category in category_names: - metacat.config.general[ - 'description'] = f"MiADE blank {category} MetaCAT model" - metacat.config.general['category_name'] = category + metacat.config.general["description"] = f"MiADE blank {category} MetaCAT model" + metacat.config.general["category_name"] = category metacat.save(str(os.path.join(output, f"meta_{category}"))) log.info(f"Saved meta_{category} at {output}") @@ -310,9 +302,7 @@ def train_metacat( description = f"MiADE meta-annotations model {model_path.stem} trained on {annotation_path.stem}" mc.config.general["description"] = description - mc.config.general["category_name"] = model_path.stem.split("_")[ - -1 - ] # meta folder name should be e.g. meta_presence + mc.config.general["category_name"] = model_path.stem.split("_")[-1] # meta folder name should be e.g. meta_presence mc.config.general["cntx_left"] = cntx_left mc.config.general["cntx_right"] = cntx_right mc.config.train["nepochs"] = nepochs @@ -361,17 +351,11 @@ def add_metacat_models( stats[categories[-1]] = report log.info(f"Creating CAT with MetaCAT models {categories}...") - cat_w_meta = CAT( - cdb=cat.cdb, vocab=cat.vocab, config=cat.config, meta_cats=meta_cats - ) + cat_w_meta = CAT(cdb=cat.cdb, vocab=cat.vocab, config=cat.config, meta_cats=meta_cats) if description is None: log.info("Automatically populating description field of model card...") - description = ( - cat.config.version["description"] - + " | Packaged with MetaCAT model(s) " - + ", ".join(categories) - ) + description = cat.config.version["description"] + " | Packaged with MetaCAT model(s) " + ", ".join(categories) cat.config.version["description"] = description for category in categories: diff --git a/streamlit_app/app.py b/streamlit_app/app.py index e43c29e..de29ed0 100644 --- a/streamlit_app/app.py +++ b/streamlit_app/app.py @@ -1,3 +1,5 @@ +# ruff: noqa: F811 + import os import json from time import sleep @@ -21,7 +23,13 @@ from medcat.cat import CAT from miade.utils.miade_meta_cat import MiADE_MetaCAT -from utils import * +from utils import ( + load_documents, + load_annotations, + get_valid_annotations, + get_probs_meta_classes_data, + get_meds_meta_classes_data, +) load_dotenv(find_dotenv()) @@ -39,10 +47,12 @@ def new_write(string): stdout.write = new_write yield + @st.cache_data def load_csv_data(csv_path): return pd.read_csv(csv_path) + @st.cache_data def get_label_counts(name, train, synth): real_counts = {} @@ -55,18 +65,21 @@ def get_label_counts(name, train, synth): synthetic_counts = synthetic_labels.value_counts().to_dict() return real_counts, synthetic_counts + @st.cache_data def get_chart_data(labels, label_count, synth_add_count): return pd.DataFrame( - {"real": [label_count.get(labels[i], 0) for i in range(len(labels))], - "synthetic": synth_add_count.values()}, - index=category_labels) + {"real": [label_count.get(labels[i], 0) for i in range(len(labels))], "synthetic": synth_add_count.values()}, + index=category_labels, + ) + @st.cache_data def make_train_data(synth_df, name, labels, synth_add_count): - return pd.concat([synth_df[synth_df[name] == label][:synth_add_count[label]] - for label in labels], - ignore_index=True) + return pd.concat( + [synth_df[synth_df[name] == label][: synth_add_count[label]] for label in labels], ignore_index=True + ) + @st.cache_resource def load_metacat_model(path): @@ -81,6 +94,7 @@ def load_metacat_model(path): name = None return model, name + @st.cache_resource def load_medcat_model(path): try: @@ -90,6 +104,7 @@ def load_medcat_model(path): model = None return model + MIN_HEIGHT = 27 MAX_HEIGHT = 800 ROW_HEIGHT = 35 @@ -99,21 +114,20 @@ def load_medcat_model(path): TEST_JSON_OPTIONS = [f for f in os.listdir(os.getenv("TEST_JSON_DIR")) if ".json" in f] MEDCAT_OPTIONS = [f for f in os.listdir(os.getenv("SAVE_DIR")) if ".zip" in f] -MODEL_OPTIONS = ["/".join(f[0].split("/")[-2:]) for f in os.walk(os.getenv("MODELS_DIR")) - if 'meta_' in f[0].split("/")[-1] and ".ipynb_checkpoints" not in f[0]] +MODEL_OPTIONS = [ + "/".join(f[0].split("/")[-2:]) + for f in os.walk(os.getenv("MODELS_DIR")) + if "meta_" in f[0].split("/")[-1] and ".ipynb_checkpoints" not in f[0] +] -st.set_page_config( - layout="wide", page_icon="🖱️", page_title="MiADE train app" -) +st.set_page_config(layout="wide", page_icon="🖱️", page_title="MiADE train app") st.title("🖱️ MiADE Training Dashboard") -st.write( - """Hello! Train, test, and experiment with MedCAT models used in MiADE""" -) +st.write("""Hello! Train, test, and experiment with MedCAT models used in MiADE""") def present_confusion_matrix(model, data): data_name = Path(data).stem - model_name = model.config.general['category_name'] + model_name = model.config.general["category_name"] title = f"{model_name} evaluated against\n{data_name}" evaluation = model.eval(data) @@ -121,7 +135,8 @@ def present_confusion_matrix(model, data): cm = evaluation["confusion matrix"].values label_names = [name.split()[-1] for name in list(evaluation["confusion matrix"].columns)] stats_text = "\n\nPrecision={:0.3f}\nRecall={:0.3f}\nF1 Score={:0.3f}".format( - evaluation['precision'], evaluation['recall'], evaluation['f1']) + evaluation["precision"], evaluation["recall"], evaluation["f1"] + ) group_counts = ["{0:0.0f}".format(value) for value in cm.flatten()] group_percentages = ["{0:.2%}".format(value) for value in cm.flatten() / np.sum(cm)] @@ -135,16 +150,16 @@ def present_confusion_matrix(model, data): square=True, xticklabels=label_names, yticklabels=label_names, - fmt='', + fmt="", ) - conf.set(xlabel='True' + stats_text, ylabel='Predicted', title=title) + conf.set(xlabel="True" + stats_text, ylabel="Predicted", title=title) st.write(evaluation) return plt def add_metacat_models( - model: str, - meta_cats_path: List, + model: str, + meta_cats_path: List, ): out_dir = os.getenv("SAVE_DIR", "./") cat = CAT.load_model_pack(str(model)) @@ -160,7 +175,7 @@ def add_metacat_models( description = cat.config.version["description"] + " | Packaged with MetaCAT model(s) " + ", ".join(categories) cat.config.version["description"] = description - save_name = Path(model).stem.rsplit("_", 3)[0] + "_w_meta_" + datetime.now().strftime('%b_%Y').lower() + save_name = Path(model).stem.rsplit("_", 3)[0] + "_w_meta_" + datetime.now().strftime("%b_%Y").lower() try: cat_w_meta.create_model_pack(save_dir_path=out_dir, model_pack_name=save_name) st.success("Saved MedCAT modelpack at " + out_dir + save_name + "_" + cat_w_meta.config.version["id"]) @@ -197,6 +212,7 @@ def aggrid_interactive_table(df: pd.DataFrame): return selection + # side bar st.sidebar.subheader("Select model to train") model_path = st.sidebar.selectbox("Select MetaCAT model path", MODEL_OPTIONS) @@ -221,7 +237,7 @@ def aggrid_interactive_table(df: pd.DataFrame): metacat_paths = [os.path.join(os.getenv("MODELS_DIR"), "/".join(path.split("/")[-2:])) for path in selected_models] selected_medcat = st.selectbox("Select MedCAT modelpack to package with:", MEDCAT_OPTIONS) medcat_path = os.getenv("SAVE_DIR") + selected_medcat - submit = st.form_submit_button(label='Save') + submit = st.form_submit_button(label="Save") if submit: add_metacat_models(medcat_path, metacat_paths) # update options probably a better way to do this @@ -237,8 +253,10 @@ def aggrid_interactive_table(df: pd.DataFrame): col1, col2, col3 = st.columns(3) with col1: - st.markdown("**Adjust** the sliders to vary the amount of synthetic data " - " you want to include in the training data in addition to your annotations:") + st.markdown( + "**Adjust** the sliders to vary the amount of synthetic data " + " you want to include in the training data in addition to your annotations:" + ) train_json_path = st.selectbox("Select annotated training data", TRAIN_JSON_OPTIONS) train_csv = train_json_path.replace(".json", ".csv") @@ -266,9 +284,7 @@ def aggrid_interactive_table(df: pd.DataFrame): all_synth_df = load_csv_data(synth_csv_path) if mc is not None: category_labels = list(mc.config.general["category_value2id"].keys()) - real_label_counts, synthetic_label_counts = get_label_counts( - model_name, train_data_df, all_synth_df - ) + real_label_counts, synthetic_label_counts = get_label_counts(model_name, train_data_df, all_synth_df) if real_label_counts: max_class = max(real_label_counts.values()) else: @@ -279,10 +295,12 @@ def aggrid_interactive_table(df: pd.DataFrame): synth_add_dict = {} for i in range(len(category_labels)): - synth_add_dict[category_labels[i]] = st.slider(category_labels[i] + " (synthetic)", - min_value=0, - max_value=synthetic_label_counts.get(category_labels[i], 0), - value=max_class - real_label_counts.get(category_labels[i], 0)) + synth_add_dict[category_labels[i]] = st.slider( + category_labels[i] + " (synthetic)", + min_value=0, + max_value=synthetic_label_counts.get(category_labels[i], 0), + value=max_class - real_label_counts.get(category_labels[i], 0), + ) with col2: st.markdown("**Visualise** the ratio of real and synthetic in your overall training set:") if mc is not None: @@ -296,7 +314,7 @@ def aggrid_interactive_table(df: pd.DataFrame): if mc is not None: st.dataframe(synth_train_df[["text", model_name]], height=500) - if st.button('Train'): + if st.button("Train"): if mc is not None: with st.spinner("Training MetaCAT..."): date_id = datetime.now().strftime("%y%m%d%H%M%S") @@ -323,7 +341,7 @@ def aggrid_interactive_table(df: pd.DataFrame): if synth_data_column is not None: synth_count = synth_data_column.value_counts().to_dict() if not train_count: - train_length = 1 #min. num data in json + train_length = 1 # min. num data in json else: train_length = len(train_data_df) total_count = train_length + len(synth_train_df) @@ -341,15 +359,18 @@ def aggrid_interactive_table(df: pd.DataFrame): with st.expander("Expand to see training logs"): output = st.empty() with st_capture(output.code): - report = mc.train(json_path=train_json_path, - synthetic_csv_path=data_save_name, - save_dir_path=model_save_name) + report = mc.train( + json_path=train_json_path, synthetic_csv_path=data_save_name, save_dir_path=model_save_name + ) st.success(f"Done! Model saved at {model_save_name}") st.write("Training report:") st.write(report) - MODEL_OPTIONS = ["/".join(f[0].split("/")[-2:]) for f in os.walk(os.getenv("MODELS_DIR")) - if 'meta_' in f[0].split("/")[-1] and ".ipynb_checkpoints" not in f[0]] + MODEL_OPTIONS = [ + "/".join(f[0].split("/")[-2:]) + for f in os.walk(os.getenv("MODELS_DIR")) + if "meta_" in f[0].split("/")[-1] and ".ipynb_checkpoints" not in f[0] + ] else: st.error("No model loaded") @@ -373,7 +394,7 @@ def aggrid_interactive_table(df: pd.DataFrame): cm.pyplot(plt) if is_save: try: - plt.savefig(model_path + "/confusion_matrix.png", format='png', bbox_inches="tight", dpi=200) + plt.savefig(model_path + "/confusion_matrix.png", format="png", bbox_inches="tight", dpi=200) except Exception as e: st.error(f"Could not save image: {e}") @@ -406,5 +427,5 @@ def aggrid_interactive_table(df: pd.DataFrame): cat = load_medcat_model(medcat_path) output = cat.get_entities(text) doc = cat(text) - visualize_ner(doc, title=None, show_table=False, displacy_options={"colors":{"concept":"#F17156"}}) + visualize_ner(doc, title=None, show_table=False, displacy_options={"colors": {"concept": "#F17156"}}) st.write(output) diff --git a/streamlit_app/utils.py b/streamlit_app/utils.py index 23869cd..1ee0dec 100644 --- a/streamlit_app/utils.py +++ b/streamlit_app/utils.py @@ -1,24 +1,23 @@ import pandas as pd -from medcat.meta_cat import MetaCAT -from medcat.config_meta_cat import ConfigMetaCAT -from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBPE from typing import Optional def load_documents(data): documents = {} - for i in range(0,len(data['projects'][0]['documents'])): - documents[data['projects'][0]['documents'][i]['id']] = data['projects'][0]['documents'][i]['text'] + for i in range(0, len(data["projects"][0]["documents"])): + documents[data["projects"][0]["documents"][i]["id"]] = data["projects"][0]["documents"][i]["text"] return documents def load_annotations(data): annotations = [] - for i in range(0,len(data['projects'][0]['documents'])): - document_id = data['projects'][0]['documents'][i]['id'] - annotations.extend([Annotation.from_dict(ann, document_id) for ann in data['projects'][0]['documents'][i]['annotations']]) + for i in range(0, len(data["projects"][0]["documents"])): + document_id = data["projects"][0]["documents"][i]["id"] + annotations.extend( + [Annotation.from_dict(ann, document_id) for ann in data["projects"][0]["documents"][i]["annotations"]] + ) return annotations @@ -30,14 +29,17 @@ def get_valid_annotations(data): return annotations -def get_probs_meta_classes_data(documents, annotations, ): +def get_probs_meta_classes_data( + documents, + annotations, +): r_labels = [] p_labels = [] l_labels = [] cuis = [] names = [] texts = [] - tokens = [] + # tokens = [] for ann in annotations: r_labels.append(ann.meta_relevance) p_labels.append(ann.meta_presence) @@ -61,17 +63,24 @@ def get_probs_meta_classes_data(documents, annotations, ): # tkns = doc_text['tokens'][t_start:t_end] # tokens.append(tkns) - df = pd.DataFrame({"text": texts, - "cui": cuis, - "name": names, - # "tokens": tokens, - "relevance": r_labels, - "presence": p_labels, - "laterality (generic)": l_labels, }) + df = pd.DataFrame( + { + "text": texts, + "cui": cuis, + "name": names, + # "tokens": tokens, + "relevance": r_labels, + "presence": p_labels, + "laterality (generic)": l_labels, + } + ) return df -def get_meds_meta_classes_data(documents, annotations, ): +def get_meds_meta_classes_data( + documents, + annotations, +): substance_labels = [] allergy_labels = [] severity_labels = [] @@ -79,7 +88,7 @@ def get_meds_meta_classes_data(documents, annotations, ): cuis = [] names = [] texts = [] - tokens = [] + # tokens = [] for ann in annotations: substance_labels.append(ann.meta_substance_cat) allergy_labels.append(ann.meta_allergy_type) @@ -104,41 +113,43 @@ def get_meds_meta_classes_data(documents, annotations, ): # tkns = doc_text['tokens'][t_start:t_end] # tokens.append(tkns) - df = pd.DataFrame({"text": texts, - "cui": cuis, - "name": names, - # "tokens": tokens, - "substance_category": substance_labels, - "allergy_type": allergy_labels, - "severity": severity_labels, - "reaction_pos": reaction_labels}) + df = pd.DataFrame( + { + "text": texts, + "cui": cuis, + "name": names, + # "tokens": tokens, + "substance_category": substance_labels, + "allergy_type": allergy_labels, + "severity": severity_labels, + "reaction_pos": reaction_labels, + } + ) return df - - class Annotation: def __init__( - self, - alternative, - id, - document_id, - cui, - value, - deleted, - start, - end, - irrelevant, - killed, - manually_created, - meta_laterality, - meta_presence, - meta_relevance, - meta_allergy_type, - meta_substance_cat, - meta_severity, - meta_reaction_pos, - dictionary + self, + alternative, + id, + document_id, + cui, + value, + deleted, + start, + end, + irrelevant, + killed, + manually_created, + meta_laterality, + meta_presence, + meta_relevance, + meta_allergy_type, + meta_substance_cat, + meta_severity, + meta_reaction_pos, + dictionary, ): self.alternative = alternative self.id = id @@ -173,40 +184,40 @@ def from_dict(cls, d, document_id): meta_anns = d.get("meta_anns") if meta_anns is not None: - meta_ann_l = meta_anns.get('laterality (generic)') + meta_ann_l = meta_anns.get("laterality (generic)") if meta_ann_l is not None: - meta_laterality = meta_ann_l['value'] - meta_ann_r = meta_anns.get('relevance') + meta_laterality = meta_ann_l["value"] + meta_ann_r = meta_anns.get("relevance") if meta_ann_r is not None: - meta_relevance = meta_ann_r['value'] - meta_ann_p = meta_anns.get('presence') + meta_relevance = meta_ann_r["value"] + meta_ann_p = meta_anns.get("presence") if meta_ann_p is not None: - meta_presence = meta_ann_p['value'] + meta_presence = meta_ann_p["value"] - meta_ann_allergy = meta_anns.get('allergy_type') + meta_ann_allergy = meta_anns.get("allergy_type") if meta_ann_allergy is not None: - meta_allergy_type = meta_ann_allergy['value'] - meta_ann_substance = meta_anns.get('substance_category') + meta_allergy_type = meta_ann_allergy["value"] + meta_ann_substance = meta_anns.get("substance_category") if meta_ann_substance is not None: - meta_substance_cat = meta_ann_substance['value'] - meta_ann_severity = meta_anns.get('severity') + meta_substance_cat = meta_ann_substance["value"] + meta_ann_severity = meta_anns.get("severity") if meta_ann_severity is not None: - meta_severity = meta_ann_severity['value'] - meta_ann_reaction = meta_anns.get('reaction_pos') + meta_severity = meta_ann_severity["value"] + meta_ann_reaction = meta_anns.get("reaction_pos") if meta_ann_reaction is not None: - meta_reaction_pos = meta_ann_reaction['value'] + meta_reaction_pos = meta_ann_reaction["value"] return cls( - alternative=d['alternative'], - id=d['id'], + alternative=d["alternative"], + id=d["id"], document_id=document_id, - cui=d['cui'], - value=d['value'], - deleted=d['deleted'], - start=d['start'], - end=d['end'], - irrelevant=d['irrelevant'], - killed=d['killed'], - manually_created=d['manually_created'], + cui=d["cui"], + value=d["value"], + deleted=d["deleted"], + start=d["start"], + end=d["end"], + irrelevant=d["irrelevant"], + killed=d["killed"], + manually_created=d["manually_created"], meta_laterality=meta_laterality, meta_presence=meta_presence, meta_relevance=meta_relevance, @@ -245,45 +256,23 @@ def __str__(self): def __eq__(self, other): return ( - self.alternative == other.alternative - and - self.cui == other.cui - and - self.document_id == other.document_id - and - self.deleted == other.deleted - and - self.start == other.start - and - self.end == other.end - and - self.irrelevant == other.irrelevant - and - self.killed == other.killed - and - self.manually_created == other.manually_created - and - self.meta_laterality == other.meta_laterality - and - self.meta_presence == other.meta_presence - and - self.meta_relevance == other.meta_relevance - and - self.meta_substance_cat == other.meta_substance_cat - and - self.meta_allergy_type == other.meta_allergy_type - and - self.meta_severity == other.meta_severity - and - self.meta_reaction_pos == other.meta_reaction_pos - + self.alternative == other.alternative + and self.cui == other.cui + and self.document_id == other.document_id + and self.deleted == other.deleted + and self.start == other.start + and self.end == other.end + and self.irrelevant == other.irrelevant + and self.killed == other.killed + and self.manually_created == other.manually_created + and self.meta_laterality == other.meta_laterality + and self.meta_presence == other.meta_presence + and self.meta_relevance == other.meta_relevance + and self.meta_substance_cat == other.meta_substance_cat + and self.meta_allergy_type == other.meta_allergy_type + and self.meta_severity == other.meta_severity + and self.meta_reaction_pos == other.meta_reaction_pos ) def is_same_model_annotation(self, other): - return ( - self.cui == other.cui - and - self.start == other.start - and - self.end == other.end - ) + return self.cui == other.cui and self.start == other.start and self.end == other.end diff --git a/tests/conftest.py b/tests/conftest.py index b1dacc0..0162ff4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,15 @@ from miade.note import Note from miade.concept import Concept, Category from miade.metaannotations import MetaAnnotations -from miade.utils.metaannotationstypes import * +from miade.utils.metaannotationstypes import ( + Presence, + Relevance, + Laterality, + ReactionPos, + SubstanceCategory, + AllergyType, + Severity, +) from miade.utils.miade_cat import MiADE_CAT @@ -19,12 +27,16 @@ def model_directory_path() -> Path: @pytest.fixture(scope="function") def test_problems_medcat_model() -> MiADE_CAT: - return MiADE_CAT.load_model_pack(str("./tests/data/models/miade_problems_blank_modelpack_Jun_2023_df349473b9d260a9.zip")) + return MiADE_CAT.load_model_pack( + str("./tests/data/models/miade_problems_blank_modelpack_Jun_2023_df349473b9d260a9.zip") + ) @pytest.fixture(scope="function") def test_meds_algy_medcat_model() -> MiADE_CAT: - return MiADE_CAT.load_model_pack(str("./tests/data/models/miade_meds_allergy_blank_modelpack_Jun_2023_75e13bf042cc55b8.zip")) + return MiADE_CAT.load_model_pack( + str("./tests/data/models/miade_meds_allergy_blank_modelpack_Jun_2023_75e13bf042cc55b8.zip") + ) @pytest.fixture(scope="function") @@ -34,21 +46,21 @@ def test_note() -> Note: @pytest.fixture(scope="function") def test_negated_note() -> Note: - return Note( - text="Patient does not have liver failure. Patient is taking paracetamol 500mg oral tablets." - ) + return Note(text="Patient does not have liver failure. Patient is taking paracetamol 500mg oral tablets.") @pytest.fixture(scope="function") def test_duplicated_note() -> Note: return Note( text="Patient has liver failure. The liver failure is quite bad. Patient is taking " - "paracetamol 500mg oral tablets. decrease paracetamol 500mg oral tablets dosage." + "paracetamol 500mg oral tablets. decrease paracetamol 500mg oral tablets dosage." ) + @pytest.fixture(scope="function") def test_clean_and_paragraphing_note() -> Note: - return Note(""" + return Note( + """ This is an example of text with various types of spaces: \tTabs, \u00A0Non-breaking spaces, \u2003Em spaces, \u2002En spaces. Some lines may contain only punctuation and spaces, like this: @@ -82,7 +94,9 @@ def test_clean_and_paragraphing_note() -> Note: imp:: Penicillin - """) + """ + ) + @pytest.fixture(scope="function") def test_paragraph_chunking_concepts() -> List[Concept]: @@ -234,6 +248,7 @@ def test_paragraph_chunking_concepts() -> List[Concept]: ), ] + @pytest.fixture(scope="function") def temp_dir() -> Path: return Path("./tests/data/temp") @@ -302,33 +317,22 @@ def test_med_concepts() -> List[Concept]: start=0, end=19, ), - Concept( - id="1", name="Paracetamol", category=Category.MEDICATION, start=32, end=43 - ), - Concept( - id="2", name="Aspirin", category=Category.MEDICATION, start=99, end=107 - ), - Concept( - id="3", name="Doxycycline", category=Category.MEDICATION, start=144, end=156 - ), + Concept(id="1", name="Paracetamol", category=Category.MEDICATION, start=32, end=43), + Concept(id="2", name="Aspirin", category=Category.MEDICATION, start=99, end=107), + Concept(id="3", name="Doxycycline", category=Category.MEDICATION, start=144, end=156), ] @pytest.fixture(scope="function") def test_miade_doses() -> (List[Note], pd.DataFrame): extracted_doses = pd.read_csv("./tests/examples/common_doses_for_miade.csv") - return [ - Note(text=dose) for dose in extracted_doses.dosestring.to_list() - ], extracted_doses + return [Note(text=dose) for dose in extracted_doses.dosestring.to_list()], extracted_doses @pytest.fixture(scope="function") def test_miade_med_concepts() -> List[Concept]: data = pd.read_csv("./tests/examples/common_doses_for_miade.csv") - return [ - Concept(id="387337001", name=drug, category=Category.MEDICATION) - for drug in data.drug.to_list() - ] + return [Concept(id="387337001", name=drug, category=Category.MEDICATION) for drug in data.drug.to_list()] @pytest.fixture(scope="function") @@ -455,7 +459,7 @@ def test_meta_annotations_concepts() -> List[Concept]: meta_anns=[ MetaAnnotations(name="presence", value=Presence.CONFIRMED), MetaAnnotations(name="relevance", value=Relevance.PRESENT), - MetaAnnotations(name="laterality (generic)", value=Laterality.LEFT) + MetaAnnotations(name="laterality (generic)", value=Laterality.LEFT), ], ), Concept( @@ -466,7 +470,7 @@ def test_meta_annotations_concepts() -> List[Concept]: meta_anns=[ MetaAnnotations(name="presence", value=Presence.NEGATED), MetaAnnotations(name="relevance", value=Relevance.PRESENT), - MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY) + MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY), ], ), Concept( @@ -477,7 +481,7 @@ def test_meta_annotations_concepts() -> List[Concept]: meta_anns=[ MetaAnnotations(name="presence", value=Presence.NEGATED), MetaAnnotations(name="relevance", value=Relevance.PRESENT), - MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY) + MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY), ], ), Concept( @@ -488,7 +492,7 @@ def test_meta_annotations_concepts() -> List[Concept]: meta_anns=[ MetaAnnotations(name="presence", value=Presence.SUSPECTED), MetaAnnotations(name="relevance", value=Relevance.PRESENT), - MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY) + MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY), ], ), Concept( @@ -499,7 +503,7 @@ def test_meta_annotations_concepts() -> List[Concept]: meta_anns=[ MetaAnnotations(name="presence", value=Presence.SUSPECTED), MetaAnnotations(name="relevance", value=Relevance.PRESENT), - MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY) + MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY), ], ), Concept( @@ -510,7 +514,7 @@ def test_meta_annotations_concepts() -> List[Concept]: meta_anns=[ MetaAnnotations(name="presence", value=Presence.SUSPECTED), MetaAnnotations(name="relevance", value=Relevance.PRESENT), - MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY) + MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY), ], ), Concept( @@ -521,7 +525,7 @@ def test_meta_annotations_concepts() -> List[Concept]: meta_anns=[ MetaAnnotations(name="presence", value=Presence.CONFIRMED), MetaAnnotations(name="relevance", value=Relevance.HISTORIC), - MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY) + MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY), ], ), Concept( @@ -532,7 +536,7 @@ def test_meta_annotations_concepts() -> List[Concept]: meta_anns=[ MetaAnnotations(name="presence", value=Presence.CONFIRMED), MetaAnnotations(name="relevance", value=Relevance.IRRELEVANT), - MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY) + MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY), ], ), Concept( @@ -543,7 +547,7 @@ def test_meta_annotations_concepts() -> List[Concept]: meta_anns=[ MetaAnnotations(name="presence", value=Presence.CONFIRMED), MetaAnnotations(name="relevance", value=Relevance.HISTORIC), - MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY) + MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY), ], ), ] @@ -556,80 +560,113 @@ def test_filtering_list_concepts() -> List[Concept]: Concept(id="13543005", name="Pressure", category=Category.PROBLEM), Concept(id="19342008", name="Subacute disease", category=Category.PROBLEM), Concept(id="76797004", name="Failure", category=Category.PROBLEM), - Concept(id="123", name="real concept", category=Category.PROBLEM) + Concept(id="123", name="real concept", category=Category.PROBLEM), ] + @pytest.fixture(scope="function") def test_meds_allergy_note() -> Note: return Note( text="Intolerant of eggs mild rash. Allergies: moderate nausea due to penicillin. Taking paracetamol for pain." ) + @pytest.fixture(scope="function") def test_substance_concepts_with_meta_anns() -> List[Concept]: return [ - Concept(id="226021002", name="Eggs", start=14, end=17, meta_anns=[ - MetaAnnotations(name="reactionpos", value=ReactionPos.NOT_REACTION), - MetaAnnotations(name="category", value=SubstanceCategory.ADVERSE_REACTION), - MetaAnnotations(name="allergytype", value=AllergyType.INTOLERANCE), - MetaAnnotations(name="severity", value=Severity.MILD), - ]), - Concept(id="159002", name="Penicillin", start=64, end=73, meta_anns=[ - MetaAnnotations(name="reactionpos", value=ReactionPos.NOT_REACTION), - MetaAnnotations(name="category", value=SubstanceCategory.ADVERSE_REACTION), - MetaAnnotations(name="allergytype", value=AllergyType.ALLERGY), - MetaAnnotations(name="severity", value=Severity.MODERATE), - ]), - Concept(id="140004", name="Rash", start=24, end=27, meta_anns=[ - MetaAnnotations(name="reactionpos", value=ReactionPos.AFTER_SUBSTANCE), - MetaAnnotations(name="category", value=SubstanceCategory.NOT_SUBSTANCE), - MetaAnnotations(name="allergytype", value=AllergyType.UNSPECIFIED), - MetaAnnotations(name="severity", value=Severity.UNSPECIFIED), - ]), - Concept(id="832007", name="Nausea", start=50, end=55, meta_anns=[ - MetaAnnotations(name="reactionpos", value=ReactionPos.BEFORE_SUBSTANCE), - MetaAnnotations(name="category", value=SubstanceCategory.NOT_SUBSTANCE), - MetaAnnotations(name="allergytype", value=AllergyType.UNSPECIFIED), - MetaAnnotations(name="severity", value=Severity.UNSPECIFIED), - ]), - Concept(id="7336002", name="Paracetamol", start=83, end=93, - dosage=Dosage( - dose=Dose(value=50, unit="mg"), - duration=None, - frequency=None, - route=None - ), - meta_anns=[ - MetaAnnotations(name="reactionpos", value=ReactionPos.NOT_REACTION), - MetaAnnotations(name="category", value=SubstanceCategory.TAKING), - MetaAnnotations(name="allergytype", value=AllergyType.UNSPECIFIED), - MetaAnnotations(name="severity", value=Severity.UNSPECIFIED), - ]), + Concept( + id="226021002", + name="Eggs", + start=14, + end=17, + meta_anns=[ + MetaAnnotations(name="reaction_pos", value=ReactionPos.NOT_REACTION), + MetaAnnotations(name="category", value=SubstanceCategory.ADVERSE_REACTION), + MetaAnnotations(name="allergy_type", value=AllergyType.INTOLERANCE), + MetaAnnotations(name="severity", value=Severity.MILD), + ], + ), + Concept( + id="159002", + name="Penicillin", + start=64, + end=73, + meta_anns=[ + MetaAnnotations(name="reaction_pos", value=ReactionPos.NOT_REACTION), + MetaAnnotations(name="category", value=SubstanceCategory.ADVERSE_REACTION), + MetaAnnotations(name="allergy_type", value=AllergyType.ALLERGY), + MetaAnnotations(name="severity", value=Severity.MODERATE), + ], + ), + Concept( + id="140004", + name="Rash", + start=24, + end=27, + meta_anns=[ + MetaAnnotations(name="reaction_pos", value=ReactionPos.AFTER_SUBSTANCE), + MetaAnnotations(name="category", value=SubstanceCategory.NOT_SUBSTANCE), + MetaAnnotations(name="allergy_type", value=AllergyType.UNSPECIFIED), + MetaAnnotations(name="severity", value=Severity.UNSPECIFIED), + ], + ), + Concept( + id="832007", + name="Nausea", + start=50, + end=55, + meta_anns=[ + MetaAnnotations(name="reaction_pos", value=ReactionPos.BEFORE_SUBSTANCE), + MetaAnnotations(name="category", value=SubstanceCategory.ADVERSE_REACTION), + MetaAnnotations(name="allergy_type", value=AllergyType.UNSPECIFIED), + MetaAnnotations(name="severity", value=Severity.UNSPECIFIED), + ], + ), + Concept( + id="7336002", + name="Paracetamol", + start=83, + end=93, + dosage=Dosage(dose=Dose(value=50, unit="mg"), duration=None, frequency=None, route=None), + meta_anns=[ + MetaAnnotations(name="reaction_pos", value=ReactionPos.NOT_REACTION), + MetaAnnotations(name="category", value=SubstanceCategory.TAKING), + MetaAnnotations(name="allergy_type", value=AllergyType.UNSPECIFIED), + MetaAnnotations(name="severity", value=Severity.UNSPECIFIED), + ], + ), ] + @pytest.fixture(scope="function") def test_vtm_concepts() -> List[Concept]: return [ Concept( - id="302007", name="Spiramycin", category=Category.MEDICATION, - dosage=Dosage( - dose=Dose(value=10, unit="mg"), - duration=None, - frequency=None, - route=None, - ), + id="302007", + name="Spiramycin", + category=Category.MEDICATION, + dosage=Dosage( + dose=Dose(value=10, unit="mg"), + duration=None, + frequency=None, + route=None, ), + ), Concept( - id="7336002", name="Paracetamol", category=Category.MEDICATION, - dosage=Dosage( - dose=Dose(value=50, unit="mg"), - duration=None, - frequency=None, - route=None, - ), + id="7336002", + name="Paracetamol", + category=Category.MEDICATION, + dosage=Dosage( + dose=Dose(value=50, unit="mg"), + duration=None, + frequency=None, + route=None, ), + ), Concept( - id="7947003", name="Aspirin", category=Category.MEDICATION, + id="7947003", + name="Aspirin", + category=Category.MEDICATION, dosage=Dosage( dose=None, duration=None, @@ -637,12 +674,11 @@ def test_vtm_concepts() -> List[Concept]: route=Route(full_name="Oral", value="C38288"), ), ), + Concept(id="6247001", name="Folic acid", category=Category.MEDICATION, dosage=None), Concept( - id="6247001", name="Folic acid", category=Category.MEDICATION, - dosage=None - ), - Concept( - id="350057002", name="Selenium", category=Category.MEDICATION, + id="350057002", + name="Selenium", + category=Category.MEDICATION, dosage=Dosage( dose=Dose(value=50, unit="microgram"), duration=None, @@ -651,7 +687,9 @@ def test_vtm_concepts() -> List[Concept]: ), ), Concept( - id="350057002", name="Selenium", category=Category.MEDICATION, + id="350057002", + name="Selenium", + category=Category.MEDICATION, dosage=Dosage( dose=Dose(value=10, unit="microgram"), duration=None, @@ -659,4 +697,4 @@ def test_vtm_concepts() -> List[Concept]: route=None, ), ), - ] \ No newline at end of file + ] diff --git a/tests/data/models/config.yaml b/tests/data/models/config.yaml index 254b78c..1f6f91a 100644 --- a/tests/data/models/config.yaml +++ b/tests/data/models/config.yaml @@ -8,8 +8,10 @@ annotators: custom: CustomAnnotator general: problems: + lookup_data_path: "./lookup_data/" negation_detection: negex # negex or metacat or none disable: [] meds/allergies: + lookup_data_path: "./lookup_data/" negation_detection: None disable: [] diff --git a/tests/test_annotator.py b/tests/test_annotator.py index 71e94f2..6b5165b 100644 --- a/tests/test_annotator.py +++ b/tests/test_annotator.py @@ -3,6 +3,7 @@ from miade.dosage import Dose, Frequency, Dosage, Route from miade.dosageextractor import DosageExtractor + def test_dosage_text_splitter(test_meds_algy_medcat_model, test_med_concepts, test_med_note): annotator = MedsAllergiesAnnotator(test_meds_algy_medcat_model) dosage_extractor = DosageExtractor() @@ -14,9 +15,7 @@ def test_dosage_text_splitter(test_meds_algy_medcat_model, test_med_concepts, te assert concepts[2].dosage.text == "aspirin IM q daily x 2 weeks with concurrent " assert concepts[3].dosage.text == "DOXYCYCLINE 500mg tablets for two weeks" - assert concepts[0].dosage.dose == Dose( - source="75 mg", value=75, unit="mg", low=None, high=None - ) + assert concepts[0].dosage.dose == Dose(source="75 mg", value=75, unit="mg", low=None, high=None) assert concepts[0].dosage.frequency == Frequency( source="start 75 mg every day ", @@ -32,6 +31,7 @@ def test_dosage_text_splitter(test_meds_algy_medcat_model, test_med_concepts, te def test_calculate_word_distance(): from miade.note import Note + note = Note("the quick broooooown fox jumped over the lazy dog") start1, end1 = 10, 20 start2, end2 = 10, 20 @@ -62,7 +62,6 @@ def test_calculate_word_distance(): assert calculate_word_distance(start1, end1, start2, end2, note) == 1 - def test_deduplicate( test_problems_medcat_model, test_duplicate_concepts_note, @@ -79,14 +78,11 @@ def test_deduplicate( Concept(id="7", name="test2", category=Category.MEDICATION), Concept(id="5", name="test2", category=Category.PROBLEM), ] - assert annotator.deduplicate( - concepts=test_self_duplicate_concepts_note, record_concepts=None) == [ + assert annotator.deduplicate(concepts=test_self_duplicate_concepts_note, record_concepts=None) == [ Concept(id="1", name="test1", category=Category.PROBLEM), Concept(id="2", name="test2", category=Category.MEDICATION), ] - assert annotator.deduplicate( - concepts=test_duplicate_concepts_note, record_concepts=None - ) == [ + assert annotator.deduplicate(concepts=test_duplicate_concepts_note, record_concepts=None) == [ Concept(id="1", name="test1", category=Category.PROBLEM), Concept(id="2", name="test2", category=Category.PROBLEM), Concept(id="3", name="test2", category=Category.PROBLEM), @@ -95,9 +91,7 @@ def test_deduplicate( Concept(id="5", name="test2", category=Category.PROBLEM), Concept(id="6", name="test2", category=Category.MEDICATION), ] - assert annotator.deduplicate( - concepts=test_duplicate_concepts_note, record_concepts=[] - ) == [ + assert annotator.deduplicate(concepts=test_duplicate_concepts_note, record_concepts=[]) == [ Concept(id="1", name="test1", category=Category.PROBLEM), Concept(id="2", name="test2", category=Category.PROBLEM), Concept(id="3", name="test2", category=Category.PROBLEM), @@ -113,27 +107,17 @@ def test_deduplicate( Concept(id=None, name="vtm1", category=Category.MEDICATION), Concept(id=None, name="vtm3", category=Category.MEDICATION), ] - assert ( - annotator.deduplicate( - concepts=[], record_concepts=test_duplicate_concepts_record - ) - == [] - ) + assert annotator.deduplicate(concepts=[], record_concepts=test_duplicate_concepts_record) == [] + def test_meta_annotations(test_problems_medcat_model, test_meta_annotations_concepts): annotator = ProblemsAnnotator(test_problems_medcat_model) assert annotator.postprocess(test_meta_annotations_concepts) == [ Concept(id="274826007", name="Nystagmus (negated)", category=Category.PROBLEM), # negex true, meta ignored - Concept( - id="302064001", name="Lymphangitis (negated)", category=Category.PROBLEM - ), # negex true, meta ignored - Concept( - id="431956005", name="Arthritis (suspected)", category=Category.PROBLEM - ), # negex false, meta processed - Concept( - id="413241009", name="Gastritis (suspected)", category=Category.PROBLEM - ), + Concept(id="302064001", name="Lymphangitis (negated)", category=Category.PROBLEM), # negex true, meta ignored + Concept(id="431956005", name="Arthritis (suspected)", category=Category.PROBLEM), # negex false, meta processed + Concept(id="413241009", name="Gastritis (suspected)", category=Category.PROBLEM), Concept( id="1847009", name="Endophthalmitis", @@ -158,11 +142,8 @@ def test_meta_annotations(test_problems_medcat_model, test_meta_annotations_conc Concept( id="1415005", name="Lymphangitis", category=Category.PROBLEM ), # negex false, meta processed but ignore negation - Concept( - id="413241009", name="Gastritis (suspected)", category=Category.PROBLEM - ), # negex false, meta processed - Concept(id="0000", name="historic concept", category=Category.PROBLEM - ), # historic with no conversion + Concept(id="413241009", name="Gastritis (suspected)", category=Category.PROBLEM), # negex false, meta processed + Concept(id="0000", name="historic concept", category=Category.PROBLEM), # historic with no conversion ] @@ -172,6 +153,7 @@ def test_problems_filtering_list(test_problems_medcat_model, test_filtering_list Concept(id="123", name="real concept", category=Category.PROBLEM), ] + def test_allergy_annotator(test_meds_algy_medcat_model, test_substance_concepts_with_meta_anns, test_meds_allergy_note): annotator = MedsAllergiesAnnotator(test_meds_algy_medcat_model) concepts = annotator.postprocess(test_substance_concepts_with_meta_anns, test_meds_allergy_note) @@ -194,6 +176,7 @@ def test_allergy_annotator(test_meds_algy_medcat_model, test_substance_concepts_ ] assert concepts[2].linked_concepts == [] + def test_vtm_med_conversions(test_meds_algy_medcat_model, test_vtm_concepts): annotator = MedsAllergiesAnnotator(test_meds_algy_medcat_model) concepts = annotator.convert_VTM_to_VMP_or_text(test_vtm_concepts) @@ -231,4 +214,4 @@ def test_vtm_med_conversions(test_meds_algy_medcat_model, test_vtm_concepts): frequency=None, duration=None, route=None, - ) \ No newline at end of file + ) diff --git a/tests/test_cdbbuilder.py b/tests/test_cdbbuilder.py index fddc58d..522c242 100644 --- a/tests/test_cdbbuilder.py +++ b/tests/test_cdbbuilder.py @@ -1,5 +1,3 @@ -from pathlib import Path - from miade.model_builders import CDBBuilder diff --git a/tests/test_core.py b/tests/test_core.py index d51e752..1b020e0 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2,7 +2,12 @@ from miade.concept import Concept, Category from miade.annotators import Annotator from miade.metaannotations import MetaAnnotations -from miade.utils.metaannotationstypes import * +from miade.utils.metaannotationstypes import ( + Presence, + Relevance, + Laterality, +) + def test_core(model_directory_path, test_note, test_negated_note, test_duplicated_note): processor = NoteProcessor(model_directory_path) @@ -23,24 +28,34 @@ def test_core(model_directory_path, test_note, test_negated_note, test_duplicate ] assert processor.get_concept_dicts(test_note) == [ { - 'name': '00 liver failure', 'id': '59927004', 'category': 'PROBLEM', 'start': 12, 'end': 25, - 'dosage': None, 'linked_concepts': [], 'negex': False, 'meta': None, 'debug': None + "name": "00 liver failure", + "id": "59927004", + "category": "PROBLEM", + "start": 12, + "end": 25, + "dosage": None, + "linked_concepts": [], + "negex": False, + "meta": None, + "debug": None, }, { - 'name': 'paracetamol 500mg oral tablets', 'id': '322236009', 'category': 'MEDICATION', 'start': 40, 'end': 70, - 'dosage': { - 'dose': { - 'source': '500 mg by mouth tab', 'value': 500.0, 'unit': '{tbl}', 'low': None, 'high': None - }, - 'duration': None, - 'frequency': None, - 'route': { - 'source': 'by mouth', 'full_name': 'Oral', 'value': 'C38288', 'code_system': 'NCI Thesaurus' - } + "name": "paracetamol 500mg oral tablets", + "id": "322236009", + "category": "MEDICATION", + "start": 40, + "end": 70, + "dosage": { + "dose": {"source": "500 mg by mouth tab", "value": 500.0, "unit": "{tbl}", "low": None, "high": None}, + "duration": None, + "frequency": None, + "route": {"source": "by mouth", "full_name": "Oral", "value": "C38288", "code_system": "NCI Thesaurus"}, }, - 'linked_concepts': [], - 'negex': False, 'meta': None, 'debug': None - } + "linked_concepts": [], + "negex": False, + "meta": None, + "debug": None, + }, ] @@ -62,9 +77,11 @@ def test_adding_removing_annotators(model_directory_path): processor.print_model_cards() + def test_adding_custom_annotators(model_directory_path): class CustomAnnotator(Annotator): pass + processor = NoteProcessor(model_directory_path, custom_annotators=[CustomAnnotator]) processor.add_annotator("custom") @@ -73,6 +90,7 @@ class CustomAnnotator(Annotator): processor.remove_annotator("custom") assert len(processor.annotators) == 0 + def test_meta_from_entity(test_medcat_concepts): assert Concept.from_entity(test_medcat_concepts["0"]) == Concept( id="0", @@ -82,8 +100,8 @@ def test_meta_from_entity(test_medcat_concepts): end=11, meta_anns=[ MetaAnnotations(name="presence", value=Presence.NEGATED), - MetaAnnotations(name="relevance", value=Relevance.HISTORIC) - ] + MetaAnnotations(name="relevance", value=Relevance.HISTORIC), + ], ) assert Concept.from_entity(test_medcat_concepts["1"]) == Concept( id="0", @@ -94,7 +112,6 @@ def test_meta_from_entity(test_medcat_concepts): meta_anns=[ MetaAnnotations(name="presence", value=Presence.SUSPECTED, confidence=1), MetaAnnotations(name="relevance", value=Relevance.IRRELEVANT, confidence=1), - MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY, confidence=1) - ] + MetaAnnotations(name="laterality (generic)", value=Laterality.NO_LATERALITY, confidence=1), + ], ) - diff --git a/tests/test_dosageextractor.py b/tests/test_dosageextractor.py index d99a1d7..d23516c 100644 --- a/tests/test_dosageextractor.py +++ b/tests/test_dosageextractor.py @@ -1,4 +1,3 @@ -import pytest from pandas import isnull # from devtools import debug @@ -38,21 +37,13 @@ def test_dosage_extractor(test_miade_doses, test_miade_med_concepts): if not isnull(doses.timeinterval_value.values[ind]): assert dosage.frequency - assert round(dosage.frequency.value, 3) == round( - doses.timeinterval_value.values[ind], 3 - ) + assert round(dosage.frequency.value, 3) == round(doses.timeinterval_value.values[ind], 3) assert dosage.frequency.unit == doses.timeinterval_unit.values[ind] if not isnull(doses.institution_specified.values[ind]): assert dosage.frequency - assert ( - dosage.frequency.institutionSpecified - == doses.institution_specified.values[ind] - ) + assert dosage.frequency.institutionSpecified == doses.institution_specified.values[ind] if not isnull(doses.precondition_as_required.values[ind]): assert dosage.frequency - assert ( - dosage.frequency.preconditionAsRequired - == doses.precondition_as_required.values[ind] - ) + assert dosage.frequency.preconditionAsRequired == doses.precondition_as_required.values[ind] diff --git a/tests/test_install.py b/tests/test_install.py index fafce3f..1882e65 100644 --- a/tests/test_install.py +++ b/tests/test_install.py @@ -1,2 +1,2 @@ def test_install(): - import miade + import miade # noqa diff --git a/tests/test_note.py b/tests/test_note.py index 511e68c..329639b 100644 --- a/tests/test_note.py +++ b/tests/test_note.py @@ -2,10 +2,14 @@ from miade.paragraph import Paragraph, ParagraphType from miade.metaannotations import MetaAnnotations -from miade.utils.metaannotationstypes import * +from miade.utils.metaannotationstypes import ( + Presence, + Relevance, + SubstanceCategory, +) -def test_note(model_directory_path, test_clean_and_paragraphing_note, test_paragraph_chunking_concepts): +def test_note(model_directory_path, test_clean_and_paragraphing_note, test_paragraph_chunking_concepts): test_clean_and_paragraphing_note.clean_text() test_clean_and_paragraphing_note.get_paragraphs() @@ -28,8 +32,7 @@ def test_note(model_directory_path, test_clean_and_paragraphing_note, test_parag processor.add_annotator("meds/allergies") concepts = processor.annotators[0].process_paragraphs( - test_clean_and_paragraphing_note, - test_paragraph_chunking_concepts + test_clean_and_paragraphing_note, test_paragraph_chunking_concepts ) # prose assert concepts[0].meta == [ @@ -83,6 +86,7 @@ def test_note(model_directory_path, test_clean_and_paragraphing_note, test_parag # for concept in concepts: # print(concept) + def test_long_problem_list(): # TODO - pass \ No newline at end of file + pass