Skip to content

Commit

Permalink
Handle empty candidates in Predictions dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
thobson88 committed Dec 20, 2024
1 parent 6d7d48c commit 0fe401d
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 4 deletions.
11 changes: 7 additions & 4 deletions t_res/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,10 @@ def candidates_str(self, candidates: MentionCandidates, pad_mention: int=0, pad_
s += f"None"
return s

def candidates(self) -> List[MentionCandidates]:
return [c for sc in self.sentence_candidates for c in sc.candidates]
def candidates(self, ignore_empty_candidates: bool=True) -> List[MentionCandidates]:
if ignore_empty_candidates:
return [c for sc in self.sentence_candidates for c in sc.candidates]
return [c for sc in self.sentence_candidates for c in sc.candidates if not c.is_empty()]

def is_empty(self, ignore_empty_candidates: bool=True) -> bool:
if ignore_empty_candidates:
Expand Down Expand Up @@ -785,13 +787,14 @@ def __post_init__(self):
Got {len(self.rel_scores)} instances and {len(super().candidates())} mentions.""")

# Override the candidates method to return REL linking predictions.
def candidates(self) -> List[MentionCandidates]:
def candidates(self, ignore_empty_candidates: bool=True) -> List[MentionCandidates]:

# Construct equivalent Candidate instances but with the REL scores in the PredictedLinks.
ret = list()
for c, rs in zip(super().candidates(), self.rel_scores):
if c.is_empty():
ret.append(c)
if not ignore_empty_candidates:
ret.append(c)
continue
predicted_links = c.best_match()
# Get the list of WikidataLink instances for which REL scores are available.
Expand Down
106 changes: 106 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,112 @@ def test_deezy_rel_wpubl_wmtops(tmp_path):
# assert predictions.candidates()[0].best_match().best_disambiguation_score() == 0.039 # TODO: reproduce this number.
assert predictions.candidates()[0].mention.ner_score == 1.0

@pytest.mark.resources(reason="Needs large resources")
def test_deezy_rel_wpubl(tmp_path):
model_path = os.path.join(current_dir, "../resources/models/")
assert os.path.isdir(model_path) is True

recogniser = ner.CustomRecogniser(
model_name="blb_lwm-ner-fine",
train_dataset=os.path.join(current_dir,"sample_files/experiments/outputs/data/lwm/ner_fine_train.json"),
test_dataset=os.path.join(current_dir,"sample_files/experiments/outputs/data/lwm/ner_fine_dev.json"),
pipe=None,
base_model="khosseini/bert_1760_1900", # Base model to fine-tune
model_path=model_path,
training_args={
"batch_size": 8,
"num_train_epochs": 10,
"learning_rate": 0.00005,
"weight_decay": 0.0,
},
overwrite_training=False, # Set to True if you want to overwrite model if existing
do_test=False, # Set to True if you want to train on test mode
)

# --------------------------------------
# Instantiate the ranker:
ranker = ranking.DeezyMatchRanker(
resources_path=os.path.join(current_dir, "../resources/"),
mentions_to_wikidata=dict(),
wikidata_to_mentions=dict(),
strvar_parameters={
# Parameters to create the string pair dataset:
"ocr_threshold": 60,
"top_threshold": 85,
"min_len": 5,
"max_len": 15,
"w2v_ocr_path": str(tmp_path),
"w2v_ocr_model": "w2v_1800s_news",
"overwrite_dataset": False,
},
deezy_parameters={
# Paths and filenames of DeezyMatch models and data:
"dm_path": os.path.join(current_dir,"../resources/deezymatch/"),
"dm_cands": "wkdtalts",
"dm_model": "w2v_ocr",
"dm_output": "deezymatch_on_the_fly",
# Ranking measures:
"ranking_metric": "faiss",
"selection_threshold": 50,
"num_candidates": 1,
"verbose": False,
# DeezyMatch training:
"overwrite_training": False,
"do_test": False,
},
)

with sqlite3.connect(os.path.join(current_dir, "../resources/rel_db/embeddings_database.db")) as conn:
cursor = conn.cursor()
linker = linking.RelDisambLinker(
resources_path=os.path.join(current_dir, "../resources/"),
ranker=ranker,
linking_resources=dict(),
rel_params={
"model_path": os.path.join(current_dir,"../resources/models/disambiguation/"),
"data_path": os.path.join(current_dir,"sample_files/experiments/outputs/data/lwm/"),
"training_split": "originalsplit",
"db_embeddings": cursor,
"with_publication": True,
"without_microtoponyms": False,
"do_test": False,
"default_publname": "United Kingdom",
"default_publwqid": "Q145",
},
overwrite_training=False,
)

geoparser = pipeline.Pipeline(recogniser=recogniser, ranker=ranker, linker=linker)
text = "The charming seaside town of Swanage is noted for its Town Hall whose distinctive façade was designed by Edward Jerman, a pupil of Sir Christopher Wren. Also the Grosvenor Hotel, with its clock tower originally erected at the south end of London Bridge as a memorial to the Duke of Wellington."

# Test with microtoponyms.
predictions = geoparser.run(text, place_of_pub_wqid="Q203349", place_of_pub="Poole, Dorset")
assert isinstance(predictions, RelPredictions)

# When the "without_microtoponyms" parameter set to False, there are four candidates:
assert len(predictions.candidates()) == 4
assert predictions.candidates()[0].mention.mention == "Swanage"
assert predictions.candidates()[1].mention.mention == "Town Hall"
assert predictions.candidates()[2].mention.mention == "Grosvenor Hotel"
assert predictions.candidates()[3].mention.mention == "London Bridge"

# Test without microtoponyms.
geoparser.linker.rel_params["without_microtoponyms"] = True
predictions = geoparser.run(text, place_of_pub_wqid="Q203349", place_of_pub="Poole, Dorset")
assert isinstance(predictions, RelPredictions)

# When the "without_microtoponyms" parameter set to False, only one candidate remains:
assert len(predictions.candidates()) == 1
assert predictions.candidates()[0].mention.mention == "Swanage"

# The MentionCandidates for the microtoponyms still exist, but they are empty
# (i.e. contain no candidate links) because the Linker was configured to ignore them.
assert len(predictions.candidates(ignore_empty_candidates=False)) == 4
assert predictions.candidates(ignore_empty_candidates=False)[0].mention.mention == "Swanage"
assert predictions.candidates(ignore_empty_candidates=False)[1].mention.mention == "Town Hall"
assert predictions.candidates(ignore_empty_candidates=False)[2].mention.mention == "Grosvenor Hotel"
assert predictions.candidates(ignore_empty_candidates=False)[3].mention.mention == "London Bridge"

@pytest.mark.resources(reason="Needs large resources")
def test_perfect_rel_wpubl_wmtops():
model_path = os.path.join(current_dir, "../resources/models/")
Expand Down

0 comments on commit 0fe401d

Please sign in to comment.