Skip to content

Commit

Permalink
feat(directSqlPrompt): use connector directly if flag is set (#731)
Browse files Browse the repository at this point in the history
* feat: config plot libraries (#705)

* In this commit, I introduced a new configuration parameter in our application settings that allows users to define their preferred data visualization library (matplotlib, seaborn, or plotly).
With this update, I've eliminated the need for the user to specify in every prompt which library to use, thereby simplifying their interaction with the application and increasing its versatility.

* This commit adds a configuration parameter for users to set their preferred data visualization library (matplotlib, seaborn, or plotly), simplifying interactions and enhancing the application's versatility.

* viz_library_type' in test_generate_python_code_prompt.py, resolved failing tests

---------

Co-authored-by: sabatino.severino <qrxqfspfibrth6nxywai2qifza6jmskt222howzew43risnx4kva>
Co-authored-by: Gabriele Venturi <[email protected]>

* build: use ruff for formatting

* feat: add add_message method to the agent

* Release v1.4.3

* feat: workspace env (#717)

* fix(chart): charts to save to save_chart_path

* refactor sourcery changes

* 'Refactored by Sourcery'

* refactor chart save code

* fix: minor leftovers

* feat(workspace_env): add workspace env to store cache, temp chart and config

* add error handling and comments

---------

Co-authored-by: Sourcery AI <>

* fix: hallucinations was plotting when not asked

* Release v1.4.4

* feat(sqlConnector): add direct config run sql at runtime

* feat(DirectSqlConnector): add sql test cases

* fix: minor leftovers

* fix(orders): check examples of different tables

* 'Refactored by Sourcery'

* chore(sqlprompt): add description only when we have it

---------

Co-authored-by: Sab Severino <[email protected]>
Co-authored-by: Gabriele Venturi <[email protected]>
Co-authored-by: Sourcery AI <>
  • Loading branch information
3 people authored Nov 7, 2023
1 parent 3fa5625 commit e3c6b79
Show file tree
Hide file tree
Showing 27 changed files with 670 additions and 87 deletions.
59 changes: 59 additions & 0 deletions examples/sql_direct_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Example of using PandasAI with a CSV file."""

from pandasai import SmartDatalake
from pandasai.llm import OpenAI
from pandasai.connectors import PostgreSQLConnector
from pandasai.smart_dataframe import SmartDataframe


# With a PostgreSQL database
order = PostgreSQLConnector(
config={
"host": "localhost",
"port": 5432,
"database": "testdb",
"username": "postgres",
"password": "123456",
"table": "orders",
}
)

order_details = PostgreSQLConnector(
config={
"host": "localhost",
"port": 5432,
"database": "testdb",
"username": "postgres",
"password": "123456",
"table": "order_details",
}
)

products = PostgreSQLConnector(
config={
"host": "localhost",
"port": 5432,
"database": "testdb",
"username": "postgres",
"password": "123456",
"table": "products",
}
)


llm = OpenAI("OPEN_API_KEY")


order_details_smart_df = SmartDataframe(
order_details,
config={"llm": llm, "direct_sql": True},
description="Contain user order details",
)


df = SmartDatalake(
[order_details_smart_df, order, products],
config={"llm": llm, "direct_sql": True},
)
response = df.chat("return orders with count of distinct products")
print(response)
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ nav:
- Documents Building: building_docs.md
- License: license.md
extra:
version: "1.4.3"
version: "1.4.4"

plugins:
- search
- mkdocstrings:
Expand Down
5 changes: 5 additions & 0 deletions pandasai/assets/prompt_templates/default_instructions.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Analyze the data, using the provided dataframes (`dfs`).
1. Prepare: Preprocessing and cleaning data if necessary
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)
{viz_library_type}
39 changes: 39 additions & 0 deletions pandasai/assets/prompt_templates/direct_sql_connector.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
You are provided with the following samples of sql tables data:

<Tables>
{tables}
<Tables>

<conversation>
{conversation}
</conversation>

You are provided with following function that executes the sql query,
<Function>
def execute_sql_query(sql_query: str) -> pd.Dataframe
"""his method connect to the database executes the sql query and returns the dataframe"""
</Function>

This is the initial python function. Do not change the params.

```python
# TODO import all the dependencies required
import pandas as pd

def analyze_data() -> dict:
"""
Analyze the data, using the provided dataframes (`dfs`).
1. Prepare: generate sql query to get data for analysis (grouping, filtering, aggregating, etc.)
2. Process: execute the query using execute method available to you which returns dataframe
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)
{viz_library_type}
At the end, return a dictionary of:
{output_type_hint}
"""
```

Take a deep breath and reason step-by-step. Act as a senior data analyst.
In the answer, you must never write the "technical" names of the tables.
Based on the last message in the conversation:

- return the updated analyze_data function wrapped within `python `
2 changes: 0 additions & 2 deletions pandasai/assets/prompt_templates/generate_python_code.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ You are provided with the following pandas DataFrames:
{conversation}
</conversation>

{viz_library_type}

This is the initial python function. Do not change the params. Given the context, use the right dataframes.
```python
{current_code}
Expand Down
1 change: 1 addition & 0 deletions pandasai/assets/prompt_templates/viz_library.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
If the user requests to create a chart, utilize the Python {library} library to generate high-quality graphics that will be saved directly to a file.
17 changes: 17 additions & 0 deletions pandasai/connectors/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,20 @@ def __repr__(self):
f"host={self._config.host} port={self._config.port} "
f"database={self._config.database} httpPath={str(self._config.httpPath)}"
)

def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.token,
self._config.host,
self._config.port,
self._config.httpPath,
) == (
other._config.dialect,
other._config.token,
other._config.host,
other._config.port,
other._config.httpPath,
)
return False
15 changes: 15 additions & 0 deletions pandasai/connectors/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,18 @@ def __repr__(self):
f"database={self._config.database} schema={str(self._config.dbSchema)} "
f"table={self._config.table}>"
)

def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.account,
self._config.username,
self._config.password,
) == (
other._config.dialect,
other._config.account,
other._config.username,
other._config.password,
)
return False
42 changes: 42 additions & 0 deletions pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import re
import os
import pandas as pd

from pandasai.exceptions import MaliciousQueryError
from .base import BaseConnector, SQLConnectorConfig, SqliteConnectorConfig
from .base import BaseConnectorConfig
from sqlalchemy import create_engine, text, select, asc
Expand Down Expand Up @@ -360,6 +362,46 @@ def column_hash(self):
def fallback_name(self):
return self._config.table

def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.driver,
self._config.host,
self._config.port,
self._config.username,
self._config.password,
) == (
other._config.dialect,
other._config.driver,
other._config.host,
other._config.port,
other._config.username,
other._config.password,
)
return False

def _is_sql_query_safe(self, query: str):
infected_keywords = [
r"\bINSERT\b",
r"\bUPDATE\b",
r"\bDELETE\b",
r"\bDROP\b",
r"\bEXEC\b",
r"\bALTER\b",
r"\bCREATE\b",
]

return not any(
re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords
)

def execute_direct_sql_query(self, sql_query):
if not self._is_sql_query_safe(sql_query):
raise MaliciousQueryError("Malicious query is generated in code")

return pd.read_sql(sql_query, self._connection)


class SqliteConnector(SQLConnector):
"""
Expand Down
17 changes: 17 additions & 0 deletions pandasai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,27 @@ class UnSupportedLogicUnit(Exception):
Exception (Exception): UnSupportedLogicUnit
"""


class InvalidWorkspacePathError(Exception):
"""
Raised when the environment variable of workspace exist but path is invalid
Args:
Exception (Exception): InvalidWorkspacePathError
"""


class InvalidConfigError(Exception):
"""
Raised when config value is not appliable
Args:
Exception (Exception): InvalidConfigError
"""


class MaliciousQueryError(Exception):
"""
Raise error if malicious query is generated
Args:
Exception (Excpetion): MaliciousQueryError
"""
17 changes: 16 additions & 1 deletion pandasai/helpers/code_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@

class CodeExecutionContext:
_prompt_id: uuid.UUID = None
_can_direct_sql: bool = False
_skills_manager: SkillsManager = None

def __init__(self, prompt_id: uuid.UUID, skills_manager: SkillsManager):
def __init__(
self,
prompt_id: uuid.UUID,
skills_manager: SkillsManager,
_can_direct_sql: bool = False,
):
"""
Additional Context for code execution
Args:
Expand All @@ -39,6 +45,7 @@ def __init__(self, prompt_id: uuid.UUID, skills_manager: SkillsManager):
"""
self._skills_manager = skills_manager
self._prompt_id = prompt_id
self._can_direct_sql = _can_direct_sql

@property
def prompt_id(self):
Expand All @@ -48,6 +55,10 @@ def prompt_id(self):
def skills_manager(self):
return self._skills_manager

@property
def can_direct_sql(self):
return self._can_direct_sql


class CodeManager:
_dfs: List
Expand Down Expand Up @@ -283,6 +294,10 @@ def execute_code(self, code: str, context: CodeExecutionContext) -> Any:

analyze_data = environment.get("analyze_data")

if context.can_direct_sql:
environment["execute_sql_query"] = self._dfs[0].get_query_exec_func()
return analyze_data()

return analyze_data(self._get_originals(dfs))

def _get_samples(self, dfs):
Expand Down
5 changes: 2 additions & 3 deletions pandasai/helpers/viz_library_types/_viz_library_types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from abc import abstractmethod, ABC
from typing import Any, Iterable
from pandasai.prompts.generate_python_code import VizLibraryPrompt


class BaseVizLibraryType(ABC):
@property
def template_hint(self) -> str:
return f"""When a user requests to create a chart, utilize the Python
{self.name} library to generate high-quality graphics that will be saved
directly to a file."""
return VizLibraryPrompt(library=self.name)

@property
@abstractmethod
Expand Down
10 changes: 5 additions & 5 deletions pandasai/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
In order to better handle the instructions, this prompt module is written.
"""
from abc import ABC, abstractmethod
import string


class AbstractPrompt(ABC):
Expand Down Expand Up @@ -92,12 +93,11 @@ def to_string(self):
prompt_args = {}
for key, value in self._args.items():
if isinstance(value, AbstractPrompt):
args = [
arg[1] for arg in string.Formatter().parse(value.template) if arg[1]
]
value.set_vars(
{
k: v
for k, v in self._args.items()
if k != key and not isinstance(v, AbstractPrompt)
}
{k: v for k, v in self._args.items() if k != key and k in args}
)
prompt_args[key] = value.to_string()
else:
Expand Down
40 changes: 40 additions & 0 deletions pandasai/prompts/direct_sql_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
""" Prompt to explain code generation by the LLM
The previous conversation we had
<Conversation>
{conversation}
</Conversation>
Based on the last conversation you generated the following code:
<Code>
{code}
</Code>
Explain how you came up with code for non-technical people without
mentioning technical details or mentioning the libraries used?
"""
from .file_based_prompt import FileBasedPrompt


class DirectSQLPrompt(FileBasedPrompt):
"""Prompt to explain code generation by the LLM"""

_path_to_template = "assets/prompt_templates/direct_sql_connector.tmpl"

def _prepare_tables_data(self, tables):
tables_join = []
for table in tables:
table_description_tag = (
f' description="{table.table_description}"'
if table.table_description is not None
else ""
)
table_head_tag = f'<table name="{table.table_name}"{table_description_tag}>'
table = f"{table_head_tag}\n{table.head_csv}\n</table>"
tables_join.append(table)
return "\n\n".join(tables_join)

def setup(self, tables) -> None:
self.set_var("tables", self._prepare_tables_data(tables))
Loading

0 comments on commit e3c6b79

Please sign in to comment.