From 5a8aa663850f0d528426003f2ff5a92e8e8f1e1f Mon Sep 17 00:00:00 2001 From: Wessel Date: Mon, 20 Jul 2020 20:50:33 +0200 Subject: [PATCH 1/3] Bugfix unicode compatibility --- codes/run.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/codes/run.py b/codes/run.py index 457c6fdf..e9d5b832 100644 --- a/codes/run.py +++ b/codes/run.py @@ -9,6 +9,7 @@ import logging import os import random +import unicodedata import numpy as np import torch @@ -123,7 +124,7 @@ def read_triple(file_path, entity2id, relation2id): triples = [] with open(file_path) as fin: for line in fin: - h, r, t = line.strip().split('\t') + h, r, t = map(lambda x: x.strip(), unicodedata.normalize('NFKC', line).split('\t')) triples.append((entity2id[h], relation2id[r], entity2id[t])) return triples @@ -160,12 +161,12 @@ def log_metrics(mode, step, metrics): def main(args): if (not args.do_train) and (not args.do_valid) and (not args.do_test): - raise ValueError('one of train/val/test mode must be choosed.') + raise ValueError('one of train/val/test mode must be chosen.') if args.init_checkpoint: override_config(args) elif args.data_path is None: - raise ValueError('one of init_checkpoint/data_path must be choosed.') + raise ValueError('one of init_checkpoint/data_path must be chosen.') if args.do_train and args.save_path is None: raise ValueError('Where do you want to save your trained model?') @@ -179,13 +180,13 @@ def main(args): with open(os.path.join(args.data_path, 'entities.dict')) as fin: entity2id = dict() for line in fin: - eid, entity = line.strip().split('\t') + eid, entity = map(lambda x: x.strip(), unicodedata.normalize('NFKC', line).split('\t')) entity2id[entity] = int(eid) with open(os.path.join(args.data_path, 'relations.dict')) as fin: relation2id = dict() for line in fin: - rid, relation = line.strip().split('\t') + rid, relation = map(lambda x: x.strip(), unicodedata.normalize('NFKC', line).split('\t')) relation2id[relation] = int(rid) # Read regions for Countries S* datasets From 20c37cdb134642a51e582386fb624239bbc4a180 Mon Sep 17 00:00:00 2001 From: Wessel Date: Mon, 20 Jul 2020 20:56:32 +0200 Subject: [PATCH 2/3] Fix indentation mistake --- codes/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codes/run.py b/codes/run.py index e9d5b832..8c827a26 100644 --- a/codes/run.py +++ b/codes/run.py @@ -124,7 +124,7 @@ def read_triple(file_path, entity2id, relation2id): triples = [] with open(file_path) as fin: for line in fin: - h, r, t = map(lambda x: x.strip(), unicodedata.normalize('NFKC', line).split('\t')) + h, r, t = map(lambda x: x.strip(), unicodedata.normalize('NFKC', line).split('\t')) triples.append((entity2id[h], relation2id[r], entity2id[t])) return triples From 325bee0dad4684d3231c0f68ae83a7cbbbed3b85 Mon Sep 17 00:00:00 2001 From: Wessel Date: Thu, 23 Jul 2020 10:26:45 +0200 Subject: [PATCH 3/3] Fix root cause of trailing space key issue I mistakenly thought I was dealing with a unicode issue due to the error I received. Upon investigating closer I realized that when the entity/relations are loaded and the entire line is split, trailing spaces are removed because the names are at the end. However, when loading triples this only occurs on the tail entities. Fixed by mapping str.split() on all components when loading triples. --- codes/run.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/codes/run.py b/codes/run.py index 8c827a26..6c1ec67a 100644 --- a/codes/run.py +++ b/codes/run.py @@ -9,7 +9,6 @@ import logging import os import random -import unicodedata import numpy as np import torch @@ -124,7 +123,10 @@ def read_triple(file_path, entity2id, relation2id): triples = [] with open(file_path) as fin: for line in fin: - h, r, t = map(lambda x: x.strip(), unicodedata.normalize('NFKC', line).split('\t')) + # The entity/relation dict have the element names at the end, when loading them + # the entire line is stripped, removing any trailing spaces. As such, we need + # to strip each element individually here as well. + h, r, t = map(str.strip, line.split('\t')) triples.append((entity2id[h], relation2id[r], entity2id[t])) return triples @@ -180,13 +182,13 @@ def main(args): with open(os.path.join(args.data_path, 'entities.dict')) as fin: entity2id = dict() for line in fin: - eid, entity = map(lambda x: x.strip(), unicodedata.normalize('NFKC', line).split('\t')) + eid, entity = line.strip().split('\t') entity2id[entity] = int(eid) with open(os.path.join(args.data_path, 'relations.dict')) as fin: relation2id = dict() for line in fin: - rid, relation = map(lambda x: x.strip(), unicodedata.normalize('NFKC', line).split('\t')) + rid, relation = line.strip().split('\t') relation2id[relation] = int(rid) # Read regions for Countries S* datasets