diff --git a/.github/workflows/workflow_orders_on_pr.yml b/.github/workflows/workflow_orders_on_pr.yml index 85752dbda..acec72a6d 100644 --- a/.github/workflows/workflow_orders_on_pr.yml +++ b/.github/workflows/workflow_orders_on_pr.yml @@ -7,7 +7,6 @@ on: paths: - '**' - '!*.md' - jobs: Lint: diff --git a/application/pages/1_LANGCHAIN_2_text_to_sql.py b/application/pages/1_LANGCHAIN_2_text_to_sql.py index be1db8859..4ee8e87a7 100644 --- a/application/pages/1_LANGCHAIN_2_text_to_sql.py +++ b/application/pages/1_LANGCHAIN_2_text_to_sql.py @@ -14,89 +14,33 @@ # limitations under the License. # -import os +import json +import time import utils +import requests import streamlit as st from streaming import StreamHandler -from langchain.chat_models import ChatOpenAI -from langchain.chains import ConversationChain -from langchain_community.vectorstores import FAISS -from langchain_community.embeddings import HuggingFaceEmbeddings - st.set_page_config(page_title="SQL_Chatbot", page_icon="πŸ’¬") st.header("SQL Chatbot") st.write("Allows users to interact with the LLM") -def generate_prompt(question, schema): - prompt = """### Instructions: -Your task is convert a question into a SQL query, given a MySQL database schema. -Adhere to these rules: -- **Deliberately go through the question and database schema word by word** to appropriately answer the question -- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. -- When creating a ratio, always cast the numerator as float -- Use LIKE instead of ilike -- Only generate the SQL query, no additional text is required -- Generate SQL queries for MySQL database - -### Input: -Generate a SQL query that answers the question `{question}`. -This query will run on a database whose schema is represented in this string: -{schema} - -### Response: -Based on your instructions, here is the SQL query I have generated to answer the question `{question}`: -```sql -""".format( - question=question, schema=schema - ) - - return prompt - - -def rag_retrival(retriever, query): - matched_tables = [] - matched_documents = retriever.get_relevant_documents(query=query) - - for document in matched_documents: - page_content = document.page_content - matched_tables.append(page_content) - return matched_tables - - class Basic: def __init__(self): - self.openai_model = "sqlcoder-7b-2" + self.server_url = "http://127.0.0.1:8080" self.history_messages = utils.enable_chat_history("basic_chat") - def setup_chain(self): - llm = ChatOpenAI( - openai_api_base="http://localhost:8000/v1", - model_name=self.openai_model, - openai_api_key="not_needed", - streaming=True, - max_tokens=512, - ) - chain = ConversationChain(llm=llm, verbose=True) - return chain - - def setup_db_retriever( - self, - db=os.path.join(os.path.abspath(os.path.dirname(__file__)), "retriever.db"), - emb_model_name="defog/sqlcoder-7b-2", - top_k_table=1, - ): - embeddings = HuggingFaceEmbeddings(model_name=emb_model_name) - db = FAISS.load_local(db, embeddings, allow_dangerous_deserialization=True) - retriever = db.as_retriever( - search_type="mmr", search_kwargs={"k": top_k_table, "lambda_mult": 1} - ) - return retriever + def _post_parse_response(self, response): + if response.status_code == 200: + text = response.text + json_data = json.loads(text) + return json_data + else: + print("Error Code: ", response.status_code) + return None def main(self): - chain = self.setup_chain() - db_retriever = self.setup_db_retriever() for message in self.history_messages: # Display the prior chat messages with st.chat_message(message["role"]): st.write(message["content"]) @@ -109,10 +53,24 @@ def main(self): with st.chat_message("assistant"): with st.spinner("Thinking..."): st_cb = StreamHandler(st.empty()) - schema = rag_retrival(db_retriever, user_query) - user_query = generate_prompt(user_query, schema) - response = chain.run(user_query, callbacks=[st_cb]) - self.history_messages.append({"role": "assistant", "content": response}) + response_rag = requests.post( + f"{self.server_url}/v1/retrieve_tables", json={"query": user_query} + ) + json_data_rag = self._post_parse_response(response_rag) + matched_tables = json_data_rag["matched_tables"] + + response_sql = requests.post( + f"{self.server_url}/v1/generate_sql_code", + json={"query": user_query, "schema": matched_tables}, + ) + json_data_sql = self._post_parse_response(response_sql) + sql_answer = json_data_sql["sql_code"]["content"] + self.history_messages.append({"role": "assistant", "content": sql_answer}) + + print(sql_answer) + for token in sql_answer.split(): + time.sleep(1) + st_cb.on_llm_new_token(token + " ") if __name__ == "__main__": diff --git a/application/pages/1_LANGCHAIN_3_gluten_udf_codegen.py b/application/pages/1_LANGCHAIN_3_gluten_udf_codegen.py new file mode 100644 index 000000000..6fead4ab6 --- /dev/null +++ b/application/pages/1_LANGCHAIN_3_gluten_udf_codegen.py @@ -0,0 +1,156 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import requests +import streamlit as st +from code_editor import code_editor +import json + +st.set_page_config(page_title="Gluten_Coder_Chatbot_V2", page_icon="πŸ’¬") +st.header("Gluten Coder Chatbot") +st.write("Convert code to Gluten/Velox UDF with the LLM") + +code_editor_btns_config = [ + { + "name": "Copy", + "feather": "Copy", + "hasText": True, + "alwaysOn": True, + "commands": [ + "copyAll", + [ + "infoMessage", + {"text": "Copied to clipboard!", "timeout": 2500, "classToggle": "show"}, + ], + ], + "style": {"top": "0rem", "right": "0.4rem"}, + }, + { + "name": "Run", + "feather": "Play", + "primary": True, + "hasText": True, + "showWithIcon": True, + "commands": ["submit"], + "style": {"bottom": "0.44rem", "right": "0.4rem"}, + }, +] + +info_bar = { + "name": "input code", + "css": "\nbackground-color: #bee1e5;\n\nbody > #root .ace-streamlit-dark~& {\n background-color: #444455;\n}\n\n.ace-streamlit-dark~& span {\n color: #fff;\n opacity: 0.6;\n}\n\nspan {\n color: #000;\n opacity: 0.5;\n}\n\n.code_editor-info.message {\n width: inherit;\n margin-right: 75px;\n order: 2;\n text-align: center;\n opacity: 0;\n transition: opacity 0.7s ease-out;\n}\n\n.code_editor-info.message.show {\n opacity: 0.6;\n}\n\n.ace-streamlit-dark~& .code_editor-info.message.show {\n opacity: 0.5;\n}\n", + "style": { + "order": "1", + "display": "flex", + "flexDirection": "row", + "alignItems": "center", + "width": "100%", + "height": "2.5rem", + "padding": "0rem 0.6rem", + "padding-bottom": "0.2rem", + "margin-bottom": "-1px", + "borderRadius": "8px 8px 0px 0px", + "zIndex": "9993", + }, + "info": [{"name": "Your code", "style": {"width": "800px"}}], +} + + +class Basic: + def __init__(self): + self.server_url = "http://127.0.0.1:8000" + + def _post_parse_response(self, response): + if response.status_code == 200: + text = response.text + json_data = json.loads(text) + return json_data + else: + print("Error Code: ", response.status_code) + return None + + def main(self): + step = 1 + + response_dict = code_editor( + "", + height=(8, 20), + lang="scala", + theme="dark", + shortcuts="vscode", + focus=False, + buttons=code_editor_btns_config, + info=info_bar, + props={"style": {"borderRadius": "0px 0px 8px 8px"}}, + options={"wrap": True}, + ) + code_to_convert = response_dict["text"] + + if bool(code_to_convert): + print(code_to_convert) + + with st.chat_message(name="assistant", avatar="πŸ§‘β€πŸ’»"): + st.write(f"Step {step}: convert the code into C++") + step += 1 + with st.spinner("Converting your code to C++..."): + data = {"code": code_to_convert} + response = requests.post(self.server_url + "/v1/convert_to_cpp", json=data) + json_data = self._post_parse_response(response) + cpp_code_res = json_data["answer"] + cpp_code = json_data["cpp_code"] + with st.chat_message("ai"): + st.markdown(cpp_code_res) + + with st.chat_message(name="assistant", avatar="πŸ§‘β€πŸ’»"): + st.write(f"Step {step}: Analyze the keywords that may need to be queried") + step += 1 + with st.spinner("Analyze the code..."): + data = {"cpp_code": cpp_code} + response = requests.post(self.server_url + "/v1/generate_keywords", json=data) + json_data = self._post_parse_response(response) + keywords = json_data["velox_keywords"] + with st.chat_message("ai"): + st.markdown("\n".join(keywords)) + + with st.chat_message(name="assistant", avatar="πŸ§‘β€πŸ’»"): + st.write(f"Step {step}: Retrieve related knowledge from velox documentations") + step += 1 + with st.spinner("Retrieve reference from velox document and code..."): + data = {"velox_keywords": keywords} + response = requests.post(self.server_url + "/v1/retrieve_doc", json=data) + json_data = self._post_parse_response(response) + related_docs = json_data["related_docs"] + with st.chat_message("ai"): + st.write(related_docs) + + with st.chat_message(name="assistant", avatar="πŸ§‘β€πŸ’»"): + st.write(f"Step {step}: Based on the previous analysis, rewrite velox based UDF") + step += 1 + with st.spinner("Converting the C++ code to velox based udf..."): + data = { + "velox_keywords": keywords, + "code": code_to_convert, + "related_docs": related_docs, + } + response = requests.post(self.server_url + "/v1/get_gluten_udf", json=data) + json_data = self._post_parse_response(response) + udf_answer = json_data["udf_answer"] + with st.chat_message("ai"): + st.markdown(udf_answer) + + +if __name__ == "__main__": + obj = Basic() + obj.main() diff --git a/application/pages/gluten_coder/README.md b/application/pages/gluten_coder/README.md new file mode 100644 index 000000000..d04efe796 --- /dev/null +++ b/application/pages/gluten_coder/README.md @@ -0,0 +1,49 @@ +# Gluten UDF converter AI chatbot + +[![Streamlit App](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://langchain-chatbot.streamlit.app/) +## Introduction +### Gluten +[Gluten](https://github.com/apache/incubator-gluten) is a new middle layer to offload Spark SQL queries to native engines. Gluten can benefit from high scalability of Spark SQL framework and high performance of native libraries. + +The basic rule of Gluten's design is that we would reuse spark's whole control flow and as many JVM code as possible but offload the compute-intensive data processing part to native code. Here is what Gluten does: + +- Transform Spark's whole stage physical plan to Substrait plan and send to native +- Offload performance-critical data processing to native library +- Define clear JNI interfaces for native libraries +- Switch available native backends easily +- Reuse SparkοΏ½s distributed control flow +- Manage data sharing between JVM and native +- Extensible to support more native accelerators + + +### Ability of this chatbot +The objective of this chatbot application is to assist users by seamlessly transforming their user-defined functions (UDFs), originally designed for Vanilla Spark, into C++ code that adheres to the code standards of Gluten and Velox. This is achieved through the utilization of a Language Learning Model (LLM), which automates the conversion process, ensuring compatibility and enhancing performance within these advanced execution frameworks. + +The conversion process is streamlined into the following steps: + +1. The chatbot identifies and comprehends the logic of the original Spark UDF code, then translates it into an initial C++ code draft. +2. Utilizing the preliminary C++ code, the Language Learning Model (LLM) identifies key terms to construct queries. These queries are related to Velox's existing function implementations and data types. The LLM then outputs the query results in JSON format. +3. With the keywords from the LLM's output, the chatbot retrieve the Velox documentation stored in vector database(Faiss) to find relevant information. +4. Drawing from the information in the Velox documentation, the chatbot generates the final C++ code that is tailored to meet the specifications of Velox UDFs. + + +### Configuration + +Currently, we are using LLM Model [deepseek-coder-33b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct). +Deployment can be done using LLM-on-Ray with the following command: +``` +llm_on_ray-serve --config_file llm_on_ray/inference/models/deepseek-coder-33b-instruct.yaml +``` + +Before launching the Streamlit application, you need to update the config.py file located at application/pages/codegen/config.py with the necessary configuration details: + +``` +# Specify the directory where the model 'deepseek-coder-33b-instruct' is stored. +model_base_path = "" +# Provide the path to the FAISS index for Velox documentation. +vs_path = "" +``` + + + + diff --git a/application/pages/gluten_coder/__init__.py b/application/pages/gluten_coder/__init__.py new file mode 100644 index 000000000..854e39ad4 --- /dev/null +++ b/application/pages/gluten_coder/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/application/pages/gluten_coder/coder.py b/application/pages/gluten_coder/coder.py new file mode 100644 index 000000000..b118159c9 --- /dev/null +++ b/application/pages/gluten_coder/coder.py @@ -0,0 +1,136 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +from functools import lru_cache + +from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings +from langchain_community.vectorstores.faiss import FAISS + +from config import emb_model_path, index_path +from prompt_config import ( + rag_suffix, + convert_to_cpp_temp, + example_temp, + example_related_queries, + generate_search_query_prompt, + example_scala_code, +) + +import re + + +@lru_cache() +def get_embedding(model_path): + embedding = HuggingFaceEmbeddings( + model_name=model_path, + ) + return embedding + + +@lru_cache() +def get_vector_store(emb_model_path, index_path): + embeddings = get_embedding(emb_model_path) + db = FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True) + return db + + +@lru_cache(maxsize=10) +def retrieve_reference(code_query): + results = "" + db = get_vector_store(emb_model_path, index_path) + + if isinstance(code_query, str): + retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 3, "lambda_mult": 1}) + matched_documents = retriever.get_relevant_documents(query=code_query) + for document in matched_documents: + results += document.page_content + "\n" + elif isinstance(code_query, tuple): + for query in code_query: + retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 1, "lambda_mult": 1}) + query = query.lower() + query = query.replace("velox", " ") + matched_documents = retriever.get_relevant_documents(query=query) + for document in matched_documents: + results += document.page_content + "\n" + + return results + + +def extract_code(text): + print(text) + pattern = r"```c\+\+|```cpp|\```C\+\+|```" + parts = re.split(pattern, text) + return parts[1] + + +def generate_to_cpp_code(llm, code_text): + convert_to_cpp_prompt = convert_to_cpp_temp.format(code_text) + res = llm.invoke(convert_to_cpp_prompt) + cpp_code = extract_code(res.content) + with open("record.txt", "w") as f: + f.write(res.content) + + return res.content, cpp_code + + +def generate_keywords(llm, code_text): + keywords_prompt = generate_search_query_prompt.substitute(cpp_code=code_text) + print(keywords_prompt) + res = llm.invoke(keywords_prompt) + json_str = res.content + json_str = json_str.replace("```", "") + json_str = json_str.replace("json", "") + keywords = json.loads(json_str)["Queries"] + return keywords + + +def generate_velox_udf(llm, cpp_code, rag_queries=example_related_queries, rag_text=""): + reference = "" + if bool(rag_text): + reference = rag_suffix.format(rag_text) + + prompt = example_temp.substitute(reference=reference, queries=rag_queries, cpp_code=cpp_code) + if bool(rag_text): + prompt = prompt + rag_suffix.format(rag_text) + print("------velox udf------") + print(prompt) + res = llm.invoke(prompt) + + with open("record.txt", "a") as f: + print("-" * 70) + f.write(res.content) + + return res.content + + +if __name__ == "__main__": + openai_model = "deepseek-coder:33b-instruct" + from langchain_community.chat_models.openai import ChatOpenAI + + llm = ChatOpenAI( + openai_api_base="http://localhost:11434/v1", + model_name=openai_model, + openai_api_key="not_needed", + # max_tokens=2048, + streaming=True, + ) + + keywords = ( + "Velox string functions", + "Velox UDF argument handling", + "Velox default values for UDFs", + ) + res = retrieve_reference(keywords) diff --git a/application/pages/gluten_coder/config.py b/application/pages/gluten_coder/config.py new file mode 100644 index 000000000..df7dd6287 --- /dev/null +++ b/application/pages/gluten_coder/config.py @@ -0,0 +1,27 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +# Please input your path +model_base_path = "" +vs_path = "" + + +emb_model_name = "deepseek-coder-33b-instruct" +emb_model_path = os.path.join(model_base_path, emb_model_name) + +index_name = "velox-doc_deepseek-coder-33b-instruct" +index_path = os.path.join(vs_path, index_name) diff --git a/application/pages/gluten_coder/prompt_config.py b/application/pages/gluten_coder/prompt_config.py new file mode 100644 index 000000000..bfce1c156 --- /dev/null +++ b/application/pages/gluten_coder/prompt_config.py @@ -0,0 +1,342 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import string + + +example_scala_code = """ +@Description( + name = "norm_str", + value = "_FUNC_(input, [defaultValue], [dirtyValues ...]) trims input and " + + "normalize null, empty or dirtyValues to defVal. \n", + extended = "preset defaultValue is 'N-A' and preset dirtyValues are {'null', 'unknown', 'unknow', 'N-A'},\n" + + "the third NULL argument will clear the preset dirtyValues list." +) +public class UDFNormalizeString extends GenericUDF { + + + public final static String DEFAULT_VALUE = "N-A"; + + @SuppressWarnings("SpellCheckingInspection") + public final static List DEFAULT_NULL_VALUES = Arrays.asList("null", "unknown", "unknow", DEFAULT_VALUE); + + private transient String defaultValue; + private transient Set nullValues; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + + if (arguments.length == 0) { + throw new UDFArgumentLengthException("norm_str() expects at least one argument."); + } + + defaultValue = DEFAULT_VALUE; + if (arguments.length >= 2) { + + // ............ + if (!ObjectInspectorUtils.isConstantObjectInspector(arguments[1])) { + throw new UDFArgumentTypeException(1, "norm_str() expects a constant value as default."); + } + + // ..... + Object writable = ObjectInspectorUtils.getWritableConstantValue(arguments[1]); + defaultValue = (writable == null ? null : writable.toString()); + } + + nullValues = new HashSet<>(DEFAULT_NULL_VALUES); + for (int i = 2; i < arguments.length; i++) { + + if (!ObjectInspectorUtils.isConstantObjectInspector(arguments[i])) { + throw new UDFArgumentTypeException(i, "norm_str() expects constant values as dirty values"); + } + + Object writable = ObjectInspectorUtils.getWritableConstantValue(arguments[i]); + + if (writable == null) { + // .........null ....... + if (i != 2) { + throw new UDFArgumentException( + "Only the third null argument will clear the default null values of norm_str()."); + } + nullValues.clear(); + } else { + nullValues.add(writable.toString().trim().toLowerCase()); + } + } + + return PrimitiveObjectInspectorFactory.javaStringObjectInspector; + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + assert arguments.length > 0; + + Object inputObject = arguments[0].get(); + + if (inputObject == null) { + return defaultValue; + } + + String input = inputObject.toString().trim(); + + if (input.length() == 0 || nullValues.contains(input.toLowerCase())) { + return defaultValue; + } + + return input; + } + + @Override + public String getDisplayString(String[] children) { + return getStandardDisplayString("norm_str", children); + } +} + +""" + +demo_sample_code = """ +@Description( + name = "norm_str", + value = "_FUNC_(input, [defaultValue], [dirtyValues ...]) trims input and " + + "normalize null, empty or dirtyValues to defVal. \n", + extended = "preset defaultValue is 'N-A' and preset dirtyValues are {'null', 'unknown', 'unknow', 'N-A'},\n" + + "the third NULL argument will clear the preset dirtyValues list." +) +public class UDFNormalizeString extends GenericUDF { + + + public final static String DEFAULT_VALUE = "N-A"; + + @SuppressWarnings("SpellCheckingInspection") + public final static List DEFAULT_NULL_VALUES = Arrays.asList("null", "unknown", "unknow", DEFAULT_VALUE); + + private transient String defaultValue; + private transient Set nullValues; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + + if (arguments.length == 0) { + throw new UDFArgumentLengthException("norm_str() expects at least one argument."); + } + + defaultValue = DEFAULT_VALUE; + if (arguments.length >= 2) { + + // ............ + if (!ObjectInspectorUtils.isConstantObjectInspector(arguments[1])) { + throw new UDFArgumentTypeException(1, "norm_str() expects a constant value as default."); + } + + // ..... + Object writable = ObjectInspectorUtils.getWritableConstantValue(arguments[1]); + defaultValue = (writable == null ? null : writable.toString()); + } + + nullValues = new HashSet<>(DEFAULT_NULL_VALUES); + for (int i = 2; i < arguments.length; i++) { + + if (!ObjectInspectorUtils.isConstantObjectInspector(arguments[i])) { + throw new UDFArgumentTypeException(i, "norm_str() expects constant values as dirty values"); + } + + Object writable = ObjectInspectorUtils.getWritableConstantValue(arguments[i]); + + if (writable == null) { + // .........null ....... + if (i != 2) { + throw new UDFArgumentException( + "Only the third null argument will clear the default null values of norm_str()."); + } + nullValues.clear(); + } else { + nullValues.add(writable.toString().trim().toLowerCase()); + } + } + + return PrimitiveObjectInspectorFactory.javaStringObjectInspector; + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + assert arguments.length > 0; + + Object inputObject = arguments[0].get(); + + if (inputObject == null) { + return defaultValue; + } + + String input = inputObject.toString().trim(); + + if (input.length() == 0 || nullValues.contains(input.toLowerCase())) { + return defaultValue; + } + + return input; + } + + @Override + public String getDisplayString(String[] children) { + return getStandardDisplayString("norm_str", children); + } +} +""" + + +convert_to_cpp_temp = """ + +Convert the following code into a C++function or class: +``` +{} +``` +""" + +example_temp = string.Template( + """ +Your task is to refer example Velox UDF and rewrite code I provided into a Velox UDF. +Following code is an example Velox UDF: +``` +#include +#include +#include "udf/Udf.h" + +namespace { +using namespace facebook::velox; + +template +class PlusConstantFunction : public exec::VectorFunction { + public: + explicit PlusConstantFunction(int32_t addition) : addition_(addition) {} + + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& /* outputType */, + exec::EvalCtx& context, + VectorPtr& result) const override { + using nativeType = typename TypeTraits::NativeType; + VELOX_CHECK_EQ(args.size(), 1); + + auto& arg = args[0]; + + // The argument may be flat or constant. + VELOX_CHECK(arg->isFlatEncoding() || arg->isConstantEncoding()); + + BaseVector::ensureWritable(rows, createScalarType(), context.pool(), result); + + auto* flatResult = result->asFlatVector(); + auto* rawResult = flatResult->mutableRawValues(); + + flatResult->clearNulls(rows); + + if (arg->isConstantEncoding()) { + auto value = arg->as>()->valueAt(0); + rows.applyToSelected([&](auto row) { rawResult[row] = value + addition_; }); + } else { + auto* rawInput = arg->as>()->rawValues(); + + rows.applyToSelected([&](auto row) { rawResult[row] = rawInput[row] + addition_; }); + } + } + + private: + const int32_t addition_; +}; + +static std::vector> integerSignatures() { + // integer -> integer + return {exec::FunctionSignatureBuilder().returnType("integer").argumentType("integer").build()}; +} + +static std::vector> bigintSignatures() { + // bigint -> bigint + return {exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()}; +} + +} // namespace + +const int kNumMyUdf = 2; +gluten::UdfEntry myUdf[kNumMyUdf] = {{"myudf1", "integer"}, {"myudf2", "bigint"}}; + +DEFINE_GET_NUM_UDF { + return kNumMyUdf; +} + +DEFINE_GET_UDF_ENTRIES { + for (auto i = 0; i < kNumMyUdf; ++i) { + udfEntries[i] = myUdf[i]; + } +} + +DEFINE_REGISTER_UDF { + facebook::velox::exec::registerVectorFunction( + "myudf1", integerSignatures(), std::make_unique>(5)); + facebook::velox::exec::registerVectorFunction( + "myudf2", bigintSignatures(), std::make_unique>(5)); + LOG(INFO) << "registered myudf1, myudf2"; +} + +``` + +$reference + +Think step by step: +1. Understand the code I provided and the Velox UDF examples : +2. Consider whether there are any related classes or functions already implemented in velox in the code that needs to be rewritten, such as {$queries} +3. Rewrite the code as a Velox UDF + +Please convert blow code: +``` +$cpp_code +``` +""" +) + +example_related_queries = ( + "velox string functions, velox string normalization, and velox string case conversion" +) + +generate_search_query_prompt = string.Template( + """ +Your task is to rewrite the code I provided as a new function in the velox project. Based on the code I provided, write 3 the necessary search keywords to gather information from the Velox code or document. +For example, you found that the code requires calculating the distance between two strings, so you need to find out if there are any functions in Velox that handle string types, such as hamming_distance, that can be directly called + +## Rule +- Don't ask questions you already know and velox udf specification +- The main purpose is to find functions, type definitions, etc. that already exist and can be directly used in Velox + +Here is code: +``` +$cpp_code +``` + +Only respond in the following JSON format: +``` +{ +"Queries":[ +, +"" +] +} +""" +) + +rag_suffix = """ +Some Velox code for reference: +``` +{} +``` +""" diff --git a/application/pages/gluten_coder/server.py b/application/pages/gluten_coder/server.py new file mode 100644 index 000000000..b71acf1c7 --- /dev/null +++ b/application/pages/gluten_coder/server.py @@ -0,0 +1,126 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +from fastapi import APIRouter, FastAPI, Request +from langchain_community.chat_models import ChatOpenAI +from starlette.middleware.cors import CORSMiddleware + +from coder import generate_to_cpp_code, generate_keywords, retrieve_reference, generate_velox_udf + +app = FastAPI() + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +class GlutenUdfGeneratorAPIRouter(APIRouter): + def __init__(self): + super().__init__() + self.openai_model = "deepseek-coder:33b-instruct" + self.coder_llm = ChatOpenAI( + openai_api_base="http://localhost:8000/v1", + model_name=self.openai_model, + openai_api_key="not_needed", + streaming=False, + ) + self.general_llm = self.coder_llm + # self.general_llm = ChatOpenAI( + # openai_api_base="http://localhost:8000/v1", + # model_name="mistral-7b-instruct-v0.2", + # openai_api_key="not_needed", + # streaming=False, + # ) + + def get_cpp_code(self, code_to_convert): + answer, cpp_code = generate_to_cpp_code(self.coder_llm, code_to_convert) + return answer, cpp_code + + def gen_keywords(self, cpp_code): + velox_keywords = generate_keywords(self.general_llm, cpp_code) + return velox_keywords + + def retrieve(self, velox_keywords): + rag_source = retrieve_reference(tuple(velox_keywords)) + return rag_source + + def get_udf(self, code_to_convert, velox_keywords, rag_source): + udf_answer = generate_velox_udf( + self.coder_llm, + code_to_convert, + rag_queries=",".join(velox_keywords), + rag_text=rag_source, + ) + return udf_answer + + +router = GlutenUdfGeneratorAPIRouter() + + +@router.post("/v1/convert_to_cpp") +async def convert_to_cpp(request: Request): + params = await request.json() + print(f"[GlutenUDFConverter - chat] POST request: /v1/rag/convert_to_cpp, params:{params}") + code = params["code"] + answer, cpp_code = router.get_cpp_code(code) + print(f"[GlutenUDFConverter - chat] answer: {answer}, cpp_code: {cpp_code}") + return {"answer": answer, "cpp_code": cpp_code} + + +@router.post("/v1/generate_keywords") +async def keywords(request: Request): + params = await request.json() + print(f"[GlutenUDFConverter - chat] POST request: /v1/rag/generate_keywords, params:{params}") + code = params["cpp_code"] + velox_keywords = router.gen_keywords(code) + print(f"[GlutenUDFConverter - chat] velox_keywords: {velox_keywords}") + return {"velox_keywords": velox_keywords} + + +@router.post("/v1/retrieve_doc") +async def retrieve_doc(request: Request): + params = await request.json() + print(f"[GlutenUDFConverter - chat] POST request: /v1/rag/retrieve_doc, params:{params}") + velox_keywords = params["velox_keywords"] + related_docs = router.retrieve(velox_keywords) + print(f"[GlutenUDFConverter - chat] related_docs: {related_docs}") + return {"related_docs": related_docs} + + +@router.post("/v1/get_gluten_udf") +async def get_gluten_udf(request: Request): + params = await request.json() + print(f"[GlutenUDFConverter - chat] POST request: /v1/rag/get_gluten_udf, params:{params}") + velox_keywords = params["velox_keywords"] + code = params["code"] + related_docs = params["related_docs"] + udf_answer = router.get_udf(code, velox_keywords, related_docs) + print(f"[GlutenUDFConverter - chat] udf_answer: {udf_answer}") + return {"udf_answer": udf_answer} + + +app.include_router(router) + +if __name__ == "__main__": + import uvicorn + + fastapi_port = os.getenv("FASTAPI_PORT", "8000") + uvicorn.run(app, host="0.0.0.0", port=int(fastapi_port)) diff --git a/application/pages/text2sql/README.md b/application/pages/text2sql/README.md new file mode 100644 index 000000000..373ed0d02 --- /dev/null +++ b/application/pages/text2sql/README.md @@ -0,0 +1,33 @@ +# Text2SQL AI chatbot + +[![Streamlit App](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://langchain-chatbot.streamlit.app/) +## Introduction +### Text2SQL +Text2SQL is a natural language processing (NLP) task that aims to convert a natural language question into a SQL query. The goal of Text2SQL is to enable users to interact with databases using natural language, without the need for a specialized knowledge of SQL. The Text2SQL task is a challenging task for NLP systems, as it requires the system to understand the context of the question, the relationships between the words, and the structure of the SQL query. + +### The Text2SQL AI chatbot +This chatbot is an implementation of the Text2SQL task using the LLM-on-Ray service for SQL generation, Langchain for RAG. The chatbot is built using the Streamlit library, which allows for easy deployment and sharing of the chatbot. The chatbot uses a pre-trained [defog/sqlcoder-7b-2](https://huggingface.co/defog/sqlcoder-7b-2) model to generate the SQL query, which is then executed on a remote SQL database. The chatbot is designed to be user-friendly and easy to use, with clear instructions and error messages. + +### Ability of this chatbot +The objective of this chatbot application is to assist users by seamlessly transforming their user-defined language, originally designed for their database, into SQL code that adheres to the database standards. + +The conversion process is streamlined into the following steps: + +1. The chatbot identifies and comprehends the logic of the original natural language, then translates it into an initial SQL code draft. +2. Utilizing the customer's database description, the Language Learning Model (LLM) identifies key terms to construct queries. These queries are related to database schema, tables, columns, and data types. The LLM then outputs the query results in JSON format. +3. With the key items identified from the LLM's output, the chatbot retrieve the database documentation stored in vector database(Faiss) to find relevant information. +4. Drawing from the information in the database documentation, the chatbot generates the final SQL code that is tailored to meet the specifications of user's database. + +### Configuration +Currently, we are using LLM Model [defog/sqlcoder-7b-2](https://huggingface.co/defog/sqlcoder-7b-2). +Deployment can be done using LLM-on-Ray with the following command: +``` +llm_on_ray-serve --config_file llm_on_ray/inference/models/sqlcoder-7b-2.yaml +``` + +Before launching the Streamlit application, you need to update the application/pages/1_LANGCHAIN_2_text_to_sql.py with the necessary configuration details: + +``` +# Provide the path to the FAISS index for database documentation. +retriever.db_path = "" #application/pages/text2sql/retriever.db +``` \ No newline at end of file diff --git a/application/pages/text2sql/__init__.py b/application/pages/text2sql/__init__.py new file mode 100644 index 000000000..854e39ad4 --- /dev/null +++ b/application/pages/text2sql/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/application/pages/text2sql/prompt.py b/application/pages/text2sql/prompt.py new file mode 100644 index 000000000..ea1af0f9f --- /dev/null +++ b/application/pages/text2sql/prompt.py @@ -0,0 +1,41 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +def generate_prompt(question, schema): + prompt = """### Instructions: +Your task is convert a question into a SQL query, given a MySQL database schema. +Adhere to these rules: +- **Deliberately go through the question and database schema word by word** to appropriately answer the question +- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. +- When creating a ratio, always cast the numerator as float +- Use LIKE instead of ilike +- Only generate the SQL query, no additional text is required +- Generate SQL queries for MySQL database + +### Input: +Generate a SQL query that answers the question `{question}`. +This query will run on a database whose schema is represented in this string: +{schema} + +### Response: +Based on your instructions, here is the SQL query I have generated to answer the question `{question}`: +```sql +""".format( + question=question, schema=schema + ) + + return prompt diff --git a/application/pages/retriever.db/index.faiss b/application/pages/text2sql/retriever.db/index.faiss similarity index 100% rename from application/pages/retriever.db/index.faiss rename to application/pages/text2sql/retriever.db/index.faiss diff --git a/application/pages/retriever.db/index.pkl b/application/pages/text2sql/retriever.db/index.pkl similarity index 100% rename from application/pages/retriever.db/index.pkl rename to application/pages/text2sql/retriever.db/index.pkl diff --git a/application/pages/text2sql/server.py b/application/pages/text2sql/server.py new file mode 100644 index 000000000..405449a40 --- /dev/null +++ b/application/pages/text2sql/server.py @@ -0,0 +1,116 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +from fastapi import APIRouter, FastAPI, Request +from langchain_community.chat_models import ChatOpenAI +from starlette.middleware.cors import CORSMiddleware + +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.vectorstores import FAISS + +from prompt import generate_prompt + +app = FastAPI() + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +class SQLGeneratorAPIRouter(APIRouter): + def __init__(self): + super().__init__() + self.openai_model = "sqlcoder-7b-2" + + self.llm = ChatOpenAI( + openai_api_base="http://localhost:8000/v1", + model_name=self.openai_model, + openai_api_key="not_needed", + streaming=True, + max_tokens=512, + ) + self.embeddings = self.setup_emb_model() + self.db_retriever = self.setup_db_retriever(self.embeddings) + + def setup_emb_model(self, emb_model_name="defog/sqlcoder-7b-2"): + embeddings = HuggingFaceEmbeddings(model_name=emb_model_name) + tokenizer = embeddings.client.tokenizer + tokenizer.pad_token = tokenizer.eos_token + return embeddings + + def setup_db_retriever( + self, + embeddings, + db=os.path.join(os.path.abspath(os.path.dirname(__file__)), "retriever.db"), + top_k_table=1, + ): + db = FAISS.load_local(db, embeddings, allow_dangerous_deserialization=True) + retriever = db.as_retriever( + search_type="mmr", search_kwargs={"k": top_k_table, "lambda_mult": 1} + ) + return retriever + + def retrieve(self, query): + matched_tables = [] + matched_documents = self.db_retriever.get_relevant_documents(query=query) + for document in matched_documents: + page_content = document.page_content + matched_tables.append(page_content) + return matched_tables + + def generate_sql_code(self, query, schema): + prompt = generate_prompt(query, schema) + res = self.llm.invoke(prompt) + return res + + +router = SQLGeneratorAPIRouter() + + +@router.post("/v1/retrieve_tables") +async def retrieve_tables(request: Request): + params = await request.json() + print(f"[SQLGenerator - chat] POST request: /v1/rag/retrieve_tables, params:{params}") + query = params["query"] + matched_tables = router.retrieve(query) + print(f"[SQLGenerator - chat] matched_tables: {matched_tables}") + return {"matched_tables": matched_tables} + + +@router.post("/v1/generate_sql_code") +async def generate_sql_code(request: Request): + params = await request.json() + print(f"[SQLGenerator - chat] POST request: /v1/rag/generate_sql_code, params:{params}") + query = params["query"] + schema = params["schema"] + sql_code = router.generate_sql_code(query, schema) + print(f"[SQLGenerator - chat] sql_code: {sql_code}") + return {"sql_code": sql_code} + + +app.include_router(router) + +if __name__ == "__main__": + import uvicorn + + fastapi_port = os.getenv("FASTAPI_PORT", "8080") + uvicorn.run(app, host="0.0.0.0", port=int(fastapi_port)) diff --git a/application/requirements.txt b/application/requirements.txt index 233dfbbd1..93f1432de 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -1,9 +1,11 @@ -langchain==0.1.0 +langchain==0.1.16 langchain_community==0.0.32 faiss-cpu==1.8.0 -openai==0.27.8 +openai==1.17.0 streamlit==1.30.0 duckduckgo-search==3.8.3 pypdf==3.17.0 -sentence-transformers==2.2.2 +sentence-transformers==2.2.2 --no-deps docarray==0.32.1 +streamlit-code-editor +torchvision==0.16.2 --no-deps \ No newline at end of file