Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Db wiz #21

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions elm/db_wiz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# -*- coding: utf-8 -*-
"""
ELM energy wizard
"""
import os
import psycopg2
import pandas as pd

from elm.base import ApiBase


class DataBaseWizard(ApiBase):
"""Interface to ask OpenAI LLMs about energy research."""

MODEL_ROLE = ("You are a data engineer pulling data from a relational "
"database using SQL and writing python code to plot the "
"output based on user queries.")
"""High level model role, somewhat redundant to MODEL_INSTRUCTION"""

def __init__(self, connection_string, model=None, token_budget=3500):
"""
Parameters
----------
corpus : pd.DataFrame
Corpus of text in dataframe format. Must have columns "text" and
"embedding".
model : str
GPT model name, default is the DEFAULT_MODEL global var
token_budget : int
Number of tokens that can be embedded in the prompt. Note that the
default budget for GPT-3.5-Turbo is 4096, but you want to subtract
some tokens to account for the response budget.
"""

super().__init__(model)

self.query = None
self.sql = None
self.df = None
self.py = None

self.connection_string = connection_string
self.connection = psycopg2.connect(self.connection_string)
self.token_budget = token_budget

fpcache = './database_manual.txt'

if os.path.exists(fpcache):
with open(fpcache, 'r') as f:
self.database_describe = f.read()
else:
print('Error no expert database file')

# Getting sql from a generic query
def get_sql_for(self, query):
"""Take the raw user query and ask the LLM for a SQL query that will
get data to support a response
"""
e_query = ('{}\n\nPlease create a SQL query that will answer this '
'user question: "{}"\n\n'
'Return all columns from the database. '
'All the tables are in the schema "loads"'
'Please only return the SQL query with '
'no commentary or preface.'
.format(self.database_describe, query))
out = super().chat(e_query, temperature=0)
print(out)
return out

def run_sql(self, sql):
"""Takes a SQL query that can support a user prompt, runs SQL query
based on the db connection (self.connection), returns dataframe
response."""
query = sql
print(query)
# Move Connection or cursor to init and test so that you aren't
# re-intializing it with each instance.
with self.connection.cursor() as cursor:
cursor.execute(query)
data = cursor.fetchall()
column_names = [desc[0] for desc in cursor.description]
df = pd.DataFrame(data, columns=column_names)
return df

def get_py_code(self, query, df):
"""Get python code to respond to query"""
e_query = ('Great it worked! I have made a dataframe from the output '
'of the SQL query you gave me. '
'Here is the dataframe head: \n{}\n'
'Here is the dataframe tail: \n{}\n'
'Here is the dataframe description: \n{}\n'
'Here is the dataframe datatypes: \n{}\n'
'Now please write python code using matplotlib to plot '
'the data in the dataframe based on '
'the original user query: "{}"'
.format(df.head(), df.tail(), df.describe(),
df.dtypes, query))
out = super().chat(e_query, temperature=0)

# get response from output
# Need to fix full response
full_response = out
# get python code from response
full_response = full_response[full_response.find('```python') + 9:]
full_response = full_response[:full_response.find('```')]
py = full_response
return py

# pylint: disable=unused-argument
def run_py_code(self, py, df):
"""Run the python code with ``exec`` to plot the queried data

Caution using this method! Exec can be dangerous. Need to do more work
to make this safe.
"""
try:
# pylint: disable=exec-used
exec(py)
except Exception:
print(py)

def chat(self, query):
"""Answers a query by doing a semantic search of relevant text with
embeddings and then sending engineered query to the LLM.

Parameters
----------
query : str
Question being asked of EnergyWizard
"""

self.query = query
self.sql = self.get_sql_for(query)
self.df = self.run_sql(self.sql)
self.py = self.get_py_code(query=query, df=self.df)
self.run_py_code(self.py, self.df)
183 changes: 183 additions & 0 deletions elm/experts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""
ELM mixture of experts
"""
import streamlit as st
import os
import openai
from glob import glob
import pandas as pd
import sys


from elm.base import ApiBase
from elm.wizard import EnergyWizard
from elm.db_wiz import DataBaseWizard

model = 'gpt-4'

# NREL-Azure endpoint. You can also use just the openai endpoint.
# NOTE: embedding values are different between OpenAI and Azure models!
openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT")
openai.api_key = os.getenv("AZURE_OPENAI_KEY")
openai.api_type = 'azure'
openai.api_version = os.getenv('AZURE_OPENAI_VERSION')

EnergyWizard.EMBEDDING_MODEL = 'text-embedding-ada-002-2'
EnergyWizard.EMBEDDING_URL = ('https://stratus-embeddings-south-central.'
'openai.azure.com/openai/deployments/'
'text-embedding-ada-002-2/embeddings?'
f'api-version={openai.api_version}')
EnergyWizard.URL = ('https://stratus-embeddings-south-central.'
'openai.azure.com/openai/deployments/'
f'{model}/chat/completions?'
f'api-version={openai.api_version}')
EnergyWizard.HEADERS = {"Content-Type": "application/json",
"Authorization": f"Bearer {openai.api_key}",
"api-key": f"{openai.api_key}"}

EnergyWizard.MODEL_ROLE = ('You are a energy research assistant. Use the '
'articles below to answer the question. If '
'articles do not provide enough information to '
'answer the question, say "I do not know."')
EnergyWizard.MODEL_INSTRUCTION = EnergyWizard.MODEL_ROLE

DataBaseWizard.URL = (
f'https://stratus-embeddings-south-central.openai.azure.com/'
f'openai/deployments/{model}/chat/'
f'completions?api-version={openai.api_version}')
DataBaseWizard.HEADERS = {"Content-Type": "application/json",
"Authorization": f"Bearer {openai.api_key}",
"api-key": f"{openai.api_key}"}

st.set_option('deprecation.showPyplotGlobalUse', False)


@st.cache_data
def get_corpus():
"""Get the corpus of text data with embeddings."""
corpus = sorted(glob('./embed/*.json'))
corpus = [pd.read_json(fp) for fp in corpus]
corpus = pd.concat(corpus, ignore_index=True)

return corpus


@st.cache_resource
def get_wizard(model=model):
"""Get the energy wizard object.

Parameters
----------
model : str
State which model to use for the energy wizard.

Returns
-------
response : str
GPT output / answer.
wizard : EnergyWizard
Returns the energy wizard object for use in chat responses.
"""

# Getting Corpus of data. If no corpus throw error for user.
try:
corpus = get_corpus()
except Exception:
print("Error: Have you run 'retrieve_docs.py'?")
st.header("Error")
st.write("Error: Have you run 'retrieve_docs.py'?")
sys.exit(0)

wizard = EnergyWizard(corpus, ref_col='ref', model=model)
return wizard


class MixtureOfExperts(ApiBase):
"""Interface to ask OpenAI LLMs about energy
research either from a database or report."""

MODEL_ROLE = ("You are an expert given a query. Which of the "
"following best describes the query? Please "
"answer with just the number and nothing else."
"1. This is a query best answered by a text-based report."
"2. This is a query best answered by pulling data from "
"a database and creating a figure.")
"""High level model role, somewhat redundant to MODEL_INSTRUCTION"""

def __init__(self, db_wiz, txt_wiz, model=None):
self.wizard_db = db_wiz
self.wizard_chat = txt_wiz
super().__init__(model)

def chat(self, query,
stream=True,
temperature=0):
"""Answers a query by doing a semantic search of relevant text with
embeddings and then sending engineered query to the LLM.

Parameters
----------
query : str
Question being asked of EnergyWizard
stream : bool
Flag to print subsequent chunks of the response in a streaming
fashion
temperature : float
GPT model temperature, a measure of response entropy from 0 to 1. 0
is more reliable and nearly deterministic; 1 will give the model
more creative freedom and may not return as factual of results.

Returns
-------
response : str
GPT output / answer.
"""

messages = [{"role": "system", "content": self.MODEL_ROLE},
{"role": "user", "content": query}]
response_message = ''
kwargs = dict(model=self.model,
messages=messages,
temperature=temperature,
stream=stream)

response = self._client.chat.completions.create(**kwargs)

print(response)

if stream:
for chunk in response:
chunk_msg = chunk.choices[0].delta.content or ""
response_message += chunk_msg
print(chunk_msg, end='')

else:
response_message = response["choices"][0]["message"]["content"]

message_placeholder = st.empty()
full_response = ""

if '1' in response_message:
out = self.wizard_chat.chat(query,
debug=True, stream=True,
token_budget=6000, temperature=0.0,
print_references=True, convo=False,
return_chat_obj=True)

for response in out[0]:
full_response += response.choices[0].delta.content or ""
message_placeholder.markdown(full_response + "▌")

elif '2' in response_message:
out = self.wizard_db.chat(query,
debug=True, stream=True,
token_budget=6000, temperature=0.0,
print_references=True, convo=False,
return_chat_obj=True)

st.pyplot(fig=out, clear_figure=False)

else:
response_message = 'Error cannot find data in report or database.'

return full_response
Loading
Loading