Skip to content

Commit

Permalink
break out some sync functions for faster syncing of autolabel
Browse files Browse the repository at this point in the history
  • Loading branch information
tomlue committed May 4, 2024
1 parent 83838bf commit 24d380f
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 122 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,8 @@ __pycache__/

scratch.py
.env
.sr
.sr

# code2flow output
out.gv
out.png
255 changes: 134 additions & 121 deletions sysrev/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,91 @@ def transform_label(self, label_type, label_value):
else:
raise ValueError("Invalid label type")


class Client():

def __init__(self, api_key, base_url="https://www.sysrev.com"):
self.api_key = api_key
self.base_url = base_url

def sync(self, project_id):
Synchronizer().sync(self, project_id)

def get_project_info(self, project_id):
endpoint = f"{self.base_url}/api-json/project-info"
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.get(endpoint, headers=headers, params={"project-id": project_id})
return response.json()

def get_labels(self, project_id):
raw_labels = self.get_project_info(project_id)['result']['project']['labels']
labels = [{"label_id": label_id} | raw_labels[label_id] for label_id in raw_labels.keys()]
return labels

def set_labels(self, project_id, article_id, label_ids, label_values, label_types, confirm=False, change=False, resolve=False):
endpoint = f"{self.base_url}/api-json/set-labels"
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}

assert len(label_ids) == len(label_values) == len(label_types), "Length of label_ids, label_values, and label_types should be the same."

# construct label_values_dict
tf = LabelTransformer()
label_values_dict = {label_ids[i]: tf.transform_label(label_types[i], label_values[i]) for i in range(len(label_ids))}

# Constructing the data payload as per the server's expectation
data = {"project-id": project_id, "article-id": article_id, "label-values": label_values_dict}
data.update({ "confirm?": confirm, "change?": change, "resolve?": resolve })

# Sending a POST request to the server
response = requests.post(endpoint, json=data, headers=headers)
return response.json()

def get_project_articles(self, project_id, offset=0, limit=10, sort_by=None, sort_dir=None):
endpoint = f"{self.base_url}/api-json/project-articles"
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
body = {"project-id": project_id, "n-offset": offset, "n-count": limit}

# Add optional sorting keys if provided
if sort_by: body["sort-by"] = sort_by
if sort_dir: body["sort-dir"] = sort_dir

# Make the POST request with the simplified body
response = requests.post(endpoint, headers=headers, json=body)
return response.json()

def fetch_all_articles(self, project_id, limit=10, sort_by=None, sort_dir=None):
offset = 0
while True:
result = self.get_project_articles(project_id, offset=offset, limit=limit, sort_by=sort_by, sort_dir=sort_dir)
articles = result.get('result', [])
if not articles:
break # Stop iteration if no articles are left
yield from articles # Yield each article in the current batch
offset += len(articles)

def get_article_info(self, project_id, article_id):
endpoint = f"{self.base_url}/api-json/article-info/{article_id}"
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
body = {"project-id": project_id,}
response = requests.get(endpoint, headers=headers, json=body)
return response.json()['result']

def upload_jsonlines(self, file_path, project_id):
url = f"{self.base_url}/api-json/import-files/{project_id}"
headers = {"Authorization": f"Bearer {self.api_key}"}

# Prepare the file for upload
with open(file_path, 'rb') as f:
files = {'file': (file_path.split('/')[-1], f, 'application/octet-stream')}
# Let requests handle "Content-Type"
response = requests.post(url, headers=headers, files=files)

return response

def get_article_file(self, project_id, article_id, hash):
url = f"{self.base_url}/api-json/files/{project_id}/article/{article_id}/download/{hash}"
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}

class Synchronizer:

def create_sqlite_db(self):
Expand Down Expand Up @@ -93,33 +178,28 @@ def create_sqlite_db(self):
conn.commit()
conn.close()

# TODO - this could be made more efficient by checking sqlite state and updating the sysrev api
def sync(self, client, project_id):

if not pathlib.Path('.sr/sr.sqlite').exists():
self.create_sqlite_db()

project_info = client.get_project_info(project_id)

labels = client.get_labels(project_id)
labels_df = pd.DataFrame(labels)
labels_df['definition'] = labels_df['definition'].apply(json.dumps)

n_articles = project_info['result']['project']['stats']['articles']
articles = [resp for resp in tqdm.tqdm(client.fetch_all_articles(project_id), total=n_articles)]

article_labels = [a['labels'] for a in articles if a['labels'] is not None]
article_labels = [lbl for lbls in article_labels for lbl in lbls]
article_label_df = pd.DataFrame(article_labels)
article_label_df['answer'] = article_label_df['answer'].apply(json.dumps)

article_data = [{k: v for k, v in a.items() if k != 'labels'} for a in articles]
article_data_df = pd.DataFrame(article_data)
article_data_df['notes'] = article_data_df['notes'].apply(json.dumps)
article_data_df['resolve'] = article_data_df['resolve'].apply(json.dumps)

def write_df(self, df, name, db_path='.sr/sr.sqlite'):
"""
Writes the given DataFrame to a SQLite database.
Parameters:
df (pandas.DataFrame): The DataFrame to be written to the database.
name (str): The name of the table in which the DataFrame should be stored.
db_path (str): Path to the SQLite database file.
"""
# Connect to the SQLite database
conn = sqlite3.connect(db_path)

try:
df.columns = df.columns.str.replace('-', '_')
df = df.loc[:, ~df.columns.duplicated()]
df.to_sql(name, conn, if_exists='replace', index=False) if not df.empty else None
finally:
conn.close()

def sync_article_info(self, client:Client, project_id, article_ids):
article_info = []
for article_id in tqdm.tqdm(article_data_df['article-id'], total=n_articles):
for article_id in tqdm.tqdm(article_ids, total=len(article_ids)):
article_info.append(client.get_article_info(project_id, article_id))

full_texts = pd.DataFrame([{**ft} for a in article_info for ft in a['article'].get('full-texts', []) ])
Expand All @@ -137,106 +217,39 @@ def sync(self, client, project_id):
csl_citations['issued'] = csl_citations['issued'].apply(json.dumps)
csl_citations['author'] = csl_citations['author'].apply(json.dumps)

# write everything to .sr/sr.sqlite
conn = sqlite3.connect('.sr/sr.sqlite')

def write_df(df,name):
# replace any - with _ in column names and remove duplicates
df.columns = df.columns.str.replace('-', '_')
df = df.loc[:,~df.columns.duplicated()]
df.to_sql(name, conn, if_exists='replace', index=False) if not df.empty else None


# Writing data to tables
write_df(labels_df,'labels')
write_df(article_label_df,'article_label')
write_df(article_data_df,'article_data')
write_df(full_texts,'full_texts')
write_df(auto_labels,'auto_labels')
write_df(csl_citations,'csl_citations')

conn.close()
class Client():
self.write_df(full_texts,'full_texts')
self.write_df(auto_labels,'auto_labels')
self.write_df(csl_citations,'csl_citations')

def __init__(self, api_key, base_url="https://www.sysrev.com"):
self.api_key = api_key
self.base_url = base_url
def sync_labels(self, client, project_id):
labels = client.get_labels(project_id)
labels_df = pd.DataFrame(labels)
labels_df['definition'] = labels_df['definition'].apply(json.dumps)
self.write_df(labels_df,'labels')

def sync(self, project_id):
Synchronizer().sync(self, project_id)

def get_project_info(self, project_id):
endpoint = f"{self.base_url}/api-json/project-info"
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.get(endpoint, headers=headers, params={"project-id": project_id})
return response.json()

def get_labels(self, project_id):
raw_labels = self.get_project_info(project_id)['result']['project']['labels']
labels = [{"label_id": label_id} | raw_labels[label_id] for label_id in raw_labels.keys()]
return labels

def set_labels(self, project_id, article_id, label_ids, label_values, label_types, confirm=False, change=False, resolve=False):
endpoint = f"{self.base_url}/api-json/set-labels"
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
# TODO - this could be made more efficient by checking sqlite state and updating the sysrev api
def sync(self, client, project_id):

assert len(label_ids) == len(label_values) == len(label_types), "Length of label_ids, label_values, and label_types should be the same."
if not pathlib.Path('.sr/sr.sqlite').exists():
self.create_sqlite_db()

project_info = client.get_project_info(project_id)

# construct label_values_dict
tf = LabelTransformer()
label_values_dict = {label_ids[i]: tf.transform_label(label_types[i], label_values[i]) for i in range(len(label_ids))}
n_articles = project_info['result']['project']['stats']['articles']
articles = [resp for resp in tqdm.tqdm(client.fetch_all_articles(project_id), total=n_articles)]

# Constructing the data payload as per the server's expectation
data = {"project-id": project_id, "article-id": article_id, "label-values": label_values_dict}
data.update({ "confirm?": confirm, "change?": change, "resolve?": resolve })
article_labels = [a['labels'] for a in articles if a['labels'] is not None]
article_labels = [lbl for lbls in article_labels for lbl in lbls]
article_label_df = pd.DataFrame(article_labels)
article_label_df['answer'] = article_label_df['answer'].apply(json.dumps)

# Sending a POST request to the server
response = requests.post(endpoint, json=data, headers=headers)
return response.json()

def get_project_articles(self, project_id, offset=0, limit=10, sort_by=None, sort_dir=None):
endpoint = f"{self.base_url}/api-json/project-articles"
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
body = {"project-id": project_id, "n-offset": offset, "n-count": limit}
article_data = [{k: v for k, v in a.items() if k != 'labels'} for a in articles]
article_data_df = pd.DataFrame(article_data)
article_data_df['notes'] = article_data_df['notes'].apply(json.dumps)
article_data_df['resolve'] = article_data_df['resolve'].apply(json.dumps)

# Add optional sorting keys if provided
if sort_by: body["sort-by"] = sort_by
if sort_dir: body["sort-dir"] = sort_dir
self.sync_article_info(client, project_id, article_data_df['article-id'])

# Make the POST request with the simplified body
response = requests.post(endpoint, headers=headers, json=body)
return response.json()

def fetch_all_articles(self, project_id, limit=10, sort_by=None, sort_dir=None):
offset = 0
while True:
result = self.get_project_articles(project_id, offset=offset, limit=limit, sort_by=sort_by, sort_dir=sort_dir)
articles = result.get('result', [])
if not articles:
break # Stop iteration if no articles are left
yield from articles # Yield each article in the current batch
offset += len(articles)

def get_article_info(self, project_id, article_id):
endpoint = f"{self.base_url}/api-json/article-info/{article_id}"
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
body = {"project-id": project_id,}
response = requests.get(endpoint, headers=headers, json=body)
return response.json()['result']

def upload_jsonlines(self, file_path, project_id):
url = f"{self.base_url}/api-json/import-files/{project_id}"
headers = {"Authorization": f"Bearer {self.api_key}"}

# Prepare the file for upload
with open(file_path, 'rb') as f:
files = {'file': (file_path.split('/')[-1], f, 'application/octet-stream')}
# Let requests handle "Content-Type"
response = requests.post(url, headers=headers, files=files)

return response

def get_article_file(self, project_id, article_id, hash):
url = f"{self.base_url}/api-json/files/{project_id}/article/{article_id}/download/{hash}"
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}

# Writing data to tables
self.write_df(article_label_df,'article_label')
self.write_df(article_data_df,'article_data')

0 comments on commit 24d380f

Please sign in to comment.