-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsummarize.py
261 lines (205 loc) · 9.36 KB
/
summarize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import os
import csv
import numpy
import torch
import random
import logging
import argparse
import warnings
from typing import Any
from datetime import datetime
from datasets import load_metric
# :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
def check_file_path(file_path: str) -> str:
"""
Checks if the provided path exists and correspond to a file.
:param file_path: A string with the file path to an input document passed as argument in the command line.
:return: A string with the validated input doc file path.
"""
current_dir = os.path.dirname(__file__)
full_file_path = os.path.join(current_dir, file_path)
if os.path.isfile(full_file_path):
return full_file_path
else:
raise FileNotFoundError(full_file_path)
def get_input(path: str) -> tuple:
"""
Load an input doc containing an input document and its Gold Standard.
:param path: A string with the file path to an input document passed as argument in the command line.
:return: A tuple containing the article and the related Gold Standard.
"""
try:
temp = open(path).read().replace("\n", " ").split("***")
except FileNotFoundError as error:
logger.error(error) # "ERROR: Article does not exist."
raise
else:
input_doc = temp[0].strip()
gold_standard = temp[1].strip()
return input_doc, gold_standard
def get_tokenizer_model(model_name: str) -> tuple:
"""
Determines the appropriate tokenizer and returns it with an object representing the required pre-trained model.
:param model_name: A string with the model name passed as argument in the command line.
:return: A tuple containing the pre-trained model and its tokenizer.
"""
if "pegasus" in model_name: # model_name already lowecase
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
try:
model_tokenizer = PegasusTokenizer.from_pretrained(model_name) # download vocab
pretrained_model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device) # model ~2GB
except RuntimeError as error:
logger.error(error) # "ERROR: Couldn't load model or tokenizer."
raise
elif "bart" in model_name:
from transformers import BartForConditionalGeneration, BartTokenizer
try:
model_tokenizer = BartTokenizer.from_pretrained(model_name) # download vocab
pretrained_model = BartForConditionalGeneration.from_pretrained(model_name) # download model ~500MB
except RuntimeError as error:
logger.error(error) # "ERROR: Couldn't load model or tokenizer."
raise
else:
raise AttributeError("ERROR: Wrong model name provided.")
return model_tokenizer, pretrained_model
def get_summary(generation_type: str, pretrained_model: Any, input_doc: str, model_tokenizer: Any) -> list:
"""
Load an input doc containing an input document and its Gold Standard.
For Beam Search:
- We reduces the risk of missing hidden high probabilities word sequences by keeping the most likely num_beams of
hypotheses at each time step and eventually choosing the hypothesis that has the overall highest probability.
- Early stopping so that generation is finished when all beam hypotheses reached the EOS token.
- Techniques like applying a length penalty or setting the number of return sequences don't apply for summarization.
- N-grams penalty not applied to avoid reducing model's expressivity
For Sampling:
- K most likely next words are filtered. The probability mass is redistributed among only those K next words.
- Choose from the smallest possible set of words whose cumulative probability exceeds the top_p parameter.
:param generation_type: The technique to use when generating the summary.
:param pretrained_model: The pre-trained model to use when summarizing.
:param input_doc: The article to summarize.
:param model_tokenizer: The tokenizer to convert the raw article for generation.
:return: A list containing the summary of each input document
"""
# Model's max. seq length is 1024. Returns dict of "input_ids" and "attention_mask" tensors of 1x512. pt is Torch!
model_input = model_tokenizer(input_doc, truncation=True, padding='longest', return_tensors="pt").to(device)
summary_ids = None
if generation_type == "search":
summary_ids = pretrained_model.generate(
**model_input,
num_beams=4,
early_stopping=True,
min_length=64,
max_length=128,
# length_penalty=4
)
elif generation_type == "sampling":
summary_ids = pretrained_model.generate(
model_input["input_ids"],
do_sample=True,
min_length=64,
max_length=128,
top_k=64,
top_p=0.90
)
# Detokenize
result = model_tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
return result
def evaluate_summary(model_output: list, gold_standard: str, rouge_metric: Any) -> list:
"""
Uses the ROUGE1 and ROUGEL scores to have a numeric idea of the summary quality.
:param model_output: The summary generated by the model.
:param gold_standard: The reference summary.
:param rouge_metric: The ROUGE metric to use.
:return: A list with the calculated ROUGE metrics.
"""
rouge_metric.add(prediction=model_output, reference=[gold_standard]) # both should be lists
metric_results = rouge_metric.compute()
# results = f"\nRouge1 Recall: {round(metric_results['rouge1'].mid.recall, 3)}\n"
# results += f"Rouge1 Precision: {round(metric_results['rouge1'].mid.precision, 3)}\n"
# results += f"Rouge1 F Measure: {round(metric_results['rouge1'].mid.fmeasure, 3)}\n"
# results += "\n"
# results += f"\nRougeL Recall: {round(metric_results['rougeL'].mid.recall, 3)}\n"
# results += f"RougeL Precision: {round(metric_results['rougeL'].mid.precision, 3)}\n"
# results += f"RougeL F Measure: {round(metric_results['rougeL'].mid.fmeasure, 3)}\n"
results = [
round(metric_results['rouge1'].mid.recall, 3),
round(metric_results['rouge1'].mid.precision, 3),
round(metric_results['rouge1'].mid.fmeasure, 3),
round(metric_results['rougeL'].mid.recall, 3),
round(metric_results['rougeL'].mid.precision, 3),
round(metric_results['rougeL'].mid.fmeasure, 3)
]
return results
def report_results(
input_doc: str,
generation_technique: str,
pretrained_model: str,
model_output: list,
output_score: list) -> None:
"""
Send the results to a CSV file for later analysis.
:param input_doc: The input article to summarize
:param generation_technique: The technique to generate the summary.
:param pretrained_model: The summarizing model to use.
:param model_output: The summary generated by the model.
:param output_score: The ROUGE metrics applied to the model's output and the reference summary.
:return: None.
"""
# One file per day
with open(f"./results/runner_results{datetime.now().strftime('%m-%d-%y')}.csv", "a") as csv_file:
writer = csv.writer(csv_file, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
writer.writerow([
input_doc,
generation_technique,
pretrained_model,
model_output[0],
output_score[0], # Rouge1_recall
output_score[1], # Rouge1_precision
output_score[2], # Rouge1_F
output_score[3], # RougeL_recall
output_score[4], # RougeL_precision
output_score[5] # RougeL_F
])
# :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::
if __name__ == "__main__":
# Script setup
warnings.filterwarnings("ignore")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger = logging.getLogger(__name__)
# Reproducibility measures
torch.manual_seed(3)
random.seed(3)
numpy.random.seed(3)
# ............................. Command line argument parsing .....................................................
parser = argparse.ArgumentParser(description="Trying summarization models.")
parser.add_argument(
"--model_name",
required=True,
type=str,
choices=[
"google/pegasus-xsum",
"google/pegasus-cnn_dailymail",
"google/pegasus-large",
"facebook/bart-large-cnn",
"facebook/bart-large-xsum",
"sshleifer/distilbart-cnn-12-6",
"sshleifer/distilbart-xsum-12-6"
]
)
parser.add_argument("--article", required=True, type=check_file_path, help="Add a relative path.")
parser.add_argument("--generation", required=True, type=str, choices=["search", "sampling"])
args = parser.parse_args()
# ......................................... Evaluation metric .....................................................
metric = load_metric("rouge")
# ......................................... Read input doc ........................................................
text, ref_summary = get_input(args.article)
# ........................................ Create tokenizer and Model .............................................
tokenizer, model = get_tokenizer_model(args.model_name.lower())
# ............................................ Summarize! .........................................................
summary = get_summary(args.generation.lower(), model, text, tokenizer)
# .............................................. Evalute! .........................................................
score = evaluate_summary(summary, ref_summary, metric)
# ................................................ Report .........................................................
report_results (args.article, args.generation.lower(), args.model_name.lower(), summary, score)
# print(f"Summary:\n{summary[0]}")
# print(f"\nScore: {score}")