-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsetup.py
112 lines (85 loc) · 3.35 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Import API keys (in a .gitignore'd file)
from keys import weaviate_key, weaviate_url, openai_key, spotify_id, spotify_secret
# Weaviate packages
import weaviate
import weaviate.classes as wvc
# Regular web-based packages
import requests
import json
import urllib.parse
import urllib
# Pandas for reading CSV
import pandas as pd
# TQDM for progress bar on data upload
from tqdm.auto import tqdm
# tiktoken for accurate count of openAI tokenisation (counting tokens)
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # I think this is the encoding scheme for text2vec
def create_collection(client, collection_name="lyrics"):
"""
Create collections (in place, online) for the Weaviate instance.
"""
if not client.collections.exists(collection_name):
client.collections.create(
name=collection_name,
vectorizer_config=wvc.config.Configure.Vectorizer.text2vec_openai(),
generative_config=wvc.config.Configure.Generative.openai()
)
print(f"Collection {collection_name} created.")
else:
print(f"Collection {collection_name} already exists, continuing.")
def read_data(fname = "all_lyrics.csv"):
"""
Use pandas to read the song lyrics CSV that is stored locally.
Default dataset is from: https://www.kaggle.com/datasets/elizzyliu/song-lyrics
"""
df = pd.read_csv(fname)
return df
def limit_tokens(x, max_tokens):
if isinstance(x, str):
tokenized = enc.encode(x)
if len(tokenized) > max_tokens:
tokenized = tokenized[:max_tokens]
return enc.decode(tokenized)
else:
return x
# def limit_tokens(x, max_tokens):
# if isinstance(x, str):
# return x[:max_tokens]
def add_data_to_client(client, collection_name, df, columns = [], max_tokens = 8192):
"""
Given a pandas dataframe, add data information to the collection in the Weaviate client so that it can be vectorised.
"""
if len(columns) == 0:
columns = df.columns
elif any([col not in df.columns for col in columns]):
raise ValueError("Columns not found in dataframe.")
# Enumerate df and add to client
data = []
n = len(df)
# n = min(n, 100) # limit to 100 for now for testing
print("Adding data to client...")
for i in tqdm(range(n)):
d = df.iloc[i]
# only add what is in the columns argument
to_add = {c: limit_tokens(d[c], max_tokens) for c in columns}
data.append(to_add)
# add data to given collection
collection = client.collections.get(collection_name)
# load in chunks of 1000
for i in range(3000, len(data), 1000):
print(f"Adding data from {i} to {i+1000}...")
collection.data.insert_many(data[i:i+1000])
if __name__ == "__main__":
# connect to client
client = weaviate.connect_to_weaviate_cloud(
cluster_url = weaviate_url,
auth_credentials=weaviate.auth.AuthApiKey(weaviate_key),
headers={
"X-OpenAI-Api-Key": openai_key # Replace with your inference API key
}
)
# create, add data, and vectorise etc.
create_collection(client, collection_name="lyrics")
df = read_data(fname = "all_lyrics.csv")
add_data_to_client(client, "lyrics", df, max_tokens = 7000) # 7000 < 8192 for safety in case limit_tokens is not perfect and even with tiktoken it doesnt seem to work well