-
Notifications
You must be signed in to change notification settings - Fork 201
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
217 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from itertools import chain | ||
|
||
import torch | ||
from torch.distributions import Categorical | ||
from transformers import AutoTokenizer | ||
|
||
from data_juicer.format import load_formatter | ||
|
||
|
||
class TextTokenDistCollector(object): | ||
"""Tokenize and collect distribution of tokens for given | ||
dataset with a specified tokenizer. | ||
""" | ||
|
||
def __init__(self, tokenizer): | ||
""" | ||
Initialization method. | ||
:param tokenizer: tokenizer name on huggingface | ||
""" | ||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, | ||
trust_remote_code=True) | ||
self.vocab_size = len(self.tokenizer) | ||
|
||
def collect(self, data_path, text_key, num_proc=1) -> 'Categorical': | ||
""" | ||
Tokenize and collect tokens distribution of input dataset | ||
:param data_path: path to input dataset. | ||
:param text_key: field keys that will be considered into token counts. | ||
:param num_proc: number of processes to count tokens. | ||
:return: token distribution. | ||
""" | ||
|
||
formatter = load_formatter(data_path) | ||
dataset = formatter.load_dataset(num_proc=num_proc) | ||
assert text_key in dataset.features, f'[{text_key} not find in dataset' | ||
|
||
def prepare_tokenizer( | ||
tokenizer, | ||
text_key, | ||
): | ||
""" | ||
Prepare a tokenizer function for dataset. | ||
:param tokenizer: a tokenizer to tokenize sample. | ||
:param text_key: field keys that will be | ||
considered into token counts. | ||
""" | ||
|
||
def _tokenize_fn(example, ): | ||
example = tokenizer(example[text_key], | ||
add_special_tokens=False) | ||
return example | ||
|
||
return _tokenize_fn | ||
|
||
tokenize_proc = prepare_tokenizer(self.tokenizer, text_key) | ||
dataset = dataset.map(tokenize_proc, | ||
num_proc=num_proc, | ||
desc=f'tokenize {data_path.split("/")[-1]}') | ||
|
||
token_count = torch.zeros(self.vocab_size, dtype=torch.int64) | ||
token_ids = torch.tensor( | ||
list(chain.from_iterable(dataset['input_ids']))) | ||
indices, counts = token_ids.unique(return_counts=True) | ||
token_count.scatter_(0, indices, counts.to(token_count.dtype)) | ||
dist = Categorical(token_count) | ||
return dist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import seaborn as sns | ||
|
||
|
||
def draw_heatmap(data, xlabels, ylables=None, figsize=None, triangle=False): | ||
""" | ||
Draw heatmap of input data with special lables. | ||
:param data: input data, now support | ||
[`list`, `tuple`, `numpy array`, 'torch tensor'] | ||
:param xlabels: x axis labels. | ||
:param ylabels: y axis labels, if None, use xlabels. | ||
:param figsize: figure size. | ||
:param triangle: only display triangle. | ||
:return: a plot figure. | ||
""" | ||
figsize = figsize if figsize else (8 * 2.5, 6 * 2.5) | ||
_, ax = plt.subplots(figsize=figsize) | ||
mask = None | ||
if triangle: | ||
mask = np.triu(np.ones_like(data)) | ||
ax.tick_params( | ||
right=True, | ||
top=True, | ||
labelright=True, | ||
labeltop=True, | ||
) | ||
sns.heatmap(data, | ||
ax=ax, | ||
cmap='Oranges', | ||
annot=True, | ||
mask=mask, | ||
linewidths=.05, | ||
square=True, | ||
xticklabels=xlabels, | ||
yticklabels=ylables, | ||
annot_kws={'size': 8}) | ||
plt.subplots_adjust(left=.1, right=0.95, bottom=0.22, top=0.95) | ||
fig = plt.gcf() | ||
plt.show() | ||
return fig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import Tensor | ||
from torch.distributions import Categorical | ||
|
||
|
||
class Measure(object): | ||
"""Base class for Measure distribution. | ||
""" | ||
name = 'base' | ||
|
||
def measure(self, *args, **kwargs): | ||
pass | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.measure(*args, **kwargs) | ||
|
||
def _convert_to_tensor(self, p): | ||
""" | ||
Convert input data to torch tensor. | ||
:param p: input data, now support | ||
[`scalar`,`list`, `tuple`, `torch binary file`, and `Categorical`]. | ||
:return: torch tensor | ||
""" | ||
if isinstance(p, Tensor): | ||
return p | ||
elif isinstance(p, Categorical): | ||
return p.probs | ||
elif isinstance(p, str): | ||
return torch.load(p) | ||
else: | ||
return torch.tensor(p) | ||
|
||
def _convert_to_categorical(self, p): | ||
""" | ||
Convert input data to torch Categorical. | ||
:param p: input data, now support | ||
[`scalar`,`list`, `tuple`, `torch binary file`, and `Categorical`]. | ||
:return: torch Categorical | ||
""" | ||
if isinstance(p, Categorical): | ||
return p | ||
elif isinstance(p, Tensor): | ||
return Categorical(p) | ||
elif isinstance(p, str): | ||
return Categorical(torch.load(p)) | ||
else: | ||
return Categorical(torch.tensor(p)) | ||
|
||
|
||
class KLDivMeasure(Measure): | ||
""" | ||
Measure Kullback-Leibler divergence. | ||
""" | ||
name = 'kl_divergence' | ||
|
||
def measure(self, p, q): | ||
p = self._convert_to_categorical(p) | ||
q = self._convert_to_categorical(q) | ||
assert p.probs.shape == q.probs.shape, \ | ||
'The two inputs have different shape:' \ | ||
f'{p.probs.shape} != {q.probs.shape} in {self.name}' | ||
return F.kl_div(q.logits, p.probs, log_target=False, reduction='sum') | ||
|
||
|
||
class JSDivMeasure(Measure): | ||
""" | ||
Measure Jensen-Shannon divergence. | ||
""" | ||
name = 'js_divergence' | ||
|
||
def measure(self, p, q): | ||
p = self._convert_to_tensor(p) | ||
q = self._convert_to_tensor(q) | ||
assert p.shape == q.shape, \ | ||
'The two inputs have different shape:' \ | ||
f'{p.shape} != {q.shape} in {self.name}' | ||
|
||
m = 0.5 * (p + q) | ||
kl_p = KLDivMeasure()(p, m) | ||
kl_q = KLDivMeasure()(q, m) | ||
js = 0.5 * (kl_p + kl_q) | ||
return js | ||
|
||
|
||
class CrossEntropyMeasure(Measure): | ||
""" | ||
Measure Cross-Entropy. | ||
""" | ||
name = 'cross_entropy' | ||
|
||
def measure(self, p, q): | ||
p = self._convert_to_categorical(p) | ||
q = self._convert_to_categorical(q) | ||
assert p.probs.shape == q.probs.shape, \ | ||
'The two inputs have different shape: '\ | ||
f'{p.probs.shape} != {q.probs.shape} in {self.name}' | ||
return F.cross_entropy(q.logits, p.probs, reduction='sum') | ||
|
||
|
||
class EntropyMeasure(Measure): | ||
""" | ||
Measure Entropy. | ||
""" | ||
name = 'entropy' | ||
|
||
def measure(self, p): | ||
p = self._convert_to_categorical(p) | ||
return p.entropy() |