Skip to content

Commit

Permalink
🚨 ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbrandreth committed Nov 20, 2023
1 parent f8ea94a commit 6d103a5
Show file tree
Hide file tree
Showing 29 changed files with 647 additions and 680 deletions.
279 changes: 166 additions & 113 deletions src/miade/annotators.py

Large diffs are not rendered by default.

11 changes: 4 additions & 7 deletions src/miade/concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(
meta_anns: Optional[List[MetaAnnotations]] = None,
debug_dict: Optional[Dict] = None,
):

self.name = name
self.id = id
self.category = category
Expand All @@ -54,7 +53,9 @@ def from_entity(cls, entity: [Dict]):

return Concept(
id=entity["cui"],
name=entity["source_value"], # can also use detected_name which is spell checked but delimited by ~ e.g. liver~failure
name=entity[
"source_value"
], # can also use detected_name which is spell checked but delimited by ~ e.g. liver~failure
category=None,
start=entity["start"],
end=entity["end"],
Expand All @@ -72,11 +73,7 @@ def __hash__(self):
return hash((self.id, self.name, self.category))

def __eq__(self, other):
return (
self.id == other.id
and self.name == other.name
and self.category == other.category
)
return self.id == other.id and self.name == other.name and self.category == other.category

def __lt__(self, other):
return int(self.id) < int(other.id)
Expand Down
41 changes: 22 additions & 19 deletions src/miade/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ def create_annotator(name: str, model_factory: ModelFactory):
"""
name = name.lower()
if name not in model_factory.models:
raise ValueError(f"MedCAT model for {name} does not exist: either not configured in config.yaml or "
f"missing from models directory")
raise ValueError(
f"MedCAT model for {name} does not exist: either not configured in config.yaml or "
f"missing from models directory"
)

if name in model_factory.annotators.keys():
return model_factory.annotators[name](cat=model_factory.models.get(name), config=model_factory.configs.get(name))
return model_factory.annotators[name](
cat=model_factory.models.get(name), config=model_factory.configs.get(name)
)
else:
log.warning(f"Annotator {name} does not exist, loading generic Annotator")
return Annotator(model_factory.models[name])
Expand All @@ -48,14 +52,15 @@ class NoteProcessor:
:param device (str) whether inference should be run on cpu or gpu - default "cpu"
:param custom_annotators (List[Annotators]) List of custom annotators
"""

def __init__(
self,
model_directory: Path,
model_config_path: Path = None,
log_level: int = logging.INFO,
dosage_extractor_log_level: int = logging.INFO,
device: str = "cpu",
custom_annotators: Optional[List[Annotator]] = None
custom_annotators: Optional[List[Annotator]] = None,
):
logging.getLogger("miade").setLevel(log_level)
logging.getLogger("miade.dosageextractor").setLevel(dosage_extractor_log_level)
Expand Down Expand Up @@ -152,13 +157,10 @@ def _load_model_factory(self, custom_annotators: Optional[List[Annotator]] = Non
else:
log.warning("No general settings configured, using default settings.")

model_factory_config = {"models": mapped_models,
"annotators": mapped_annotators,
"configs": mapped_configs}
model_factory_config = {"models": mapped_models, "annotators": mapped_annotators, "configs": mapped_configs}

return ModelFactory(**model_factory_config)


def add_annotator(self, name: str) -> None:
"""
Adds annotators to processor
Expand All @@ -167,7 +169,9 @@ def add_annotator(self, name: str) -> None:
"""
try:
annotator = create_annotator(name, self.model_factory)
log.info(f"Added {type(annotator).__name__} to processor with config {self.model_factory.configs.get(name)}")
log.info(
f"Added {type(annotator).__name__} to processor with config {self.model_factory.configs.get(name)}"
)
except Exception as e:
raise Exception(f"Error creating annotator: {e}")

Expand Down Expand Up @@ -214,11 +218,9 @@ def process(self, note: Note, record_concepts: Optional[List[Concept]] = None) -

return concepts

def get_concept_dicts(self,
note: Note,
filter_uncategorized: bool = True,
record_concepts: Optional[List[Concept]] = None
) -> List[Dict]:
def get_concept_dicts(
self, note: Note, filter_uncategorized: bool = True, record_concepts: Optional[List[Concept]] = None
) -> List[Dict]:
"""
Returns concepts in dictionary format
:param note: (Note) note containing text to extract concepts from
Expand All @@ -233,10 +235,12 @@ def get_concept_dicts(self,
continue
concept_dict = concept.__dict__
if concept.dosage is not None:
concept_dict["dosage"] = {"dose": concept.dosage.dose.dict() if concept.dosage.dose else None,
"duration": concept.dosage.duration.dict() if concept.dosage.duration else None,
"frequency": concept.dosage.frequency.dict() if concept.dosage.frequency else None,
"route": concept.dosage.route.dict() if concept.dosage.route else None}
concept_dict["dosage"] = {
"dose": concept.dosage.dose.dict() if concept.dosage.dose else None,
"duration": concept.dosage.duration.dict() if concept.dosage.duration else None,
"frequency": concept.dosage.frequency.dict() if concept.dosage.frequency else None,
"route": concept.dosage.route.dict() if concept.dosage.route else None,
}
if concept.meta is not None:
meta_anns = []
for meta in concept.meta:
Expand All @@ -249,4 +253,3 @@ def get_concept_dicts(self,
concept_list.append(concept_dict)

return concept_list

15 changes: 4 additions & 11 deletions src/miade/dosage.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ class Route(BaseModel):
code_system: Optional[str] = ROUTE_CODE_SYSTEM


def parse_dose(
text: str, quantities: List[str], units: List[str], results: Dict
) -> Optional[Dose]:
def parse_dose(text: str, quantities: List[str], units: List[str], results: Dict) -> Optional[Dose]:
"""
:param text: (str) string containing dose
:param quantities: (list) list of quantity entities NER
Expand Down Expand Up @@ -119,8 +117,7 @@ def parse_dose(
# use caliber results as backup
if results["units"] is not None:
log.debug(
f"Inconclusive dose entities {quantities}, "
f"using lookup results {results['qty']} {results['units']}"
f"Inconclusive dose entities {quantities}, " f"using lookup results {results['qty']} {results['units']}"
)
quantity_dosage.unit = results["units"]
# only autofill 1 if non-quantitative units e.g. tab, cap, puff
Expand Down Expand Up @@ -327,13 +324,9 @@ def from_doc(cls, doc: Doc, calculate: bool = True):
# if duration not given in text could extract this from total dose if given
if total_dose is not None and dose is not None and doc._.results["freq"]:
if dose.value is not None:
daily_dose = float(dose.value) * (
round(doc._.results["freq"] / doc._.results["time"])
)
daily_dose = float(dose.value) * (round(doc._.results["freq"] / doc._.results["time"]))
elif dose.high is not None:
daily_dose = float(dose.high) * (
round(doc._.results["freq"] / doc._.results["time"])
)
daily_dose = float(dose.high) * (round(doc._.results["freq"] / doc._.results["time"]))

duration = parse_duration(
text=duration_text,
Expand Down
9 changes: 2 additions & 7 deletions src/miade/dosageextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,12 @@ def extract(self, text: str, calculate: bool = True) -> Optional[Dosage]:
"""
doc = self.dosage_extractor(text)

log.debug(
f"NER results: {[(e.text, e.label_, e._.total_dose) for e in doc.ents]}"
)
log.debug(f"NER results: {[(e.text, e.label_, e._.total_dose) for e in doc.ents]}")
log.debug(f"Lookup results: {doc._.results}")

dosage = Dosage.from_doc(doc=doc, calculate=calculate)

if all(
v is None
for v in [dosage.dose, dosage.frequency, dosage.route, dosage.duration]
):
if all(v is None for v in [dosage.dose, dosage.frequency, dosage.route, dosage.duration]):
return None

return dosage
Expand Down
6 changes: 1 addition & 5 deletions src/miade/drugdoseade/entities_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@ def EntitiesRefiner(doc):
new_ents = []
for ind, ent in enumerate(doc.ents):
# combine consecutive labels with the same tag
if (
ent.label_ == "DURATION"
or ent.label_ == "FREQUENCY"
or ent.label_ == "DOSAGE"
) and ind != 0:
if (ent.label_ == "DURATION" or ent.label_ == "FREQUENCY" or ent.label_ == "DOSAGE") and ind != 0:
prev_ent = doc.ents[ind - 1]
if prev_ent.label_ == ent.label_:
new_ent = Span(doc, prev_ent.start, ent.end, label=ent.label)
Expand Down
14 changes: 3 additions & 11 deletions src/miade/drugdoseade/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
@spacy.registry.misc("patterns_lookup_table.v1")
def create_patterns_dict():
patterns_data = pkgutil.get_data(__name__, "../data/patterns.csv")
patterns_dict = (
pd.read_csv(io.BytesIO(patterns_data), index_col=0)
.squeeze("columns")
.T.to_dict()
)
patterns_dict = pd.read_csv(io.BytesIO(patterns_data), index_col=0).squeeze("columns").T.to_dict()

return patterns_dict

Expand Down Expand Up @@ -67,9 +63,7 @@ def __call__(self, doc: Doc) -> Doc:
# rule-based matching based on structure of dosage - HIE medication e.g. take 2 every day, 24 tablets
expression = r"(?P<dose_string>start [\w\s,-]+ ), (?P<total_dose>\d+) (?P<unit>[a-z]+ )?$"
for match in re.finditer(expression, dose_string):
dose_string = match.group(
"dose_string"
) # remove total dose component for lookup
dose_string = match.group("dose_string") # remove total dose component for lookup
start, end = match.span("total_dose")
total_dose_span = doc.char_span(start, end, alignment_mode="contract")
total_dose_span.label_ = "DOSAGE"
Expand All @@ -81,9 +75,7 @@ def __call__(self, doc: Doc) -> Doc:
unit_span = doc.char_span(start, end, alignment_mode="contract")
unit_span.label_ = "FORM"
unit_span._.total_dose = True
doc._.results[
"units"
] = unit_span.text # set unit in results dict as well
doc._.results["units"] = unit_span.text # set unit in results dict as well
new_entities.append(unit_span)

# lookup patterns from CALIBERdrugdose - returns dosage results in doc._.results attribute
Expand Down
7 changes: 2 additions & 5 deletions src/miade/drugdoseade/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def __call__(self, doc: Doc) -> Doc:

# remove numbers relating to strength of med e.g. aspirin 200mg tablets...
processed_text = re.sub(
r" (\d+\.?\d*) (mg|ml|g|mcg|microgram|gram|%)"
r"(\s|/)(tab|cap|gel|cream|dose|pessaries)",
r" (\d+\.?\d*) (mg|ml|g|mcg|microgram|gram|%)" r"(\s|/)(tab|cap|gel|cream|dose|pessaries)",
"",
processed_text,
)
Expand All @@ -102,9 +101,7 @@ def __call__(self, doc: Doc) -> Doc:
if replacement == " ":
log.debug(f"Removed multiword match '{words}'")
else:
log.debug(
f"Replaced multiword match '{words}' with '{replacement}'"
)
log.debug(f"Replaced multiword match '{words}' with '{replacement}'")
processed_text = new_text

# numbers replace 2
Expand Down
4 changes: 1 addition & 3 deletions src/miade/drugdoseade/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def numbers_replace(text):
text,
)
# 3 weeks...
text = re.sub(
r" ([\d.]+) (week) ", lambda m: " {:g} days ".format(int(m.group(1)) * 7), text
)
text = re.sub(r" ([\d.]+) (week) ", lambda m: " {:g} days ".format(int(m.group(1)) * 7), text)
# 3 months ... NB assume 30 days in a month
text = re.sub(
r" ([\d.]+) (month) ",
Expand Down
10 changes: 3 additions & 7 deletions src/miade/metaannotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"substance_category": SubstanceCategory,
"reaction_pos": ReactionPos,
"allergy_type": AllergyType,
"severity": Severity
"severity": Severity,
}


Expand All @@ -20,7 +20,7 @@ class MetaAnnotations(BaseModel):
value: Enum
confidence: Optional[float]

@validator('value', pre=True)
@validator("value", pre=True)
def validate_value(cls, value, values):
enum_dict = META_ANNS_DICT
if isinstance(value, str):
Expand All @@ -36,8 +36,4 @@ def validate_value(cls, value, values):
return value

def __eq__(self, other):
return (
self.name == other.name
and self.value == other.value
)

return self.name == other.name and self.value == other.value
14 changes: 4 additions & 10 deletions src/miade/model_builders/cdbbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,18 @@ def preprocess_snomed(self, output_dir: Path = Path.cwd()) -> Path:
print("Exporting preprocessed SNOMED to csv...")

if self.snomed_subset_path is not None:
snomed_subset = pd.read_csv(
str(self.snomed_subset_path), header=0, dtype={"cui": object}
)
snomed_subset = pd.read_csv(str(self.snomed_subset_path), header=0, dtype={"cui": object})
else:
snomed_subset = None

if self.snomed_exclusions_path is not None:
snomed_exclusions = pd.read_csv(
str(self.snomed_exclusions_path), sep="\n", header=None
)
snomed_exclusions = pd.read_csv(str(self.snomed_exclusions_path), sep="\n", header=None)
snomed_exclusions.columns = ["cui"]
else:
snomed_exclusions = None

output_file = output_dir / Path("preprocessed_snomed.csv")
df = self.snomed.to_concept_df(
subset_list=snomed_subset, exclusion_list=snomed_exclusions
)
df = self.snomed.to_concept_df(subset_list=snomed_subset, exclusion_list=snomed_exclusions)
df.to_csv(str(output_file), index=False)
return output_file

Expand Down Expand Up @@ -105,7 +99,7 @@ def preprocess(self):
if self.elg_data_path:
self.vocab_files.append(str(self.preprocess_elg(self.temp_dir)))
if self.custom_data_paths:
string_paths = [str(path) for path in self.custom_data_paths]
string_paths = [str(path) for path in self.custom_data_paths]
self.vocab_files.extend(string_paths)

def create_cdb(self) -> CDB:
Expand Down
Loading

0 comments on commit 6d103a5

Please sign in to comment.