forked from run-llama/llama-hub
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support ArangoDB simple loader (run-llama#900)
* Support arango db simple loader * Fix issues, add tests. And update requirements
- Loading branch information
1 parent
fb4d2fb
commit 7639084
Showing
7 changed files
with
224 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# ArangoDB Loader | ||
|
||
This loader loads documents from ArangoDB. The user specifies a ArangoDB instance to | ||
initialize the reader. They then specify the collection name and query params to | ||
fetch the relevant docs. | ||
|
||
## Usage | ||
|
||
Here's an example usage of the SimpleArangoDBReader. | ||
|
||
```python | ||
from llama_index import download_loader | ||
import os | ||
|
||
SimpleArangoDBReader = download_loader('SimpleArangoDBReader') | ||
|
||
host = "<host>" | ||
db_name = "<db_name>" | ||
collection_name = "<collection_name>" | ||
# query_dict is passed into db.collection.find() | ||
query_dict = {} | ||
# Attribute of interests to load, by default ["text"] | ||
field_names = ["title", "description"] | ||
reader = SimpleArangoDBReader(host) # or pass ArangoClient | ||
documents = reader.load_data( | ||
username, | ||
password, | ||
db_name, | ||
collection_name, | ||
query_dict=query_dict, | ||
field_names=field_names, | ||
) | ||
``` | ||
|
||
This loader is designed to be used as a way to load data into [LlamaIndex](https://github.com/run-llama/llama_index/tree/main/llama_index) and/or subsequently used as a Tool in a [LangChain](https://github.com/hwchase17/langchain) Agent. See [here](https://github.com/run-llama/llama-hub/tree/main/llama_hub) for examples. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from llama_hub.arango_db.base import ( | ||
SimpleArangoDBReader, | ||
) | ||
|
||
__all__ = ["SimpleArangoDBReader"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
"""ArangoDB client.""" | ||
|
||
from typing import Any, Dict, Iterator, List, Optional, Union, cast | ||
|
||
from llama_index.readers.base import BaseReader | ||
from llama_index.readers.schema.base import Document | ||
|
||
|
||
class SimpleArangoDBReader(BaseReader): | ||
"""Simple arangodb reader. | ||
Concatenates each ArangoDB doc into Document used by LlamaIndex. | ||
Args: | ||
host: (Union[str, List[str]]) list of urls or url for connecting to the db | ||
client: (Any) ArangoDB client | ||
""" | ||
|
||
def __init__( | ||
self, host: Optional[Union[str, List[str]]] = None, client: Optional[Any] = None | ||
) -> None: | ||
"""Initialize with parameters.""" | ||
try: | ||
from arango import ArangoClient | ||
except ImportError as err: | ||
raise ImportError( | ||
"`arango` package not found, please run `pip install python-arango`" | ||
) from err | ||
|
||
host = host or "http://127.0.0.1:8529" | ||
self.client = client or ArangoClient(hosts=host) | ||
self.client = cast(ArangoClient, self.client) | ||
|
||
def _flatten(self, texts: List[Union[str, List[str]]]) -> List[str]: | ||
result = [] | ||
for text in texts: | ||
result += text if isinstance(text, list) else [text] | ||
return result | ||
|
||
def lazy_load( | ||
self, | ||
username: str, | ||
password: str, | ||
db_name: str, | ||
collection_name: str, | ||
field_names: List[str] = ["text"], | ||
separator: str = " ", | ||
query_dict: Optional[Dict] = {}, | ||
max_docs: int = None, | ||
metadata_names: Optional[List[str]] = None, | ||
) -> Iterator[Document]: | ||
"""Lazy load data from ArangoDB. | ||
Args: | ||
username (str): for credentials. | ||
password (str): for credentials. | ||
db_name (str): name of the database. | ||
collection_name (str): name of the collection. | ||
field_names(List[str]): names of the fields to be concatenated. | ||
Defaults to ["text"] | ||
separator (str): separator to be used between fields. | ||
Defaults to " " | ||
query_dict (Optional[Dict]): query to filter documents. Read more | ||
at [docs](https://docs.python-arango.com/en/main/specs.html#arango.collection.StandardCollection.find) | ||
Defaults to empty dict | ||
max_docs (int): maximum number of documents to load. | ||
Defaults to None (no limit) | ||
metadata_names (Optional[List[str]]): names of the fields to be added | ||
to the metadata attribute of the Document. Defaults to None | ||
Returns: | ||
List[Document]: A list of documents. | ||
""" | ||
db = self.client.db(name=db_name, username=username, password=password) | ||
collection = db.collection(collection_name) | ||
cursor = collection.find(filters=query_dict, limit=max_docs) | ||
for item in cursor: | ||
try: | ||
texts = [str(item[name]) for name in field_names] | ||
except KeyError as err: | ||
raise ValueError( | ||
f"{err.args[0]} field not found in arangodb document." | ||
) from err | ||
texts = self._flatten(texts) | ||
text = separator.join(texts) | ||
|
||
if metadata_names is None: | ||
yield Document(text=text) | ||
else: | ||
try: | ||
metadata = {name: item[name] for name in metadata_names} | ||
except KeyError as err: | ||
raise ValueError( | ||
f"{err.args[0]} field not found in arangodb document." | ||
) from err | ||
yield Document(text=text, metadata=metadata) | ||
|
||
def load_data( | ||
self, | ||
username: str, | ||
password: str, | ||
db_name: str, | ||
collection_name: str, | ||
field_names: List[str] = ["text"], | ||
separator: str = " ", | ||
query_dict: Optional[Dict] = {}, | ||
max_docs: int = None, | ||
metadata_names: Optional[List[str]] = None, | ||
) -> List[Document]: | ||
"""Load data from the ArangoDB. | ||
Args: | ||
username (str): for credentials. | ||
password (str): for credentials. | ||
db_name (str): name of the database. | ||
collection_name (str): name of the collection. | ||
field_names(List[str]): names of the fields to be concatenated. | ||
Defaults to ["text"] | ||
separator (str): separator to be used between fields. | ||
Defaults to "" | ||
query_dict (Optional[Dict]): query to filter documents. Read more | ||
at [docs](https://docs.python-arango.com/en/main/specs.html#arango.collection.StandardCollection.find) | ||
Defaults to empty dict | ||
max_docs (int): maximum number of documents to load. | ||
Defaults to 0 (no limit) | ||
metadata_names (Optional[List[str]]): names of the fields to be added | ||
to the metadata attribute of the Document. Defaults to None | ||
Returns: | ||
List[Document]: A list of documents. | ||
""" | ||
return list( | ||
self.lazy_load( | ||
username, | ||
password, | ||
db_name, | ||
collection_name, | ||
field_names, | ||
separator, | ||
query_dict, | ||
max_docs, | ||
metadata_names, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
python-arango |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ olefile | |
chromadb | ||
snowflake-sqlalchemy | ||
selenium | ||
python-arango | ||
|
||
# hotfix | ||
psutil | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from unittest.mock import MagicMock, patch | ||
import pytest | ||
|
||
from llama_hub.arango_db import SimpleArangoDBReader | ||
|
||
|
||
@pytest.fixture | ||
def mock_arangodb_client(): | ||
with patch("arango.ArangoClient") as mock_client: | ||
# Mock the behavior of the db and collection | ||
mock_db = MagicMock() | ||
mock_students_collection = MagicMock() | ||
|
||
mock_students = [ | ||
{"_key": "1", "name": "Alice", "age": 20}, | ||
{"_key": "2", "name": "Bob", "age": 21}, | ||
{"_key": "3", "name": "Mark", "age": 20}, | ||
] | ||
|
||
mock_students_collection.find.return_value = mock_students | ||
mock_db.collection.return_value = mock_students_collection | ||
mock_client.db.return_value = mock_db | ||
|
||
yield mock_client | ||
|
||
|
||
def test_load_students(mock_arangodb_client): | ||
reader = SimpleArangoDBReader(client=mock_arangodb_client) | ||
documents = reader.load_data( | ||
username="usr", | ||
password="pass", | ||
db_name="school", | ||
collection_name="students", | ||
field_names=["name", "age"], | ||
) | ||
|
||
assert len(documents) == 3 | ||
assert documents[0].text == "Alice 20" | ||
assert documents[1].text == "Bob 21" | ||
assert documents[2].text == "Mark 20" |