-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
563 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from utils import * | ||
|
||
|
||
|
||
class tools: | ||
def __init__(self): | ||
a=1 | ||
|
||
def get_user_profile(self,author_name): | ||
|
||
author_query = author_name.replace(" ", "+") | ||
url = f"http://export.arxiv.org/api/query?search_query=au:{author_query}&start=0&max_results=300" # Adjust max_results if needed | ||
|
||
response = requests.get(url) | ||
papers_list = [] | ||
|
||
if response.status_code == 200: | ||
root = ElementTree.fromstring(response.content) | ||
entries = root.findall('{http://www.w3.org/2005/Atom}entry') | ||
|
||
total_papers = 0 | ||
data_to_save = [] | ||
|
||
papers_by_year = {} | ||
|
||
for entry in entries: | ||
|
||
title = entry.find('{http://www.w3.org/2005/Atom}title').text.strip() | ||
published = entry.find('{http://www.w3.org/2005/Atom}published').text.strip() | ||
abstract = entry.find('{http://www.w3.org/2005/Atom}summary').text.strip() | ||
authors_elements = entry.findall('{http://www.w3.org/2005/Atom}author') | ||
authors = [author.find('{http://www.w3.org/2005/Atom}name').text for author in authors_elements] | ||
link = entry.find('{http://www.w3.org/2005/Atom}id').text.strip() # Get the paper link | ||
|
||
# Check if the specified author is exactly in the authors list | ||
if author_name in authors: | ||
# Remove the specified author from the coauthors list for display | ||
coauthors = [author for author in authors if author != author_name] | ||
coauthors_str = ", ".join(coauthors) | ||
|
||
papers_list.append({ | ||
"date": published, | ||
"Title & Abstract": f"{title}; {abstract}", | ||
"coauthors": coauthors_str, | ||
"link": link # Add the paper link to the dictionary | ||
}) | ||
authors_elements = entry.findall('{http://www.w3.org/2005/Atom}author') | ||
authors = [author.find('{http://www.w3.org/2005/Atom}name').text for author in authors_elements] | ||
|
||
if author_name in authors: | ||
# print(author_name) | ||
# print(authors) | ||
total_papers += 1 | ||
published_date = entry.find('{http://www.w3.org/2005/Atom}published').text.strip() | ||
date_obj = datetime.datetime.strptime(published_date, '%Y-%m-%dT%H:%M:%SZ') | ||
|
||
year = date_obj.year | ||
if year not in papers_by_year: | ||
papers_by_year[year] = [] | ||
papers_by_year[year].append(entry) | ||
|
||
if total_papers > 40: | ||
for cycle_start in range(min(papers_by_year), max(papers_by_year) + 1, 5): | ||
cycle_end = cycle_start + 4 | ||
for year in range(cycle_start, cycle_end + 1): | ||
if year in papers_by_year: | ||
selected_papers = papers_by_year[year][:2] | ||
for paper in selected_papers: | ||
title = paper.find('{http://www.w3.org/2005/Atom}title').text.strip() | ||
abstract = paper.find('{http://www.w3.org/2005/Atom}summary').text.strip() | ||
authors_elements = paper.findall('{http://www.w3.org/2005/Atom}author') | ||
co_authors = [author.find('{http://www.w3.org/2005/Atom}name').text for author in | ||
authors_elements if | ||
author.find('{http://www.w3.org/2005/Atom}name').text != author_name] | ||
|
||
papers_list.append({ | ||
"Author": author_name, | ||
"Title & Abstract": f"{title}; {abstract}", | ||
"Date Period": f"{year}", | ||
"Cycle": f"{cycle_start}-{cycle_end}", | ||
"Co_author": ", ".join(co_authors) | ||
}) | ||
|
||
# Trim the list to the 10 most recent papers | ||
papers_list = papers_list[:10] | ||
|
||
personal_info = "; ".join([f"{details['Title & Abstract']}" for details in papers_list]) | ||
|
||
info = summarize_research_direction(personal_info) | ||
|
||
return info | ||
|
||
# data = {author_name: {"paper_{}".format(i+1): paper for i, paper in enumerate(papers_list)}} | ||
|
||
else: | ||
print("Failed to fetch data from arXiv.") | ||
return "" | ||
|
||
|
||
|
||
|
||
def get_recent_paper(self,num,domain): | ||
data_collector = [] | ||
keywords = dict() | ||
keywords[domain] = domain | ||
|
||
for topic, keyword in keywords.items(): | ||
# print("Keyword: " + topic) | ||
data, _ = get_daily_papers(topic, query=keyword, max_results=num) | ||
data_collector.append(data) | ||
data_dict={} | ||
for data in data_collector: | ||
for time in data.keys(): | ||
papers = data[time] | ||
# print(papers.published) | ||
data_dict[time.strftime("%m/%d/%Y")] = papers | ||
|
||
return data_dict | ||
|
||
def idea_generation(self, profile,papers,domain): | ||
|
||
time_chunks_embed = {} | ||
dataset = papers | ||
for time in dataset.keys(): | ||
papers = dataset[time]['abstract'] | ||
papers_embedding = get_bert_embedding(papers) | ||
time_chunks_embed[time] = papers_embedding | ||
|
||
self.trend, paper_link = summarize_research_field(profile, domain, dataset, | ||
time_chunks_embed) # trend | ||
self.idea = generate_ideas(self.trend) # idea | ||
|
||
|
||
return self.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import arxiv | ||
from tqdm import tqdm | ||
|
||
from utils import * | ||
|
||
|
||
def author_position(author, author_list): | ||
for ind, i in enumerate(author_list): | ||
if author.lower() == i.lower(): | ||
return ind + 1 | ||
|
||
return "NULL" | ||
|
||
|
||
def co_author_frequency(author, author_list, co_authors): | ||
for ind, i in enumerate(author_list): | ||
if author.lower() == i.lower(): | ||
continue | ||
if i in co_authors: | ||
co_authors[i] += 1 | ||
else: | ||
co_authors[i] = 1 | ||
|
||
return co_authors | ||
|
||
|
||
def co_author_filter(co_authors, limit=5): | ||
co_author_list = [] | ||
for k, v in co_authors.items(): | ||
co_author_list.append([k, v]) | ||
|
||
co_author_list.sort(reverse=True, key=lambda p: p[1]) | ||
co_author_list = co_author_list[:limit] | ||
co_author_list = [c[0] for c in co_author_list] | ||
|
||
return co_author_list | ||
|
||
|
||
def fetch_author_info(author): | ||
client = arxiv.Client() | ||
papers_info = [] | ||
co_authors = dict() | ||
print("{} Fetching Author Info: {}".format(show_time(), author)) | ||
search = arxiv.Search( | ||
query="au:{}".format(author), | ||
max_results=10 | ||
) | ||
for result in tqdm(client.results(search), desc="Processing Author Papers", unit="Paper"): | ||
if author not in ', '.join(author.name for author in result.authors): | ||
continue | ||
author_list = [author.name for author in result.authors] | ||
# author_pos = author_position(author, author_list) | ||
co_authors = co_author_frequency(author, author_list, co_authors) | ||
paper_info = { | ||
'url': result.entry_id, | ||
"title": result.title, | ||
"abstract": result.summary, | ||
"authors": ', '.join(author.name for author in result.authors), | ||
"published": str(result.published).split(" ")[0], | ||
"updated": str(result.updated).split(" ")[0], | ||
'primary_cat': result.primary_category, | ||
'cats': result.categories, | ||
# "author_pos": author_pos | ||
} | ||
# print(json.dumps(paper_info, indent=4)) | ||
papers_info.append(paper_info) | ||
|
||
# papers_info.sort(reverse=False, key=lambda p: p["author_pos"]) | ||
co_authors = co_author_filter(co_authors, limit=5) | ||
print(text_wrap("Num of Papers:"), len(papers_info)) | ||
print(text_wrap("Num of Co-authors:"), len(co_authors)) | ||
|
||
return papers_info, co_authors | ||
|
||
|
||
def bfs(author_list, node_limit=20): | ||
graph = [] | ||
node_feat = dict() | ||
edge_feat = dict() | ||
visit = [] | ||
for author in author_list: | ||
if author in visit: | ||
continue | ||
papers_info, co_authors = fetch_author_info(author) | ||
if len(node_feat) <= node_limit: | ||
author_list.extend(co_authors) | ||
for co_au in co_authors: | ||
if (author, co_au) in graph or (co_au, author) in graph: | ||
continue | ||
graph.append((author, co_au)) | ||
|
||
visit.append(author) | ||
node_feat[author] = papers_info | ||
|
||
return graph, node_feat, edge_feat | ||
|
||
|
||
if __name__ == '__main__': | ||
start_author = ["Jiaxuan You"] | ||
bfs(author_list=start_author, node_limit=20) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from construct_relation_graph import * | ||
from basic_tool import tools | ||
|
||
|
||
|
||
class SingleAgent: | ||
def __init__(self): | ||
self.tools=tools() | ||
def get_profile(self, name): | ||
"""get user profile through arxiv paper.""" | ||
profile=self.tools.get_user_profile(name) | ||
return profile | ||
def get_recent_paper_info(self,num,domain): | ||
info=self.tools.get_recent_paper(num,domain) | ||
return info | ||
def generate_ideas(self,profile,papers,domain): | ||
"""given arxiv paper,generate ideas based on user profile.""" | ||
ideas=self.tools.idea_generation(profile,papers,domain) | ||
return ideas | ||
|
||
|
||
|
||
class MultiAgent: | ||
def __init__(self): | ||
self.tools=tools() | ||
|
||
def get_relation_graph(self, name,max_node): | ||
"""Obtain author-author relations through co-author relationships""" | ||
start_author = [name] | ||
graph, node_feat, edge_feat=bfs(author_list=start_author, node_limit=max_node) | ||
return graph, node_feat, edge_feat | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from function_class import * | ||
|
||
single_agent=SingleAgent() | ||
# print(single_agent.get_profile("Jiaxuan You")) | ||
# print(single_agent.get_recent_paper_info(10,"Machine Learning")) | ||
# domain="Machine Learning" | ||
# profile=single_agent.get_profile("Jiaxuan You") | ||
# papers=single_agent.get_recent_paper_info(10,"Machine Learning") | ||
# | ||
# ideas=single_agent.generate_ideas(profile,papers,domain) | ||
# print(ideas) | ||
|
||
multi_agent=MultiAgent() | ||
print(multi_agent.get_relation_graph("Jiaxuan You",20)) |
Oops, something went wrong.