diff --git a/elm/db_wiz.py b/elm/db_wiz.py new file mode 100644 index 00000000..3b287726 --- /dev/null +++ b/elm/db_wiz.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +""" +ELM energy wizard +""" +import os +import psycopg2 +import pandas as pd + +from elm.base import ApiBase + + +class DataBaseWizard(ApiBase): + """Interface to ask OpenAI LLMs about energy research.""" + + MODEL_ROLE = ("You are a data engineer pulling data from a relational " + "database using SQL and writing python code to plot the " + "output based on user queries.") + """High level model role, somewhat redundant to MODEL_INSTRUCTION""" + + def __init__(self, connection_string, model=None, token_budget=3500): + """ + Parameters + ---------- + corpus : pd.DataFrame + Corpus of text in dataframe format. Must have columns "text" and + "embedding". + model : str + GPT model name, default is the DEFAULT_MODEL global var + token_budget : int + Number of tokens that can be embedded in the prompt. Note that the + default budget for GPT-3.5-Turbo is 4096, but you want to subtract + some tokens to account for the response budget. + """ + + super().__init__(model) + + self.query = None + self.sql = None + self.df = None + self.py = None + + self.connection_string = connection_string + self.connection = psycopg2.connect(self.connection_string) + self.token_budget = token_budget + + fpcache = './database_manual.txt' + + if os.path.exists(fpcache): + with open(fpcache, 'r') as f: + self.database_describe = f.read() + else: + print('Error no expert database file') + + # Getting sql from a generic query + def get_sql_for(self, query): + """Take the raw user query and ask the LLM for a SQL query that will + get data to support a response + """ + e_query = ('{}\n\nPlease create a SQL query that will answer this ' + 'user question: "{}"\n\n' + 'Return all columns from the database. ' + 'All the tables are in the schema "loads"' + 'Please only return the SQL query with ' + 'no commentary or preface.' + .format(self.database_describe, query)) + out = super().chat(e_query, temperature=0) + print(out) + return out + + def run_sql(self, sql): + """Takes a SQL query that can support a user prompt, runs SQL query + based on the db connection (self.connection), returns dataframe + response.""" + query = sql + print(query) + # Move Connection or cursor to init and test so that you aren't + # re-intializing it with each instance. + with self.connection.cursor() as cursor: + cursor.execute(query) + data = cursor.fetchall() + column_names = [desc[0] for desc in cursor.description] + df = pd.DataFrame(data, columns=column_names) + return df + + def get_py_code(self, query, df): + """Get python code to respond to query""" + e_query = ('Great it worked! I have made a dataframe from the output ' + 'of the SQL query you gave me. ' + 'Here is the dataframe head: \n{}\n' + 'Here is the dataframe tail: \n{}\n' + 'Here is the dataframe description: \n{}\n' + 'Here is the dataframe datatypes: \n{}\n' + 'Now please write python code using matplotlib to plot ' + 'the data in the dataframe based on ' + 'the original user query: "{}"' + .format(df.head(), df.tail(), df.describe(), + df.dtypes, query)) + out = super().chat(e_query, temperature=0) + + # get response from output + # Need to fix full response + full_response = out + # get python code from response + full_response = full_response[full_response.find('```python') + 9:] + full_response = full_response[:full_response.find('```')] + py = full_response + return py + + # pylint: disable=unused-argument + def run_py_code(self, py, df): + """Run the python code with ``exec`` to plot the queried data + + Caution using this method! Exec can be dangerous. Need to do more work + to make this safe. + """ + try: + # pylint: disable=exec-used + exec(py) + except Exception: + print(py) + + def chat(self, query): + """Answers a query by doing a semantic search of relevant text with + embeddings and then sending engineered query to the LLM. + + Parameters + ---------- + query : str + Question being asked of EnergyWizard + """ + + self.query = query + self.sql = self.get_sql_for(query) + self.df = self.run_sql(self.sql) + self.py = self.get_py_code(query=query, df=self.df) + self.run_py_code(self.py, self.df) diff --git a/elm/experts.py b/elm/experts.py new file mode 100644 index 00000000..de39f46b --- /dev/null +++ b/elm/experts.py @@ -0,0 +1,183 @@ +""" +ELM mixture of experts +""" +import streamlit as st +import os +import openai +from glob import glob +import pandas as pd +import sys + + +from elm.base import ApiBase +from elm.wizard import EnergyWizard +from elm.db_wiz import DataBaseWizard + +model = 'gpt-4' + +# NREL-Azure endpoint. You can also use just the openai endpoint. +# NOTE: embedding values are different between OpenAI and Azure models! +openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT") +openai.api_key = os.getenv("AZURE_OPENAI_KEY") +openai.api_type = 'azure' +openai.api_version = os.getenv('AZURE_OPENAI_VERSION') + +EnergyWizard.EMBEDDING_MODEL = 'text-embedding-ada-002-2' +EnergyWizard.EMBEDDING_URL = ('https://stratus-embeddings-south-central.' + 'openai.azure.com/openai/deployments/' + 'text-embedding-ada-002-2/embeddings?' + f'api-version={openai.api_version}') +EnergyWizard.URL = ('https://stratus-embeddings-south-central.' + 'openai.azure.com/openai/deployments/' + f'{model}/chat/completions?' + f'api-version={openai.api_version}') +EnergyWizard.HEADERS = {"Content-Type": "application/json", + "Authorization": f"Bearer {openai.api_key}", + "api-key": f"{openai.api_key}"} + +EnergyWizard.MODEL_ROLE = ('You are a energy research assistant. Use the ' + 'articles below to answer the question. If ' + 'articles do not provide enough information to ' + 'answer the question, say "I do not know."') +EnergyWizard.MODEL_INSTRUCTION = EnergyWizard.MODEL_ROLE + +DataBaseWizard.URL = ( + f'https://stratus-embeddings-south-central.openai.azure.com/' + f'openai/deployments/{model}/chat/' + f'completions?api-version={openai.api_version}') +DataBaseWizard.HEADERS = {"Content-Type": "application/json", + "Authorization": f"Bearer {openai.api_key}", + "api-key": f"{openai.api_key}"} + +st.set_option('deprecation.showPyplotGlobalUse', False) + + +@st.cache_data +def get_corpus(): + """Get the corpus of text data with embeddings.""" + corpus = sorted(glob('./embed/*.json')) + corpus = [pd.read_json(fp) for fp in corpus] + corpus = pd.concat(corpus, ignore_index=True) + + return corpus + + +@st.cache_resource +def get_wizard(model=model): + """Get the energy wizard object. + + Parameters + ---------- + model : str + State which model to use for the energy wizard. + + Returns + ------- + response : str + GPT output / answer. + wizard : EnergyWizard + Returns the energy wizard object for use in chat responses. + """ + + # Getting Corpus of data. If no corpus throw error for user. + try: + corpus = get_corpus() + except Exception: + print("Error: Have you run 'retrieve_docs.py'?") + st.header("Error") + st.write("Error: Have you run 'retrieve_docs.py'?") + sys.exit(0) + + wizard = EnergyWizard(corpus, ref_col='ref', model=model) + return wizard + + +class MixtureOfExperts(ApiBase): + """Interface to ask OpenAI LLMs about energy + research either from a database or report.""" + + MODEL_ROLE = ("You are an expert given a query. Which of the " + "following best describes the query? Please " + "answer with just the number and nothing else." + "1. This is a query best answered by a text-based report." + "2. This is a query best answered by pulling data from " + "a database and creating a figure.") + """High level model role, somewhat redundant to MODEL_INSTRUCTION""" + + def __init__(self, db_wiz, txt_wiz, model=None): + self.wizard_db = db_wiz + self.wizard_chat = txt_wiz + super().__init__(model) + + def chat(self, query, + stream=True, + temperature=0): + """Answers a query by doing a semantic search of relevant text with + embeddings and then sending engineered query to the LLM. + + Parameters + ---------- + query : str + Question being asked of EnergyWizard + stream : bool + Flag to print subsequent chunks of the response in a streaming + fashion + temperature : float + GPT model temperature, a measure of response entropy from 0 to 1. 0 + is more reliable and nearly deterministic; 1 will give the model + more creative freedom and may not return as factual of results. + + Returns + ------- + response : str + GPT output / answer. + """ + + messages = [{"role": "system", "content": self.MODEL_ROLE}, + {"role": "user", "content": query}] + response_message = '' + kwargs = dict(model=self.model, + messages=messages, + temperature=temperature, + stream=stream) + + response = self._client.chat.completions.create(**kwargs) + + print(response) + + if stream: + for chunk in response: + chunk_msg = chunk.choices[0].delta.content or "" + response_message += chunk_msg + print(chunk_msg, end='') + + else: + response_message = response["choices"][0]["message"]["content"] + + message_placeholder = st.empty() + full_response = "" + + if '1' in response_message: + out = self.wizard_chat.chat(query, + debug=True, stream=True, + token_budget=6000, temperature=0.0, + print_references=True, convo=False, + return_chat_obj=True) + + for response in out[0]: + full_response += response.choices[0].delta.content or "" + message_placeholder.markdown(full_response + "▌") + + elif '2' in response_message: + out = self.wizard_db.chat(query, + debug=True, stream=True, + token_budget=6000, temperature=0.0, + print_references=True, convo=False, + return_chat_obj=True) + + st.pyplot(fig=out, clear_figure=False) + + else: + response_message = 'Error cannot find data in report or database.' + + return full_response diff --git a/examples/db_wizard/retrieve_docs_general.py b/examples/db_wizard/retrieve_docs_general.py new file mode 100644 index 00000000..e59e76a4 --- /dev/null +++ b/examples/db_wizard/retrieve_docs_general.py @@ -0,0 +1,95 @@ +"""Retrieve and embed docs for a dbwizard example""" +import os +import asyncio +import pandas as pd +import logging +import openai +import time +from rex import init_logger + +from elm.pdf import PDFtoTXT +from elm.embed import ChunkAndEmbed + + +logger = logging.getLogger(__name__) +init_logger(__name__, log_level='DEBUG') +init_logger('elm', log_level='INFO') + + +# NREL-Azure endpoint. You can also use just the openai endpoint. +# NOTE: embedding values are different between OpenAI and Azure models! +openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT") +openai.api_key = os.getenv("AZURE_OPENAI_KEY") +openai.api_type = 'azure' +openai.api_version = '2023-03-15-preview' + +ChunkAndEmbed.EMBEDDING_MODEL = 'text-embedding-ada-002-2' +ChunkAndEmbed.EMBEDDING_URL = ('https://stratus-embeddings-south-central.' + 'openai.azure.com/openai/deployments/' + 'text-embedding-ada-002-2/embeddings?' + f'api-version={openai.api_version}') +ChunkAndEmbed.HEADERS = {"Content-Type": "application/json", + "Authorization": f"Bearer {openai.api_key}", + "api-key": f"{openai.api_key}"} + +PDF_DIR = './pdfs/' +TXT_DIR = './txt/' +EMBED_DIR = './embed/' + +URL = ('https://www.osti.gov/api/v1/records?' + 'research_org=NREL' + '&sort=publication_date%20desc' + '&product_type=Technical%20Report' + '&has_fulltext=true' + '&publication_date_start=01/01/2023' + '&publication_date_end=12/31/2023') + + +if __name__ == '__main__': + os.makedirs(PDF_DIR, exist_ok=True) + os.makedirs(TXT_DIR, exist_ok=True) + os.makedirs(EMBED_DIR, exist_ok=True) + + fns = os.listdir(PDF_DIR) + + for fn in fns: + if 'pdf' in fn: + print(fn) + fp = os.path.join(PDF_DIR, fn) + txt_fp = os.path.join(TXT_DIR, fn.replace('.pdf', '.txt')) + embed_fp = os.path.join(EMBED_DIR, fn.replace('.pdf', '.json')) + + assert fp.endswith('.pdf') + assert os.path.exists(fp) + + if os.path.exists(txt_fp): + with open(txt_fp, 'r') as f: + text = f.read() + else: + pdf_obj = PDFtoTXT(fp) + text = pdf_obj.clean_poppler(layout=True) + if pdf_obj.is_double_col(): + text = pdf_obj.clean_poppler(layout=False) + text = pdf_obj.clean_headers(char_thresh=0.6, page_thresh=0.8, + split_on='\n', + iheaders=[0, 1, 3, -3, -2, -1]) + with open(txt_fp, 'w') as f: + f.write(text) + logger.info(f'Saved: {txt_fp}') + + if not os.path.exists(embed_fp): + tag = "Title: Fema \n Authors: FEMA" + obj = ChunkAndEmbed(text, tag=tag, tokens_per_chunk=500, + overlap=1) + embeddings = asyncio.run(obj.run_async(rate_limit=3e4)) + if any(e is None for e in embeddings): + raise RuntimeError('Embeddings are None!') + else: + df = pd.DataFrame({'text': obj.text_chunks.chunks, + 'embedding': embeddings, + 'osti_id': 1}) + df.to_json(embed_fp, indent=2) + logger.info('Saved: {}'.format(embed_fp)) + time.sleep(5) + + logger.info('Finished!') diff --git a/examples/db_wizard/run_db_wizard_app.py b/examples/db_wizard/run_db_wizard_app.py new file mode 100644 index 00000000..f348dca8 --- /dev/null +++ b/examples/db_wizard/run_db_wizard_app.py @@ -0,0 +1,79 @@ +"""Run the db wizard example streamlit app""" +import streamlit as st +import os +import openai + +from elm.db_wiz import DataBaseWizard + +model = 'gpt-4' +conn_string = ('postgresql://la100_admin:laa5SSf6KOC6k9xl' + '@gds-cluster-1.cluster-ccklrxkcenui' + '.us-west-2.rds.amazonaws.com:5432/la100-stage') + +openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT") +openai.api_key = os.getenv("AZURE_OPENAI_KEY") +openai.api_type = 'azure' +openai.api_version = '2023-03-15-preview' + +DataBaseWizard.URL = ( + f'https://stratus-embeddings-south-central.openai.azure.com/' + f'openai/deployments/{model}/chat/' + f'completions?api-version={openai.api_version}') +DataBaseWizard.HEADERS = {"Content-Type": "application/json", + "Authorization": f"Bearer {openai.api_key}", + "api-key": f"{openai.api_key}"} + +st.set_option('deprecation.showPyplotGlobalUse', False) + + +if __name__ == '__main__': + wizard = DataBaseWizard(model=model, connection_string=conn_string) + + opening_message = '''Hello! \n I am the Database Wizard. I + Have access to a single database. You can ask me questions + about the data and ask me to produce visualizations of the data. + Here are some examples of what you can ask me: + \n - Plot a time series of the winter residential + heating load for the moderate scenario + in model year 2030 for geography 1. + \n - Plot a time series of the winter + residential heating load for the moderate scenario + in model year 2030 for the first five load centers. + ''' + + st.title(opening_message) + + if "messages" not in st.session_state: + st.session_state.messages = [] + else: + if st.button('Clear Chat'): + # Clearing Messages + st.session_state.messages = [] + for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + + # Clearing Wizard + wizard.clear() + wizard = DataBaseWizard(model=model, connection_string=conn_string) + + for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + + msg = "Type your question here" + if prompt := st.chat_input(msg): + st.chat_message("user").markdown(prompt) + st.session_state.messages.append({"role": "user", "content": prompt}) + + with st.chat_message("assistant"): + + message_placeholder = st.empty() + full_response = "" + + out = wizard.chat(prompt, + debug=True, stream=True, token_budget=6000, + temperature=0.0, print_references=True, + convo=False, return_chat_obj=True) + + st.pyplot(fig=out, clear_figure=False) diff --git a/examples/db_wizard/run_experts_app.py b/examples/db_wizard/run_experts_app.py new file mode 100644 index 00000000..1a40fefa --- /dev/null +++ b/examples/db_wizard/run_experts_app.py @@ -0,0 +1,52 @@ +"""Run the mixture of experts streamlit app""" +import streamlit as st +from elm.experts import MixtureOfExperts + +model = 'gpt-4' +# User defined connection string +conn_string = '' + +if __name__ == '__main__': + wizard = MixtureOfExperts(model=model, connection_string=conn_string) + + msg = ("""Multi-Modal Wizard Demonstration!\nI am a multi-modal AI + demonstration. I have access to NREL technical reports regarding the + LA100 study and access to several LA100 databases. If you ask me a + question, I will attempt to answer it using the reports or the + database. Below are some examples of queries that have been shown to + work. + \n - Describe chapter 2 of the LA100 report. + \n - What are key findings of the LA100 report? + \n - What enduse consumes the most electricity? + \n - During the year 2020 which geographic regions consumed the + most electricity? + """) + + st.title(msg) + + if "messages" not in st.session_state: + st.session_state.messages = [] + + for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + + msg = "Type your question here" + if prompt := st.chat_input(msg): + st.chat_message("user").markdown(prompt) + st.session_state.messages.append({"role": "user", "content": prompt}) + + with st.chat_message("assistant"): + + message_placeholder = st.empty() + full_response = "" + + out = wizard.chat(query=prompt, + debug=True, stream=True, token_budget=6000, + temperature=0.0, print_references=True, + convo=False, return_chat_obj=True) + + message_placeholder.markdown(full_response) + + st.session_state.messages.append({"role": "assistant", + "content": full_response}) diff --git a/examples/energy_wizard/database_manual.txt b/examples/energy_wizard/database_manual.txt new file mode 100644 index 00000000..31b8f36d --- /dev/null +++ b/examples/energy_wizard/database_manual.txt @@ -0,0 +1,3 @@ +The table "blk_annual_demand" has six columns: "load_scenario", "year", "block_fips", "tract_fips", "geography_id", and "kwh". The "load_scenario", "year", "block_fips", and "tract_fips" columns are of type "text", while the "geography_id" column is of type "character varying". The "kwh" column is of type "double precision". The table contains data on annual electricity demand (in kilowatt-hours) for different geographic areas, identified by their block and tract FIPS codes, under different load scenarios. The load scenarios represent varying levels of grid loading including moderate, high, and stress in ascending order. The years covered span from 2020 to 2045 in 5 year increments. +The table "lc_annual_demand_enduse" has nine columns: "load_scenario", "year", "geography_id", "scenario_year", "load_center", "sector", "enduse", "kwh", and "kwh_w_dlosses". It appears to be a database of annual electricity demand broken down by load scenario, year, geography, scenario year, load center, sector, end use, and two different measures of electricity consumption (kwh and kwh_w_dlosses). +The table "lc_annual_gas_demand" has seven columns: "load_scenario" (text), "year" (text), "geography_id" (bigint), "scenario_year" (text), "load_center" (bigint), "sector" (text), and "btu" (double precision). It appears to be a record of annual gas demand for different load scenarios, years, geographies, scenario years, load centers, and sectors. The "btu" column likely represents the amount of gas demanded in British Thermal Units. diff --git a/examples/energy_wizard/db_description.txt b/examples/energy_wizard/db_description.txt new file mode 100644 index 00000000..19acaceb --- /dev/null +++ b/examples/energy_wizard/db_description.txt @@ -0,0 +1,4 @@ +You have been given access to the database schema {"lc_day_profile_demand_enduse": [{"column": "load_scenario", "type": "text"}, {"column": "year", "type": "text"}, {"column": "geography_id", "type": "bigint"}, {"column": "scenario_year", "type": "text"}, {"column": "load_center", "type": "bigint"}, {"column": "timestamp", "type": "timestamp without time zone"}, {"column": "timestamp_alias", "type": "timestamp without time zone"}, {"column": "week_type", "type": "text"}, {"column": "day_type", "type": "text"}, {"column": "hour_type", "type": "text"}, {"column": "sector", "type": "text"}, {"column": "enduse", "type": "text"}, {"column": "kwh", "type": "double precision"}, {"column": "kwh_w_dlosses", "type": "double precision"}]}. + The first ten lines of the database are [["moderate", "2025", 1, "moderate_2025", 1, "2012-04-18T05:00:00", "2025-04-16T13:00:00", "min", null, null, "com", "plug_and_process", 2916.49093456103, 3149.8102093259126], ["moderate", "2025", 1, "moderate_2025", 1, "2012-04-18T05:00:00", "2025-04-16T13:00:00", "min", null, null, "com", "refrigeration", 54.7219096497455, 59.09966242172514], ["moderate", "2025", 1, "moderate_2025", 1, "2012-04-18T05:00:00", "2025-04-16T13:00:00", "min", null, null, "ev", "bus", 1760.17283482198, 1900.9866616077384], ["moderate", "2025", 1, "moderate_2025", 1, "2012-04-18T05:00:00", "2025-04-16T13:00:00", "min", null, null, "ev", "dcfc", 0.0, 0.0], ["moderate", "2025", 1, "moderate_2025", 1, "2012-04-18T05:00:00", "2025-04-16T13:00:00", "min", null, null, "ev", "l1l2", 284.711061946903, 307.48794690265527], ["moderate", "2025", 1, "moderate_2025", 1, "2012-04-18T05:00:00", "2025-04-16T13:00:00", "min", null, null, "gap", "other", 201.712384639549, 217.84937541071292], ["moderate", "2025", 1, "moderate_2025", 1, "2012-04-18T05:00:00", "2025-04-16T13:00:00", "min", null, null, "ind", "other", 1196.36519278737, 1292.0744082103597], ["moderate", "2025", 1, "moderate_2025", 1, "2012-04-18T05:00:00", "2025-04-16T13:00:00", "min", null, null, "misc", "other", 0.0, 0.0], ["moderate", "2025", 1, "moderate_2025", 1, "2012-04-18T05:00:00", "2025-04-16T13:00:00", "min", null, null, "res", "appliances", 101.25093281548, 109.3510074407184], ["moderate", "2025", 1, "moderate_2025", 1, "2012-04-18T05:00:00", "2025-04-16T13:00:00", "min", null, null, "res", "cooling", 260.147432639897, 280.95922725108875]]. + Each column of text contains the following unique values {"load_scenario": "[('stress',), ('high',), ('moderate',)]", "year": "[('2030',), ('2045',), ('2040',), ('2035',), ('2025',), ('2020',)]", "scenario_year": "[('high_2020',), ('high_2025',), ('moderate_2035',), ('moderate_2030',), ('stress_2020',), ('stress_2040',), ('moderate_2045',), ('moderate_2025',), ('moderate_2040',), ('high_2035',), ('stress_2035',), ('moderate_2020',), ('stress_2030',), ('stress_2045',), ('stress_2025',), ('high_2045',), ('high_2030',), ('high_2040',)]", "week_type": "[('min',), ('winter',), ('peak',), ('fall',)]", "day_type": "[(None,), ('min',), ('winter',), ('peak',), ('fall',)]", "hour_type": "[(None,), ('min',), ('winter',), ('peak',), ('fall',)]", "sector": "[('misc',), ('ev',), ('com',), ('res',), ('wtr',), ('gap',), ('ind',)]", "enduse": "[('fans_and_pumps',), ('hot_water',), ('heating',), ('refrigeration',), ('plug_and_process',), ('lighting',), ('cooling',), ('pool',), ('bus',), ('dcfc',), ('appliances',), ('municipal_water',), ('l1l2',), ('other',)]"}. + The table name is loads.lc_day_profile_demand_enduse. \ No newline at end of file