-
Notifications
You must be signed in to change notification settings - Fork 333
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
使用Mixin设计模式进行重构,给主文件瘦身,并调整代码层级结构,提高可读性
- Loading branch information
1 parent
642320a
commit bdd1098
Showing
19 changed files
with
808 additions
and
721 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import networkx as nx | ||
from itertools import combinations | ||
|
||
class EntNetworkMixin: | ||
""" | ||
实体网络模块: | ||
- 根据实体在文档中的共现关系 | ||
- 建立全局社交网络 | ||
- 建立以某一个实体为中心的社交网络 | ||
""" | ||
def build_entity_graph(self, docs, min_freq=0, inv_index={}, used_types=[]): | ||
G = nx.Graph() | ||
links = {} | ||
if len(inv_index) == 0: | ||
for i, sent in enumerate(docs): | ||
entities_info = self.entity_linking(sent) | ||
if len(used_types) == 0: | ||
entities = set(entity for span, (entity, type0) in entities_info) | ||
else: | ||
entities = set(entity for span, (entity, type0) in entities_info if type0[1:-1] in used_types) | ||
for u, v in combinations(entities, 2): | ||
pair0 = tuple(sorted((u, v))) | ||
if pair0 not in links: | ||
links[pair0] = 1 | ||
else: | ||
links[pair0] += 1 | ||
else: # 已经有倒排文档,可以更快速检索 | ||
if len(used_types) == 0: | ||
entities = self.entity_type_dict.keys() | ||
else: | ||
entities = iter(entity for (entity, type0) in self.entity_type_dict.items() if type0 in used_types) | ||
for u, v in combinations(entities, 2): | ||
pair0 = tuple(sorted((u, v))) | ||
ids = inv_index[u] & inv_index[v] | ||
if len(ids) > 0: | ||
links[pair0] = len(ids) | ||
for (u, v) in links: | ||
if links[(u, v)] >= min_freq: | ||
G.add_edge(u, v, weight=links[(u, v)]) | ||
self.entity_graph = G | ||
return G | ||
|
||
def build_word_ego_graph(self, docs, word, standard_name=True, min_freq=0, other_min_freq=-1, stopwords=None): | ||
'''根据文本和指定限定词,获得以限定词为中心的各词语的关系。 | ||
限定词可以是一个特定的方面(衣食住行这类文档),这样就可以从词语中心图中获得关于这个方面的简要信息 | ||
:param docs: 文本的列表 | ||
:param word: 限定词 | ||
:param standard_name: 把所有实体的指称化为标准实体名 | ||
:param stopwords: 需要过滤的停用词 | ||
:param min_freq: 作为边加入到图中的与中心词最小共现次数,用于筛掉可能过多的边 | ||
:param other_min_freq: 中心词以外词语关系的最小共现次数 | ||
:return: G(networxX中的Graph) | ||
''' | ||
G = nx.Graph() | ||
links = {} | ||
if other_min_freq == -1: | ||
other_min_freq = min_freq | ||
for doc in docs: | ||
if stopwords: | ||
words = set(x for x in self.seg(doc, standard_name=standard_name) if x not in stopwords) | ||
else: | ||
words = self.seg(doc, standard_name=standard_name) | ||
if word in words: | ||
for u, v in combinations(words, 2): | ||
pair0 = tuple(sorted((u, v))) | ||
if pair0 not in links: | ||
links[pair0] = 1 | ||
else: | ||
links[pair0] += 1 | ||
|
||
used_nodes = set([word]) # 关系对中涉及的词语必须与实体有关(>= min_freq) | ||
for (u, v) in links: | ||
w = links[(u, v)] | ||
if word in (u, v) and w >= min_freq: | ||
used_nodes.add(v if word == u else u) | ||
G.add_edge(u, v, weight=w) | ||
elif w >= other_min_freq: | ||
G.add_edge(u, v, weight=w) | ||
G = G.subgraph(used_nodes).copy() | ||
return G | ||
|
||
def build_entity_ego_graph(self, docs, word, min_freq=0, other_min_freq=-1, inv_index={}, used_types=[]): | ||
'''Entity only version of build_word_ego_graph() | ||
''' | ||
G = nx.Graph() | ||
links = {} | ||
if other_min_freq == -1: | ||
other_min_freq = min_freq | ||
if len(inv_index) != 0: | ||
related_docs = self.search_entity(word, docs, inv_index) | ||
else: | ||
related_docs = [] | ||
for doc in docs: | ||
entities_info = self.entity_linking(doc) | ||
entities = [entity0 for [[l,r], (entity0,type0)] in entities_info] | ||
if word in entities: | ||
related_docs.append(doc) | ||
|
||
for i, sent in enumerate(related_docs): | ||
entities_info = self.entity_linking(sent) | ||
if len(used_types) == 0: | ||
entities = set(entity for span, (entity, type0) in entities_info) | ||
else: | ||
entities = set(entity for span, (entity, type0) in entities_info if type0[1:-1] in used_types) | ||
for u, v in combinations(entities, 2): | ||
pair0 = tuple(sorted((u, v))) | ||
if pair0 not in links: | ||
links[pair0] = 1 | ||
else: | ||
links[pair0] += 1 | ||
|
||
used_nodes = set([word]) # 关系对中涉及的词语必须与实体有关(>= min_freq) | ||
for (u, v) in links: | ||
w = links[(u, v)] | ||
if word in (u, v) and w >= min_freq: | ||
used_nodes.add(v if word == u else u) | ||
G.add_edge(u, v, weight=w) | ||
elif w >= other_min_freq: | ||
G.add_edge(u, v, weight=w) | ||
G = G.subgraph(used_nodes).copy() | ||
return G |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import numpy as np | ||
from collections import defaultdict | ||
|
||
class EntRetrieveMixin: | ||
""" | ||
实体检索模块: | ||
- 基于倒排索引快速检索包括某个实体的文档,以及统计出现某实体的文档数目 | ||
""" | ||
def build_index(self, docs, with_entity=True, with_type=True): | ||
inv_index = defaultdict(set) | ||
for i, sent in enumerate(docs): | ||
entities_info = self.entity_linking(sent) | ||
for span, (entity, type0) in entities_info: | ||
if with_entity: | ||
inv_index[entity].add(i) | ||
if with_type: | ||
inv_index[type0].add(i) | ||
return inv_index | ||
|
||
def get_entity_counts(self, docs, inv_index, used_type=[]): | ||
if len(used_type) > 0: | ||
entities = iter(x for x in self.entity_type_dict | ||
if self.entity_type_dict[x] in used_type) | ||
else: | ||
entities = self.entity_type_dict.keys() | ||
cnt = {enty: len(inv_index[enty]) for enty in entities if enty in inv_index} | ||
return cnt | ||
|
||
def search_entity(self, query, docs, inv_index): | ||
words = query.split() | ||
if len(words) > 0: | ||
ids = inv_index[words[0]] | ||
for word in words[1:]: | ||
ids = ids & inv_index[word] | ||
np_docs = np.array(docs)[list(ids)] | ||
return np_docs.tolist() | ||
else: | ||
return [] |
Oops, something went wrong.