Skip to content

Commit

Permalink
add API loader
Browse files Browse the repository at this point in the history
Signed-off-by: Yotam Perlitz <[email protected]>
  • Loading branch information
perlitz committed Jan 21, 2025
1 parent acd254f commit b2e3247
Showing 1 changed file with 183 additions and 9 deletions.
192 changes: 183 additions & 9 deletions src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import fnmatch
import itertools
import json
import os
import tempfile
from abc import abstractmethod
Expand All @@ -41,6 +42,7 @@
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union

import pandas as pd
import requests
from datasets import IterableDatasetDict
from datasets import load_dataset as hf_load_dataset
from huggingface_hub import HfApi
Expand Down Expand Up @@ -570,15 +572,15 @@ def prepare(self):

def lazy_verify(self):
super().verify()
assert (
self.endpoint_url is not None
), f"Please set the {self.endpoint_url_env} environmental variable"
assert (
self.aws_access_key_id is not None
), f"Please set {self.aws_access_key_id_env} environmental variable"
assert (
self.aws_secret_access_key is not None
), f"Please set {self.aws_secret_access_key_env} environmental variable"
assert self.endpoint_url is not None, (
f"Please set the {self.endpoint_url_env} environmental variable"
)
assert self.aws_access_key_id is not None, (
f"Please set {self.aws_access_key_id_env} environmental variable"
)
assert self.aws_secret_access_key is not None, (
f"Please set {self.aws_secret_access_key_env} environmental variable"
)
if self.streaming:
raise NotImplementedError("LoadFromKaggle cannot load with streaming.")

Expand Down Expand Up @@ -922,3 +924,175 @@ def load_data(self):
self._map_wildcard_path_to_full_paths()
self.path = self._download_data()
return super().load_data()

# url: str

# _requirements_list: List[str] = ["opendatasets"]
# data_classification_policy = ["public"]

# def verify(self):
# super().verify()
# if not os.path.isfile("kaggle.json"):
# raise MissingKaggleCredentialsError(
# "Please obtain kaggle credentials https://christianjmills.com/posts/kaggle-obtain-api-key-tutorial/ and save them to local ./kaggle.json file"
# )

# if self.streaming:
# raise NotImplementedError("LoadFromKaggle cannot load with streaming.")

# def prepare(self):
# super().prepare()
# from opendatasets import download

# self.downloader = download

# def load_iterables(self):
# with TemporaryDirectory() as temp_directory:
# self.downloader(self.url, temp_directory)
# return hf_load_dataset(temp_directory, streaming=False)

# class LoadFromAPI(Loader):
# """Loads data from from API"""

# urls: Dict[str, str]
# chunksize: int = 100000
# loader_limit: Optional[int] = None
# streaming: bool = False

# def _maybe_set_classification_policy(self):
# self.set_default_data_classification(["proprietary"], "when loading from API")

# def load_iterables(self):
self.api_key = os.getenv("SQL_API_KEY", None)
if not self.api_key:
raise ValueError(
"The environment variable 'SQL_API_KEY' must be set to use the RemoteDatabaseConnector."
)

self.base_headers = {
"Content-Type": "application/json",
"accept": "application/json",
"Authorization": f"Bearer {self.api_key}",
}

iterables = {}
for split_name, url in self.urls.items():
response = requests.get(
url,
headers=self.base_headers,
verify=True,
)

iterables[split_name] = pd.DataFrame(
json.loads(response.text)["embeddings"]
)

return iterables


class LoadFromAPI(Loader):
"""Loads data from from API.
This loader is designed to fetch data from an API endpoint,
handling authentication through an API key. It supports
customizable chunk sizes and limits for data retrieval.
Args:
urls (Dict[str, str]):
A dictionary mapping split names to their respective API URLs.
chunksize (int, optional):
The size of data chunks to fetch in each request. Defaults to 100,000.
loader_limit (int, optional):
Limits the number of records to load. Applied per split. Defaults to None.
streaming (bool, optional):
Determines if data should be streamed. Defaults to False.
api_key_env_var (str, optional):
The name of the environment variable holding the API key.
Defaults to "SQL_API_KEY".
headers (Dict[str, Any], optional):
Additional headers to include in API requests. Defaults to None.
data_field (str, optional):
The name of the field in the API response that contains the data.
Defaults to "data".
method (str, optional):
The HTTP method to use for API requests. Defaults to "GET".
"""

urls: Dict[str, str]
chunksize: int = 100000
loader_limit: Optional[int] = None
streaming: bool = False
api_key_env_var: str = "SQL_API_KEY"
headers: Optional[Dict[str, Any]] = None
data_field: str = "data"
method: str = "GET"

# class level shared cache:
_loader_cache = LRUCache(max_size=settings.loader_cache_size)

def _maybe_set_classification_policy(self):
self.set_default_data_classification(["proprietary"], "when loading from API")

def load_iterables(self) -> Dict[str, Iterable]:
api_key = os.getenv(self.api_key_env_var, None)
if not api_key:
raise ValueError(
f"The environment variable '{self.api_key_env_var}' must be set to use the LoadFromAPI loader."
)

base_headers = {
"Content-Type": "application/json",
"accept": "application/json",
"Authorization": f"Bearer {api_key}",
}
if self.headers:
base_headers.update(self.headers)

iterables = {}
for split_name, url in self.urls.items():
if self.get_limit() is not None:
self.log_limited_loading()

if self.method == "GET":
response = requests.get(
url,
headers=base_headers,
verify=True,
)
elif self.method == "POST":
response = requests.post(
url,
headers=base_headers,
verify=True,
json={},
)
else:
raise ValueError(f"Method {self.method} not supported")

response.raise_for_status()

data = json.loads(response.text)

if self.data_field:
if self.data_field not in data:
raise ValueError(
f"Data field '{self.data_field}' not found in API response."
)
data = data[self.data_field]

if self.get_limit() is not None:
data = data[: self.get_limit()]

iterables[split_name] = data

return iterables

def process(self) -> MultiStream:
self._maybe_set_classification_policy()
iterables = self.__class__._loader_cache.get(str(self), None)
if iterables is None:
iterables = self.load_iterables()
self.__class__._loader_cache.max_size = settings.loader_cache_size
self.__class__._loader_cache[str(self)] = iterables
return MultiStream.from_iterables(iterables, copying=True)

0 comments on commit b2e3247

Please sign in to comment.