Skip to content

Commit

Permalink
linting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Jul 2, 2024
1 parent c574d39 commit 44a1810
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 273 deletions.
102 changes: 31 additions & 71 deletions elm/db_wiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@
ELM energy wizard
"""
import os
import copy
import numpy as np
import json
import psycopg2
from datetime import date, datetime
import pandas as pd

from elm.base import ApiBase
from elm.wizard import EnergyWizard


class DataBaseWizard(ApiBase):
"""Interface to ask OpenAI LLMs about energy research."""
Expand All @@ -21,7 +17,7 @@ class DataBaseWizard(ApiBase):
"output based on user queries.")
"""High level model role, somewhat redundant to MODEL_INSTRUCTION"""

def __init__(self, connection_string, model=None, token_budget=3500, ref_col=None):
def __init__(self, connection_string, model=None, token_budget=3500):
"""
Parameters
----------
Expand All @@ -34,12 +30,15 @@ def __init__(self, connection_string, model=None, token_budget=3500, ref_col=Non
Number of tokens that can be embedded in the prompt. Note that the
default budget for GPT-3.5-Turbo is 4096, but you want to subtract
some tokens to account for the response budget.
ref_col : None | str
Optional column label in the corpus that provides a reference text
string for each chunk of text.
"""

super().__init__(model)

self.query = None
self.sql = None
self.df = None
self.py = None

self.connection_string = connection_string
self.connection = psycopg2.connect(self.connection_string)
self.token_budget = token_budget
Expand All @@ -49,7 +48,6 @@ def __init__(self, connection_string, model=None, token_budget=3500, ref_col=Non
if os.path.exists(fpcache):
with open(fpcache, 'r') as f:
self.database_describe = f.read()

else:
print('Error no expert database file')

Expand All @@ -62,7 +60,8 @@ def get_sql_for(self, query):
'user question: "{}"\n\n'
'Return all columns from the database. '
'All the tables are in the schema "loads"'
'Please only return the SQL query with no commentary or preface.'
'Please only return the SQL query with '
'no commentary or preface.'
.format(self.database_describe, query))
out = super().chat(e_query, temperature=0)
print(out)
Expand All @@ -74,8 +73,8 @@ def run_sql(self, sql):
response."""
query = sql
print(query)
# Move Connection or cursor to init and test so that you aren't re-intializing
# it with each instance.
# Move Connection or cursor to init and test so that you aren't
# re-intializing it with each instance.
with self.connection.cursor() as cursor:
cursor.execute(query)
data = cursor.fetchall()
Expand All @@ -84,93 +83,54 @@ def run_sql(self, sql):
return df

def get_py_code(self, query, df):
""""""
"""Get python code to respond to query"""
e_query = ('Great it worked! I have made a dataframe from the output '
'of the SQL query you gave me. '
'Here is the dataframe head: \n{}\n'
'Here is the dataframe tail: \n{}\n'
'Here is the dataframe description: \n{}\n'
'Here is the dataframe datatypes: \n{}\n'
'Now please write python code using matplotlib to plot '
'the data in the dataframe based on the original user query: "{}"'
.format(df.head(), df.tail(), df.describe(), df.dtypes, query))
'the data in the dataframe based on '
'the original user query: "{}"'
.format(df.head(), df.tail(), df.describe(),
df.dtypes, query))
out = super().chat(e_query, temperature=0)

## get response from output
# get response from output
# Need to fix full response
full_response = out
#print(full_response)
## get python code from response
full_response = full_response[full_response.find('```python')+9:]
# get python code from response
full_response = full_response[full_response.find('```python') + 9:]
full_response = full_response[:full_response.find('```')]
py = full_response
return py

# pylint: disable=unused-argument
def run_py_code(self, py, df):
"""Run the python code with ``exec`` to plot the queried data
Caution using this method! Exec can be dangerous. Need to do more work
to make this safe.
"""
try:
# pylint: disable=exec-used
exec(py)
return plt
except:
except Exception:
print(py)

def chat(self, query,
debug=True,
stream=True,
temperature=0,
convo=False,
token_budget=None,
new_info_threshold=0.7,
print_references=False,
return_chat_obj=False):
def chat(self, query):
"""Answers a query by doing a semantic search of relevant text with
embeddings and then sending engineered query to the LLM.
Parameters
----------
query : str
Question being asked of EnergyWizard
debug : bool
Flag to return extra diagnostics on the engineered question.
stream : bool
Flag to print subsequent chunks of the response in a streaming
fashion
temperature : float
GPT model temperature, a measure of response entropy from 0 to 1. 0
is more reliable and nearly deterministic; 1 will give the model
more creative freedom and may not return as factual of results.
convo : bool
Flag to perform semantic search with full conversation history
(True) or just the single query (False). Call EnergyWizard.clear()
to reset the chat history.
token_budget : int
Option to override the class init token budget.
new_info_threshold : float
New text added to the engineered query must contain at least this
much new information. This helps prevent (for example) the table of
contents being added multiple times.
print_references : bool
Flag to print references if EnergyWizard is initialized with a
valid ref_col.
return_chat_obj : bool
Flag to only return the ChatCompletion from OpenAI API.
Returns
-------
response : str
GPT output / answer.
query : str
If debug is True, the engineered query asked of GPT will also be
returned here
references : list
If debug is True, the list of references (strs) used in the
engineered prompt is returned here
"""

self.query = query
self.sql = self.get_sql_for(query)
self.df = self.run_sql(self.sql)
self.py = self.get_py_code(query = query, df = self.df)
self.plt = self.run_py_code(self.py, self.df)
return self.plt


self.py = self.get_py_code(query=query, df=self.df)
self.run_py_code(self.py, self.df)
110 changes: 39 additions & 71 deletions elm/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from glob import glob
import pandas as pd
import sys
import copy
import numpy as np


from elm.base import ApiBase
Expand Down Expand Up @@ -43,16 +41,17 @@
'answer the question, say "I do not know."')
EnergyWizard.MODEL_INSTRUCTION = EnergyWizard.MODEL_ROLE

DataBaseWizard.URL = (f'https://stratus-embeddings-south-central.openai.azure.com/'
f'openai/deployments/{model}/chat/'
f'completions?api-version={openai.api_version}')
DataBaseWizard.URL = (
f'https://stratus-embeddings-south-central.openai.azure.com/'
f'openai/deployments/{model}/chat/'
f'completions?api-version={openai.api_version}')
DataBaseWizard.HEADERS = {"Content-Type": "application/json",
"Authorization": f"Bearer {openai.api_key}",
"api-key": f"{openai.api_key}",
}
"Authorization": f"Bearer {openai.api_key}",
"api-key": f"{openai.api_key}"}

st.set_option('deprecation.showPyplotGlobalUse', False)


@st.cache_data
def get_corpus():
"""Get the corpus of text data with embeddings."""
Expand All @@ -64,22 +63,21 @@ def get_corpus():


@st.cache_resource
def get_wizard(model = model):
def get_wizard(model=model):
"""Get the energy wizard object.
Parameters
----------
model : str
State which model to use for the energy wizard.
Returns
-------
response : str
GPT output / answer.
wizard : EnergyWizard
Returns the energy wizard object for use in chat responses.
"""

Parameters
----------
model : str
State which model to use for the energy wizard.
Returns
-------
response : str
GPT output / answer.
wizard : EnergyWizard
Returns the energy wizard object for use in chat responses.
"""

# Getting Corpus of data. If no corpus throw error for user.
try:
Expand All @@ -93,8 +91,9 @@ def get_wizard(model = model):
wizard = EnergyWizard(corpus, ref_col='ref', model=model)
return wizard


class MixtureOfExperts(ApiBase):
"""Interface to ask OpenAI LLMs about energy
"""Interface to ask OpenAI LLMs about energy
research either from a database or report."""

"""Parameters
Expand All @@ -118,63 +117,33 @@ class MixtureOfExperts(ApiBase):
"a database and creating a figure.")
"""High level model role, somewhat redundant to MODEL_INSTRUCTION"""

def __init__(self, connection_string, model=None, token_budget=3500, ref_col=None):
self.wizard_db = DataBaseWizard(model = model, connection_string = connection_string)
self.wizard_chat = get_wizard()
self.model = model
def __init__(self, db_wiz, txt_wiz, model=None):
self.wizard_db = db_wiz
self.wizard_chat = txt_wiz
super().__init__(model)

def chat(self, query,
debug=True,
stream=True,
temperature=0,
convo=False,
token_budget=None,
new_info_threshold=0.7,
print_references=False,
return_chat_obj=False):
temperature=0):
"""Answers a query by doing a semantic search of relevant text with
embeddings and then sending engineered query to the LLM.
Parameters
----------
query : str
Question being asked of EnergyWizard
debug : bool
Flag to return extra diagnostics on the engineered question.
stream : bool
Flag to print subsequent chunks of the response in a streaming
fashion
temperature : float
GPT model temperature, a measure of response entropy from 0 to 1. 0
is more reliable and nearly deterministic; 1 will give the model
more creative freedom and may not return as factual of results.
convo : bool
Flag to perform semantic search with full conversation history
(True) or just the single query (False). Call EnergyWizard.clear()
to reset the chat history.
token_budget : int
Option to override the class init token budget.
new_info_threshold : float
New text added to the engineered query must contain at least this
much new information. This helps prevent (for example) the table of
contents being added multiple times.
print_references : bool
Flag to print references if EnergyWizard is initialized with a
valid ref_col.
return_chat_obj : bool
Flag to only return the ChatCompletion from OpenAI API.
Returns
-------
response : str
GPT output / answer.
query : str
If debug is True, the engineered query asked of GPT will also be
returned here
references : list
If debug is True, the list of references (strs) used in the
engineered prompt is returned here
"""

messages = [{"role": "system", "content": self.MODEL_ROLE},
Expand All @@ -198,31 +167,30 @@ def chat(self, query,
else:
response_message = response["choices"][0]["message"]["content"]


message_placeholder = st.empty()
full_response = ""

if '1' in response_message:
out = self.wizard_chat.chat(query,
debug=True, stream=True, token_budget=6000,
temperature=0.0, print_references=True,
convo=False, return_chat_obj=True)

debug=True, stream=True,
token_budget=6000, temperature=0.0,
print_references=True, convo=False,
return_chat_obj=True)

for response in out[0]:
full_response += response.choices[0].delta.content or ""
message_placeholder.markdown(full_response + "▌")


elif '2' in response_message:
elif '2' in response_message:
out = self.wizard_db.chat(query,
debug=True, stream=True, token_budget=6000,
temperature=0.0, print_references=True,
convo=False, return_chat_obj=True)
debug=True, stream=True,
token_budget=6000, temperature=0.0,
print_references=True, convo=False,
return_chat_obj=True)

st.pyplot(fig = out, clear_figure = False)
st.pyplot(fig=out, clear_figure=False)

else:
else:
response_message = 'Error cannot find data in report or database.'


return full_response
return full_response
Loading

0 comments on commit 44a1810

Please sign in to comment.