Skip to content

Commit

Permalink
Added linker - we got something working baby
Browse files Browse the repository at this point in the history
  • Loading branch information
jenniferjiangkells committed Oct 24, 2024
1 parent 7a9bc21 commit 0c51696
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 161 deletions.
147 changes: 81 additions & 66 deletions healthchain/pipeline/models/medcatlite/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,18 @@ def __init__(self, cdb: CDB, vocab: Vocab, config: Config) -> None:
self.cdb = cdb
self.vocab = vocab
self.config = config
logger.debug("ContextModel initialized")

def get_word_vector(self, word: str) -> Optional[np.ndarray]:
"""
Get the vector representation of a word from the vocabulary.
Args:
word (str): The word to look up.
Returns:
Optional[np.ndarray]: The vector representation of the word, or None if not found.
"""
return self.vocab.get(word, {}).get("vec")

def get_context_tokens(
self, entity: Span, doc: Doc, size: int
Expand Down Expand Up @@ -85,7 +96,9 @@ def get_context_vectors(
Dict[str, np.ndarray]: The context vectors.
"""
vectors = {}
for context_type, size in self.config.linking["context_vector_sizes"].items():
for context_type, size in self.config["linking"][
"context_vector_sizes"
].items():
tokens_left, tokens_center, tokens_right = self.get_context_tokens(
entity, doc, size
)
Expand Down Expand Up @@ -117,9 +130,9 @@ def _get_token_vectors(
range(len(tokens)) if not reverse else range(len(tokens) - 1, -1, -1)
)
vectors = [
self.cdb.weighted_average_function(step) * self.vocab.vec(t.lower_)
self.cdb["weighted_average_function"](step) * vec
for step, t in zip(step_range, tokens)
if t.lower_ in self.vocab and self.vocab.vec(t.lower_) is not None
if (vec := self.get_word_vector(t.lower_)) is not None
]
logger.debug(f"Generated {len(vectors)} token vectors")
return vectors
Expand All @@ -137,29 +150,28 @@ def _get_center_token_vectors(
Returns:
List[np.ndarray]: The list of center token vectors.
"""
if self.config.linking["context_ignore_center_tokens"]:
if self.config["linking"]["context_ignore_center_tokens"]:
logger.debug("Ignoring center tokens")
return []

if (
cui
and random.random() > self.config.linking["random_replacement_unsupervised"]
and self.cdb.cui2names.get(cui)
and random.random()
> self.config["linking"]["random_replacement_unsupervised"]
and self.cdb["cui2names"].get(cui)
):
new_tokens = random.choice(list(self.cdb.cui2names[cui])).split(
self.config.general["separator"]
new_tokens = random.choice(list(self.cdb["cui2names"][cui])).split(
self.config["general"]["separator"]
)
vectors = [
self.vocab.vec(t)
for t in new_tokens
if t in self.vocab and self.vocab.vec(t) is not None
vec for t in new_tokens if (vec := self.get_word_vector(t)) is not None
]
logger.debug(f"Using {len(vectors)} CUI-based center token vectors")
else:
vectors = [
self.vocab.vec(t.lower_)
vec
for t in tokens
if t.lower_ in self.vocab and self.vocab.vec(t.lower_) is not None
if (vec := self.get_word_vector(t.lower_)) is not None
]
logger.debug(f"Using {len(vectors)} original center token vectors")

Expand Down Expand Up @@ -193,19 +205,19 @@ def _similarity(self, cui: str, vectors: Dict[str, np.ndarray]) -> float:
Returns:
float: The similarity score.
"""
cui_vectors = self.cdb.cui2context_vectors.get(cui, {})
cui_vectors = self.cdb["cui2context_vectors"].get(cui, {})
if (
not cui_vectors
or self.cdb.cui2count_train[cui]
< self.config.linking["train_count_threshold"]
or self.cdb["cui2count_train"][cui]
< self.config["linking"]["train_count_threshold"]
):
logger.debug(f"Insufficient training data for CUI {cui}")
return -1

similarity = sum(
self.config.linking["context_vector_weights"][context_type]
self.config["linking"]["context_vector_weights"][context_type]
* np.dot(unitvec(vectors[context_type]), unitvec(cui_vectors[context_type]))
for context_type in self.config.linking["context_vector_weights"]
for context_type in self.config["linking"]["context_vector_weights"]
if context_type in vectors and context_type in cui_vectors
)

Expand All @@ -228,23 +240,23 @@ def disambiguate(
Tuple[Optional[str], float]: The selected CUI and its similarity score.
"""
vectors = self.get_context_vectors(entity, doc)
filters = self.config.linking["filters"]
# filters = self.config["linking"]["filters"]

if self.config.linking["filter_before_disamb"]:
cuis = [cui for cui in cuis if filters.check_filters(cui)]
logger.debug(f"Filtered CUIs: {cuis}")
# if self.config["linking"]["filter_before_disamb"]:
# cuis = [cui for cui in cuis if filters.check_filters(cui)]
# logger.debug(f"Filtered CUIs: {cuis}")

if not cuis:
logger.debug("No CUIs left after filtering")
return None, 0
# if not cuis:
# logger.debug("No CUIs left after filtering")
# return None, 0

similarities = [self._similarity(cui, vectors) for cui in cuis]
logger.debug(f"Initial similarities: {list(zip(cuis, similarities))}")

if self.config.linking.get("prefer_primary_name", 0) > 0:
if self.config["linking"].get("prefer_primary_name", 0) > 0:
self._adjust_primary_name_similarities(cuis, similarities, name)

if self.config.linking.get("prefer_frequent_concepts", 0) > 0:
if self.config["linking"].get("prefer_frequent_concepts", 0) > 0:
self._adjust_frequent_concept_similarities(cuis, similarities)

max_index = np.argmax(similarities)
Expand All @@ -265,14 +277,14 @@ def _adjust_primary_name_similarities(
name (str): The name of the entity.
"""
for i, cui in enumerate(cuis):
if similarities[i] > 0 and self.cdb.name2cuis2status.get(name, {}).get(
if similarities[i] > 0 and self.cdb["name2cuis2status"].get(name, {}).get(
cui, ""
) in {"P", "PD"}:
old_sim = similarities[i]
similarities[i] = min(
0.99,
similarities[i]
* (1 + self.config.linking.get("prefer_primary_name", 0)),
* (1 + self.config["linking"].get("prefer_primary_name", 0)),
)
logger.debug(
f"Adjusted similarity for primary name {name}, CUI {cui}: {old_sim} -> {similarities[i]}"
Expand All @@ -288,11 +300,11 @@ def _adjust_frequent_concept_similarities(
cuis (List[str]): The list of candidate CUIs.
similarities (List[float]): The list of similarity scores.
"""
counts = [self.cdb.cui2count_train.get(cui, 0) for cui in cuis]
counts = [self.cdb["cui2count_train"].get(cui, 0) for cui in cuis]
min_count = max(min(counts), 1)
scales = [
np.log10(count / min_count)
* self.config.linking.get("prefer_frequent_concepts", 0)
* self.config["linking"].get("prefer_frequent_concepts", 0)
if count > 10
else 0
for count in counts
Expand All @@ -305,17 +317,13 @@ def _adjust_frequent_concept_similarities(
)


@Language.factory("medcat_linker")
def create_linker(
nlp: Language, name: str, cdb: CDB, vocab: Vocab, config: Dict[str, Any]
):
return Linker(nlp, name, cdb, vocab, config)
@Language.factory("medcatlite_linker")
def create_linker(nlp: Language, name: str, linker_resources: Dict[str, Any]):
return Linker(nlp, name, linker_resources)


class Linker:
def __init__(
self, nlp: Language, name: str, cdb: CDB, vocab: Vocab, config: Dict[str, Any]
):
def __init__(self, nlp: Language, name: str, linker_resources: Dict[str, Any]):
"""
Initialize the Linker class.
Expand All @@ -327,9 +335,9 @@ def __init__(
config (Dict[str, Any]): The configuration dictionary.
"""
self.name = name
self.cdb = cdb
self.vocab = vocab
self.config = config
self.cdb = linker_resources["cdb"]
self.vocab = linker_resources["vocab"]
self.config = linker_resources["config"]
self.context_model = ContextModel(self.cdb, self.vocab, self.config)
self._setup_extensions()

Expand All @@ -338,12 +346,9 @@ def _setup_extensions(self):
Set up custom extensions for Doc and Span objects.
"""
custom_extensions = {
Doc: [("ents", [])],
Span: [
("detected_name", None),
("link_candidates", []),
("cui", None),
("context_similarity", None),
("cui", -1),
("context_similarity", -1),
],
}
for obj, extensions in custom_extensions.items():
Expand Down Expand Up @@ -426,7 +431,7 @@ def _should_disambiguate(self, name: str, cuis: List[str]) -> bool:
len(name) < cnf_l["disamb_length_limit"]
or (
len(cuis) == 1
and self.cdb.name2cuis2status[name][cuis[0]] in {"PD", "N"}
and self.cdb["name2cuis2status"][name][cuis[0]] in {"PD", "N"}
)
or len(cuis) > 1
)
Expand Down Expand Up @@ -460,7 +465,7 @@ def _is_valid_entity(self, cui: Optional[str], context_similarity: float) -> boo
"""
return (
cui is not None
and self._check_filters(cui)
# and self._check_filters(cui)
and self._check_similarity_threshold(cui, context_similarity)
)

Expand Down Expand Up @@ -494,7 +499,8 @@ def _check_similarity_threshold(self, cui: str, context_similarity: float) -> bo
return context_similarity >= threshold
if th_type == "dynamic":
return (
context_similarity >= self.cdb.cui2average_confidence[cui] * threshold
context_similarity
>= self.cdb["cui2average_confidence"][cui] * threshold
)

logger.warning(f"Unknown similarity threshold type: {th_type}")
Expand Down Expand Up @@ -523,15 +529,33 @@ def _post_process(self, doc: Doc):
self._apply_pretty_labels(doc)
self._map_entities_to_groups(doc)

def _create_main_ann(self, doc: Doc):
def _create_main_ann(self, doc: Doc, tuis: Optional[List[str]] = None):
"""
Create the main annotation for the document.
Args:
doc (Doc): The spaCy document object.
tuis (Optional[List[str]]): List of Type Unique Identifiers to filter entities.
Returns:
None: Modifies the doc.ents in-place.
"""
# Implement main annotation creation logic here
pass
# Sort entities by length (longest first) to prioritize longer matches
doc._.ents.sort(key=lambda x: len(x.text), reverse=True)

tokens_covered = set()
main_annotations = []

for entity in doc._.ents:
if tuis is None or entity._.tui in tuis:
# Check if any token in the entity is already covered
if not any(token in tokens_covered for token in entity):
# Add all tokens of this entity to the covered set
tokens_covered.update(entity)
main_annotations.append(entity)

# Update doc.ents with the new main annotations
doc.ents = list(doc.ents) + main_annotations

def _apply_pretty_labels(self, doc: Doc):
"""
Expand Down Expand Up @@ -562,17 +586,8 @@ def _map_entities_to_groups(self, doc: Doc):
Args:
doc (Doc): The spaCy document object.
"""
if self.config["general"]["map_cui_to_group"] and self.cdb.addl_info.get(
if self.config["general"]["map_cui_to_group"] and self.cdb["addl_info"].get(
"cui2group"
):
self._map_ents_to_groups(doc)

def _map_ents_to_groups(self, doc: Doc):
"""
Map entities to groups based on CUI.
Args:
doc (Doc): The spaCy document object.
"""
# Implement entity to group mapping logic here
pass
for ent in doc.ents:
ent._.cui = self.cdb["addl_info"]["cui2group"].get(ent._.cui, ent._.cui)
22 changes: 18 additions & 4 deletions healthchain/pipeline/models/medcatlite/medcatlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
# ruff: noqa
from .tokenprocessor import TokenProcessor
from .ner import NER
from .registry import create_token_processor_resources, create_ner_resources
from .linker import Linker
from .registry import (
create_token_processor_resources,
create_ner_resources,
create_linker_resources,
)


logger = logging.getLogger(__name__)
Expand All @@ -39,7 +44,7 @@ def __init__(
self.cdb = cdb
self.vocab = vocab
self.nlp = None
self.create_pipeline()
self._create_pipeline()

@classmethod
def load_model_pack(cls, model_path: str):
Expand All @@ -64,7 +69,7 @@ def load_model_pack(cls, model_path: str):

return cls(cdb, vocab, config)

def create_pipeline(self) -> Language:
def _create_pipeline(self) -> Language:
if self.config is None:
raise ValueError("Config not loaded. Call load_model_pack() first.")

Expand Down Expand Up @@ -109,7 +114,16 @@ def create_pipeline(self) -> Language:
},
)

# self.nlp.add_pipe('medcat_linker', config={'cdb': self.cdb, 'vocab': self.vocab, 'config': self.config})
self.nlp.add_pipe(
"medcatlite_linker",
config={
"linker_resources": {
"@misc": "medcatlite.linker_resources",
"cdb": {"@misc": "medcatlite_cdb"},
"vocab": {"@misc": "medcatlite_vocab"},
}
},
)

return self.nlp

Expand Down
Loading

0 comments on commit 0c51696

Please sign in to comment.