From 24d380f2b75fbe84e7881e3dc6b364472ceee5ab Mon Sep 17 00:00:00 2001 From: Thomas Luechtefeld Date: Sat, 4 May 2024 14:42:09 -0400 Subject: [PATCH] break out some sync functions for faster syncing of autolabel --- .gitignore | 6 +- sysrev/client.py | 255 +++++++++++++++++++++++++---------------------- 2 files changed, 139 insertions(+), 122 deletions(-) diff --git a/.gitignore b/.gitignore index 0bbb151..e01fdb2 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,8 @@ __pycache__/ scratch.py .env -.sr \ No newline at end of file +.sr + +# code2flow output +out.gv +out.png \ No newline at end of file diff --git a/sysrev/client.py b/sysrev/client.py index b664485..9f4f8d1 100644 --- a/sysrev/client.py +++ b/sysrev/client.py @@ -26,6 +26,91 @@ def transform_label(self, label_type, label_value): else: raise ValueError("Invalid label type") + +class Client(): + + def __init__(self, api_key, base_url="https://www.sysrev.com"): + self.api_key = api_key + self.base_url = base_url + + def sync(self, project_id): + Synchronizer().sync(self, project_id) + + def get_project_info(self, project_id): + endpoint = f"{self.base_url}/api-json/project-info" + headers = {"Authorization": f"Bearer {self.api_key}"} + response = requests.get(endpoint, headers=headers, params={"project-id": project_id}) + return response.json() + + def get_labels(self, project_id): + raw_labels = self.get_project_info(project_id)['result']['project']['labels'] + labels = [{"label_id": label_id} | raw_labels[label_id] for label_id in raw_labels.keys()] + return labels + + def set_labels(self, project_id, article_id, label_ids, label_values, label_types, confirm=False, change=False, resolve=False): + endpoint = f"{self.base_url}/api-json/set-labels" + headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + + assert len(label_ids) == len(label_values) == len(label_types), "Length of label_ids, label_values, and label_types should be the same." + + # construct label_values_dict + tf = LabelTransformer() + label_values_dict = {label_ids[i]: tf.transform_label(label_types[i], label_values[i]) for i in range(len(label_ids))} + + # Constructing the data payload as per the server's expectation + data = {"project-id": project_id, "article-id": article_id, "label-values": label_values_dict} + data.update({ "confirm?": confirm, "change?": change, "resolve?": resolve }) + + # Sending a POST request to the server + response = requests.post(endpoint, json=data, headers=headers) + return response.json() + + def get_project_articles(self, project_id, offset=0, limit=10, sort_by=None, sort_dir=None): + endpoint = f"{self.base_url}/api-json/project-articles" + headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + body = {"project-id": project_id, "n-offset": offset, "n-count": limit} + + # Add optional sorting keys if provided + if sort_by: body["sort-by"] = sort_by + if sort_dir: body["sort-dir"] = sort_dir + + # Make the POST request with the simplified body + response = requests.post(endpoint, headers=headers, json=body) + return response.json() + + def fetch_all_articles(self, project_id, limit=10, sort_by=None, sort_dir=None): + offset = 0 + while True: + result = self.get_project_articles(project_id, offset=offset, limit=limit, sort_by=sort_by, sort_dir=sort_dir) + articles = result.get('result', []) + if not articles: + break # Stop iteration if no articles are left + yield from articles # Yield each article in the current batch + offset += len(articles) + + def get_article_info(self, project_id, article_id): + endpoint = f"{self.base_url}/api-json/article-info/{article_id}" + headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + body = {"project-id": project_id,} + response = requests.get(endpoint, headers=headers, json=body) + return response.json()['result'] + + def upload_jsonlines(self, file_path, project_id): + url = f"{self.base_url}/api-json/import-files/{project_id}" + headers = {"Authorization": f"Bearer {self.api_key}"} + + # Prepare the file for upload + with open(file_path, 'rb') as f: + files = {'file': (file_path.split('/')[-1], f, 'application/octet-stream')} + # Let requests handle "Content-Type" + response = requests.post(url, headers=headers, files=files) + + return response + + def get_article_file(self, project_id, article_id, hash): + url = f"{self.base_url}/api-json/files/{project_id}/article/{article_id}/download/{hash}" + headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + class Synchronizer: def create_sqlite_db(self): @@ -93,33 +178,28 @@ def create_sqlite_db(self): conn.commit() conn.close() - # TODO - this could be made more efficient by checking sqlite state and updating the sysrev api - def sync(self, client, project_id): - - if not pathlib.Path('.sr/sr.sqlite').exists(): - self.create_sqlite_db() - - project_info = client.get_project_info(project_id) - - labels = client.get_labels(project_id) - labels_df = pd.DataFrame(labels) - labels_df['definition'] = labels_df['definition'].apply(json.dumps) - - n_articles = project_info['result']['project']['stats']['articles'] - articles = [resp for resp in tqdm.tqdm(client.fetch_all_articles(project_id), total=n_articles)] - - article_labels = [a['labels'] for a in articles if a['labels'] is not None] - article_labels = [lbl for lbls in article_labels for lbl in lbls] - article_label_df = pd.DataFrame(article_labels) - article_label_df['answer'] = article_label_df['answer'].apply(json.dumps) - - article_data = [{k: v for k, v in a.items() if k != 'labels'} for a in articles] - article_data_df = pd.DataFrame(article_data) - article_data_df['notes'] = article_data_df['notes'].apply(json.dumps) - article_data_df['resolve'] = article_data_df['resolve'].apply(json.dumps) - + def write_df(self, df, name, db_path='.sr/sr.sqlite'): + """ + Writes the given DataFrame to a SQLite database. + + Parameters: + df (pandas.DataFrame): The DataFrame to be written to the database. + name (str): The name of the table in which the DataFrame should be stored. + db_path (str): Path to the SQLite database file. + """ + # Connect to the SQLite database + conn = sqlite3.connect(db_path) + + try: + df.columns = df.columns.str.replace('-', '_') + df = df.loc[:, ~df.columns.duplicated()] + df.to_sql(name, conn, if_exists='replace', index=False) if not df.empty else None + finally: + conn.close() + + def sync_article_info(self, client:Client, project_id, article_ids): article_info = [] - for article_id in tqdm.tqdm(article_data_df['article-id'], total=n_articles): + for article_id in tqdm.tqdm(article_ids, total=len(article_ids)): article_info.append(client.get_article_info(project_id, article_id)) full_texts = pd.DataFrame([{**ft} for a in article_info for ft in a['article'].get('full-texts', []) ]) @@ -137,106 +217,39 @@ def sync(self, client, project_id): csl_citations['issued'] = csl_citations['issued'].apply(json.dumps) csl_citations['author'] = csl_citations['author'].apply(json.dumps) - # write everything to .sr/sr.sqlite - conn = sqlite3.connect('.sr/sr.sqlite') - - def write_df(df,name): - # replace any - with _ in column names and remove duplicates - df.columns = df.columns.str.replace('-', '_') - df = df.loc[:,~df.columns.duplicated()] - df.to_sql(name, conn, if_exists='replace', index=False) if not df.empty else None - - - # Writing data to tables - write_df(labels_df,'labels') - write_df(article_label_df,'article_label') - write_df(article_data_df,'article_data') - write_df(full_texts,'full_texts') - write_df(auto_labels,'auto_labels') - write_df(csl_citations,'csl_citations') - - conn.close() -class Client(): + self.write_df(full_texts,'full_texts') + self.write_df(auto_labels,'auto_labels') + self.write_df(csl_citations,'csl_citations') - def __init__(self, api_key, base_url="https://www.sysrev.com"): - self.api_key = api_key - self.base_url = base_url + def sync_labels(self, client, project_id): + labels = client.get_labels(project_id) + labels_df = pd.DataFrame(labels) + labels_df['definition'] = labels_df['definition'].apply(json.dumps) + self.write_df(labels_df,'labels') - def sync(self, project_id): - Synchronizer().sync(self, project_id) - - def get_project_info(self, project_id): - endpoint = f"{self.base_url}/api-json/project-info" - headers = {"Authorization": f"Bearer {self.api_key}"} - response = requests.get(endpoint, headers=headers, params={"project-id": project_id}) - return response.json() - - def get_labels(self, project_id): - raw_labels = self.get_project_info(project_id)['result']['project']['labels'] - labels = [{"label_id": label_id} | raw_labels[label_id] for label_id in raw_labels.keys()] - return labels - - def set_labels(self, project_id, article_id, label_ids, label_values, label_types, confirm=False, change=False, resolve=False): - endpoint = f"{self.base_url}/api-json/set-labels" - headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + # TODO - this could be made more efficient by checking sqlite state and updating the sysrev api + def sync(self, client, project_id): - assert len(label_ids) == len(label_values) == len(label_types), "Length of label_ids, label_values, and label_types should be the same." + if not pathlib.Path('.sr/sr.sqlite').exists(): + self.create_sqlite_db() + + project_info = client.get_project_info(project_id) - # construct label_values_dict - tf = LabelTransformer() - label_values_dict = {label_ids[i]: tf.transform_label(label_types[i], label_values[i]) for i in range(len(label_ids))} + n_articles = project_info['result']['project']['stats']['articles'] + articles = [resp for resp in tqdm.tqdm(client.fetch_all_articles(project_id), total=n_articles)] - # Constructing the data payload as per the server's expectation - data = {"project-id": project_id, "article-id": article_id, "label-values": label_values_dict} - data.update({ "confirm?": confirm, "change?": change, "resolve?": resolve }) + article_labels = [a['labels'] for a in articles if a['labels'] is not None] + article_labels = [lbl for lbls in article_labels for lbl in lbls] + article_label_df = pd.DataFrame(article_labels) + article_label_df['answer'] = article_label_df['answer'].apply(json.dumps) - # Sending a POST request to the server - response = requests.post(endpoint, json=data, headers=headers) - return response.json() - - def get_project_articles(self, project_id, offset=0, limit=10, sort_by=None, sort_dir=None): - endpoint = f"{self.base_url}/api-json/project-articles" - headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} - body = {"project-id": project_id, "n-offset": offset, "n-count": limit} + article_data = [{k: v for k, v in a.items() if k != 'labels'} for a in articles] + article_data_df = pd.DataFrame(article_data) + article_data_df['notes'] = article_data_df['notes'].apply(json.dumps) + article_data_df['resolve'] = article_data_df['resolve'].apply(json.dumps) - # Add optional sorting keys if provided - if sort_by: body["sort-by"] = sort_by - if sort_dir: body["sort-dir"] = sort_dir + self.sync_article_info(client, project_id, article_data_df['article-id']) - # Make the POST request with the simplified body - response = requests.post(endpoint, headers=headers, json=body) - return response.json() - - def fetch_all_articles(self, project_id, limit=10, sort_by=None, sort_dir=None): - offset = 0 - while True: - result = self.get_project_articles(project_id, offset=offset, limit=limit, sort_by=sort_by, sort_dir=sort_dir) - articles = result.get('result', []) - if not articles: - break # Stop iteration if no articles are left - yield from articles # Yield each article in the current batch - offset += len(articles) - - def get_article_info(self, project_id, article_id): - endpoint = f"{self.base_url}/api-json/article-info/{article_id}" - headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} - body = {"project-id": project_id,} - response = requests.get(endpoint, headers=headers, json=body) - return response.json()['result'] - - def upload_jsonlines(self, file_path, project_id): - url = f"{self.base_url}/api-json/import-files/{project_id}" - headers = {"Authorization": f"Bearer {self.api_key}"} - - # Prepare the file for upload - with open(file_path, 'rb') as f: - files = {'file': (file_path.split('/')[-1], f, 'application/octet-stream')} - # Let requests handle "Content-Type" - response = requests.post(url, headers=headers, files=files) - - return response - - def get_article_file(self, project_id, article_id, hash): - url = f"{self.base_url}/api-json/files/{project_id}/article/{article_id}/download/{hash}" - headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} - \ No newline at end of file + # Writing data to tables + self.write_df(article_label_df,'article_label') + self.write_df(article_data_df,'article_data') \ No newline at end of file