Skip to content

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorAxelsen committed Oct 14, 2024
1 parent dc22b94 commit f2e6834
Show file tree
Hide file tree
Showing 27 changed files with 4,747 additions and 16 deletions.
102 changes: 87 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
![Method](./figures/logo.png)

# Graph of Records: Boosting Retrieval Augmented Generation for Long-context Summarization with Graphs




<p align="center">
<a href="https://ulab-uiuc.github.io/GoR/">
<img alt="Build" src="https://img.shields.io/badge/Project-Page-blue">
</a>
<a href="https://arxiv.org/abs/xxx">
<img alt="Build" src="https://img.shields.io/badge/arXiv-xxxx.xxxxx-red?logo=arxiv">
</a>
<a href="xxx">
<!-- <a href="xxx">
<img alt="Build" src="https://img.shields.io/badge/Twitter-black?logo=X">
</a>
</a> -->
<a href="https://github.com/ulab-uiuc/GoR/blob/master/LICENSE">
<img alt="License" src="https://img.shields.io/badge/LICENSE-MIT-green">
</a>
Expand All @@ -33,37 +29,113 @@

<p align="center">
<a href="https://ulab-uiuc.github.io/GoR/">🌐 Project Page</a> |
<a href="https://arxiv.org/abs/xxx">📜 arXiv</a> |
<a href="xxx">📮 Twitter Post</a>
<a href="https://arxiv.org/abs/xxx">📜 arXiv</a>
<!-- <a href="xxx">📮 Twitter Post</a> -->
<p>


![Method](./figures/model.png)



## News

**[2024.10.1x]** 🌟Release GoR
<!-- **[2024.10.1x]** 🌟Release AcademicEval -->
**[2024.10.1x]** 🌟 GoR is released.
<!-- **[2024.10.1x]** 🌟Release Thought Retriever -->


## 📌Preliminary


## 📌Environment Setup
### Environment Setup

```bash
# python==3.10
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
pip install dgl==1.0.0+cu113 -f https://data.dgl.ai/wheels/cu113/repo.html
pip install scikit-learn
pip install openai==0.28
pip install pandas
pip install langchain
pip install langchain-core
pip install langchain-community
pip install langchain-experimental
pip install tiktoken
pip install tqdm
pip install bert_score
pip install rouge_score
pip install networkx
pip install faiss-gpu
pip install transformers
```

### Dataset Preparation

[QMSum](https://github.com/Yale-LILY/QMSum)
[WCEP](https://huggingface.co/datasets/ccdv/WCEP-10)
[Booksum](https://huggingface.co/datasets/kmfoda/booksum)
[GovReport](https://huggingface.co/datasets/ccdv/govreport-summarization/tree/refs%2Fconvert%2Fparquet/document)
[SQuALITY](https://github.com/nyu-mll/SQuALITY)


Save the downloaded files in the `./data/[DATASET_NAME]` folder.

## ⭐Experiments

> \[!IMPORTANT\]
>
> Before running the experiment, please configure your API KEY
> Before running the experiment, please configure your API KEY in `"get_llm_response_via_api"` in `utils.py`


## ⭐Experiments



### Query Simulation and Graph Construction

Generate simulated queries and construct graphs. The constructed graphs are saved in the `./graph` folder.

```bash
# DATASET Choices: qmsum, wcep, booksum, govreport, squality
# Training Set
python graph_construction.py --cuda 0 --dataset [DATASET] --train
# Test Set
python graph_construction.py --cuda 0 --dataset [DATASET]
```


### Training Preparation

Pre-compute BERTScore and save training data in the `./training_data` folder.



```bash
# DATASET Choices: qmsum, wcep, booksum, govreport, squality
python training_preparation.py --cuda 0 --dataset [DATASET]
```



### Training


```bash
# DATASET Choices: qmsum, wcep, booksum, govreport, squality
python train.py --cuda 0 --dataset [DATASET]
```


### Evaluation


```bash
# DATASET Choices: qmsum, wcep, booksum, govreport, squality
# Generate summary results
python eval.py --cuda 0 --dataset [DATASET]
# Evaluation
python sum_eval.py --cuda 0 --file_name ./result/[DATASET].json
```



Expand All @@ -87,7 +159,7 @@ x
```


<picture>
<!-- <picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=ulab-uiuc%2FGoR&theme=dark&type=Date">
<img width="100%" src="https://api.star-history.com/svg?repos=ulab-uiuc%2FGoR&type=Date">
</picture>
</picture> -->
204 changes: 204 additions & 0 deletions data_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import json
import pandas as pd


def clean_data(text):
text = text.replace('{ vocalsound } ', '')
text = text.replace('{ disfmarker } ', '')
text = text.replace('a_m_i_', 'ami')
text = text.replace('l_c_d_', 'lcd')
text = text.replace('p_m_s', 'pms')
text = text.replace('t_v_', 'tv')
text = text.replace('{ pause } ', '')
text = text.replace('{ nonvocalsound } ', '')
text = text.replace('{ gap } ', '')

return text


def qm_process_data(train=True):
ret = []
data = []
with open("./data/QMSum/data/ALL/jsonl/{}.jsonl".format("train" if train else "test"), 'r') as file:
for line in file:
data.append(json.loads(line))

for sample in data:
ret_sample = dict()
ret_sample["topic_list"] = sample['topic_list']
ret_sample["general_query_list"] = sample['general_query_list']
all_transcripts = "\n".join(
["Speaker: " + i["speaker"] + "\n" + "Content: " + i["content"] for i in sample['meeting_transcripts']])
ret_sample["meeting_transcripts"] = clean_data(all_transcripts)
ret.append(ret_sample)

return ret


def booksum_process_data(train=True):
ret = []
if train:
data = pd.read_csv('./data/Booksum/train.csv')
else:
data = pd.read_csv('./data/Booksum/test.csv')

for index, row in data.iterrows():
# Filter out too short text for long-context summarization
if row["chapter_length"] <= 8000:
continue
ret.append(row)

return ret


def wcep_process_data(train=True):
ret = []
data = []
with open('./data/WCEP/{}.txt'.format("train" if train else "test"), 'r') as file:
lines = file.readlines()
for line in lines:
data.append(json.loads(line))

for row in data:
word_num = len(" ".join(row["document"]).split())
# Filter out too short text for long-context summarization
if word_num <= 6000:
continue
ret.append(row)

return ret


def gov_process_data(train=True):
ret = []
if train:
data1 = pd.read_parquet('./data/GovReport/document/train-00000-of-00002.parquet')
data2 = pd.read_parquet('./data/GovReport/document/train-00001-of-00002.parquet')
data = pd.concat([data1, data2])
else:
data = pd.read_parquet('./data/GovReport/document/test-00000-of-00001.parquet')

for index, row in data.iterrows():
word_num = len(row["report"].split())
# Filter out too short text for long-context summarization
if word_num <= 8000:
continue
ret.append(row)

return ret


def squ_process_data(train=True):
ret = []
data = []
if train:
# Expand the training set
with open("./data/SQuALITY/data/v1-3/txt/train.jsonl", 'r') as file:
for line in file:
data.append(json.loads(line))
with open("./data/SQuALITY/data/v1-3/txt/dev.jsonl", 'r') as file:
for line in file:
data.append(json.loads(line))
cnt = 0
with open("./data/SQuALITY/data/v1-3/txt/test.jsonl", 'r') as file:
for line in file:
data.append(json.loads(line))
cnt += 1
if cnt == 25:
break
else:
# Ensure that the test set does not overlap with the training set
cnt = 0
with open("./data/SQuALITY/data/v1-3/txt/test.jsonl", 'r') as file:
for line in file:
if cnt < 25:
cnt += 1
continue
data.append(json.loads(line))

for sample in data:
ret.append(sample)

return ret


def get_processed_data(dataset, train=True):
if dataset == "qmsum":
data = qm_process_data(train=train)
elif dataset == "wcep":
data = wcep_process_data(train=train)
elif dataset == "booksum":
data = booksum_process_data(train=train)
elif dataset == "govreport":
data = gov_process_data(train=train)
elif dataset == "squality":
data = squ_process_data(train=train)
else:
raise Exception("Dataset Error")

return data


def split_corpus_by_doc(dataset, sample, text_splitter):
chunk_list = []
if dataset == "qmsum":
doc_list = [sample["meeting_transcripts"]]
elif dataset == "wcep":
doc_list = sample["document"]
elif dataset == "booksum":
doc_list = [sample["chapter"]]
elif dataset == "govreport":
doc_list = [sample["report"]]
elif dataset == "squality":
doc_list = [sample["document"]]
else:
raise Exception("Dataset Error")

for doc in doc_list:
chunk_list.extend(text_splitter.split_text(doc))

return chunk_list


def eval_data_generation(dataset, sample):
ret = []
if dataset == "qmsum":
all_topic = ", ".join([i["topic"] for i in sample["topic_list"]])
for test_query in sample["general_query_list"]:
data = dict()
data["rag_query"] = test_query["query"] + " The topic list of the meeting transcript is: " + all_topic
data["query"] = test_query["query"]
data["summary"] = test_query["answer"]
ret.append(data)
elif dataset == "wcep":
data = dict()
data["rag_query"] = "Summarize the contents of this news event."
data["query"] = "Summarize the contents of this news event."
data["summary"] = sample["summary"]
ret.append(data)
elif dataset == "booksum":
data = dict()
data["rag_query"] = "Summarize the contents of this story."
data["query"] = "Summarize the contents of this story."
data["summary"] = sample["summary_text"]
ret.append(data)
elif dataset == "govreport":
data = dict()
data["rag_query"] = "Summarize the contents of this report."
data["query"] = "Summarize the contents of this report."
data["summary"] = sample["summary"]
ret.append(data)
elif dataset == "squality":
data = dict()
data["rag_query"] = sample["questions"][0]["question_text"]
data["query"] = sample["questions"][0]["question_text"]
data["summary"] = [i["response_text"] for i in sample["questions"][0]["responses"]]
ret.append(data)
else:
raise Exception("Dataset Error")

return ret


if __name__ == '__main__':
pass
Loading

0 comments on commit f2e6834

Please sign in to comment.