Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(directSqlPrompt): use connector directly if flag is set #731

Merged
merged 16 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 23 additions & 27 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
repos:
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.220
hooks:
- id: ruff
name: ruff
# Respect `exclude` and `extend-exclude` settings.
args: [--force-exclude]
- repo: local
hooks:
- id: pytest-check
name: pytest-check
entry: poetry run pytest
language: system
pass_filenames: false
always_run: true
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.1.3
hooks:
- id: ruff
name: ruff
- id: ruff-format
name: ruff-format
- repo: local
hooks:
- id: pytest-check
name: pytest-check
entry: poetry run pytest
language: system
pass_filenames: false
always_run: true

- repo: https://github.com/sourcery-ai/sourcery
rev: v1.11.0
hooks:
- id: sourcery
# The best way to use Sourcery in a pre-commit hook:
# * review only changed lines:
# * omit the summary
args: [--diff=git diff HEAD, --no-summary]
- repo: https://github.com/sourcery-ai/sourcery
rev: v1.11.0
hooks:
- id: sourcery
# The best way to use Sourcery in a pre-commit hook:
# * review only changed lines:
# * omit the summary
args: [--diff=git diff HEAD, --no-summary]
9 changes: 3 additions & 6 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ To make a contribution, follow the following steps:

For more details about pull requests, please read [GitHub's guides](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request).


### 📦 Package manager

We use `poetry` as our package manager. You can install poetry by following the instructions [here](https://python-poetry.org/docs/#installation).
Expand Down Expand Up @@ -44,12 +43,12 @@ ruff pandasai examples

Make sure that the linter does not report any errors or warnings before submitting a pull request.

### Code Format with `black`
### Code Format with `ruff-format`

We use `black` to reformat the code by running the following command:
We use `ruff` to reformat the code by running the following command:

```bash
black pandasai
ruff format pandasai
```

### 🧪 Testing
Expand All @@ -62,8 +61,6 @@ poetry run pytest

Make sure that all tests pass before submitting a pull request.



## 🚀 Release Process

At the moment, the release process is manual. We try to make frequent releases. Usually, we release a new version when we have a new feature or bugfix. A developer with admin rights to the repository will create a new release on GitHub, and then publish the new version to PyPI.
49 changes: 49 additions & 0 deletions examples/sql_direct_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Example of using PandasAI with a CSV file."""

from pandasai import SmartDatalake
from pandasai.llm import OpenAI
from pandasai.connectors import PostgreSQLConnector


# With a PostgreSQL database
payment_connector = PostgreSQLConnector(
config={
"host": "localhost",
"port": 5432,
"database": "testdb",
"username": "postgres",
"password": "123456",
"table": "orders",
}
)

order_details = PostgreSQLConnector(
config={
"host": "localhost",
"port": 5432,
"database": "testdb",
"username": "postgres",
"password": "123456",
"table": "order_details",
}
)

products = PostgreSQLConnector(
config={
"host": "localhost",
"port": 5432,
"database": "testdb",
"username": "postgres",
"password": "123456",
"table": "products",
}
)


llm = OpenAI("YOUR_API_KEY")
df = SmartDatalake(
[order_details, payment_connector, products],
config={"llm": llm, "direct_sql": True},
)
response = df.chat("Return Orders with OrderDetails and counts of distinct Products")
print(response)
38 changes: 38 additions & 0 deletions examples/using_workspace_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import pandas as pd
from pandasai import Agent

from pandasai.llm.openai import OpenAI
from pandasai.schemas.df_config import Config

employees_data = {
"EmployeeID": [1, 2, 3, 4, 5],
"Name": ["John", "Emma", "Liam", "Olivia", "William"],
"Department": ["HR", "Sales", "IT", "Marketing", "Finance"],
}

salaries_data = {
"EmployeeID": [1, 2, 3, 4, 5],
"Salary": [5000, 6000, 4500, 7000, 5500],
}

employees_df = pd.DataFrame(employees_data)
salaries_df = pd.DataFrame(salaries_data)


os.environ["PANDASAI_WORKSPACE"] = "workspace dir path"


llm = OpenAI("YOUR_API_KEY")
config__ = {"llm": llm, "save_charts": False}


agent = Agent(
[employees_df, salaries_df],
config=Config(**config__),
memory_size=10,
)

# Chat with the agent
response = agent.chat("plot salary against department?")
print(response)
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ nav:
- Documents Building: building_docs.md
- License: license.md
extra:
version: "1.4.2"
version: "1.4.4"
plugins:
- search
- mkdocstrings:
Expand Down
8 changes: 8 additions & 0 deletions pandasai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ def chat(self, query: str, output_type: Optional[str] = None):
f"\n{exception}\n"
)

def add_message(self, message, is_user=False):
"""
Add message to the memory. This is useful when you want to add a message
to the memory without calling the chat function (for example, when you
need to add a message from the agent).
"""
self._lake._memory.add(message, is_user=is_user)

def check_if_related_to_conversation(self, query: str) -> bool:
"""
Check if the query is related to the previous conversation
Expand Down
5 changes: 5 additions & 0 deletions pandasai/assets/prompt_templates/default_instructions.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Analyze the data, using the provided dataframes (`dfs`).
1. Prepare: Preprocessing and cleaning data if necessary
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)
{viz_library_type}
39 changes: 39 additions & 0 deletions pandasai/assets/prompt_templates/direct_sql_connector.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
You are provided with the following samples of sql tables data:

<Tables>
{tables}
<Tables>

<conversation>
{conversation}
</conversation>

You are provided with following function that executes the sql query,
<Function>
def execute_sql_query(sql_query: str) -> pd.Dataframe
"""his method connect to the database executes the sql query and returns the dataframe"""
</Function>

This is the initial python function. Do not change the params.

```python
# TODO import all the dependencies required
import pandas as pd

def analyze_data() -> dict:
"""
Analyze the data, using the provided dataframes (`dfs`).
1. Prepare: generate sql query to get data for analysis (grouping, filtering, aggregating, etc.)
2. Process: execute the query using execute method available to you which returns dataframe
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)
{viz_library_type}
At the end, return a dictionary of:
{output_type_hint}
"""
```

Take a deep breath and reason step-by-step. Act as a senior data analyst.
In the answer, you must never write the "technical" names of the tables.
Based on the last message in the conversation:

- return the updated analyze_data function wrapped within `python `
2 changes: 0 additions & 2 deletions pandasai/assets/prompt_templates/generate_python_code.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ You are provided with the following pandas DataFrames:
{conversation}
</conversation>

{viz_library_type}

This is the initial python function. Do not change the params. Given the context, use the right dataframes.
```python
{current_code}
Expand Down
1 change: 1 addition & 0 deletions pandasai/assets/prompt_templates/viz_library.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
If the user requests to create a chart, utilize the Python {library} library to generate high-quality graphics that will be saved directly to a file.
4 changes: 1 addition & 3 deletions pandasai/connectors/airtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ def execute(self):
Returns:
DataFrameType: The result of the connector.
"""
if cached := self._cached() or self._cached(
include_additional_filters=True
):
if cached := self._cached() or self._cached(include_additional_filters=True):
return pd.read_parquet(cached)

if isinstance(self._instance, pd.DataFrame):
Expand Down
17 changes: 17 additions & 0 deletions pandasai/connectors/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,20 @@ def __repr__(self):
f"host={self._config.host} port={self._config.port} "
f"database={self._config.database} httpPath={str(self._config.httpPath)}"
)

def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.token,
self._config.host,
self._config.port,
self._config.httpPath,
) == (
other._config.dialect,
other._config.token,
other._config.host,
other._config.port,
other._config.httpPath,
)
return False
Comment on lines 63 to +82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The equals method is a good addition for comparing two instances of the DatabricksConnector class. However, it's important to note that this method only checks for equality based on a subset of the instance's properties. If there are other properties that could affect the behavior of the instance, they should be included in this comparison. Also, consider renaming the method to __eq__ to follow Python's convention for equality comparison, which would allow you to use the == operator directly.

-    def equals(self, other):
+    def __eq__(self, other):

Commitable suggestion

[!IMPORTANT]
Carefully review the code before committing. Make sure it correctly replaces the highlighted code, has no missing lines and indentaion issues.

Suggested change
f"host={self._config.host} port={self._config.port} "
f"database={self._config.database} httpPath={str(self._config.httpPath)}"
)
def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.token,
self._config.host,
self._config.port,
self._config.httpPath,
) == (
other._config.dialect,
other._config.token,
other._config.host,
other._config.port,
other._config.httpPath,
)
return False
def __eq__(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.token,
self._config.host,
self._config.port,
self._config.httpPath,
) == (
other._config.dialect,
other._config.token,
other._config.host,
other._config.port,
other._config.httpPath,
)
return False

15 changes: 15 additions & 0 deletions pandasai/connectors/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,18 @@ def __repr__(self):
f"database={self._config.database} schema={str(self._config.dbSchema)} "
f"table={self._config.table}>"
)

def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.account,
self._config.username,
self._config.password,
) == (
other._config.dialect,
other._config.account,
other._config.username,
other._config.password,
)
return False
Comment on lines +94 to +107
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The equals method is comparing sensitive information like passwords. This could potentially lead to security issues if the method is used inappropriately. Consider comparing a hash of the password instead of the password itself. Also, consider adding a docstring to this method to explain its purpose and usage.

    def equals(self, other):
        if isinstance(other, self.__class__):
            return (
                self._config.dialect,
                self._config.account,
                self._config.username,
-               self._config.password,
+               hash(self._config.password),
            ) == (
                other._config.dialect,
                other._config.account,
                other._config.username,
-               other._config.password,
+               hash(other._config.password),
            )
        return False

Commitable suggestion

[!IMPORTANT]
Carefully review the code before committing. Make sure it correctly replaces the highlighted code, has no missing lines and indentaion issues.

Suggested change
def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.account,
self._config.username,
self._config.password,
) == (
other._config.dialect,
other._config.account,
other._config.username,
other._config.password,
)
return False
def equals(self, other):
"""
Compare the current object with another object for equality.
Args:
other: The object to compare with.
Returns:
True if the objects are equal, False otherwise.
"""
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.account,
self._config.username,
hash(self._config.password),
) == (
other._config.dialect,
other._config.account,
other._config.username,
hash(other._config.password),
)
return False

46 changes: 43 additions & 3 deletions pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import re
import os
import pandas as pd

from pandasai.exceptions import MaliciousQueryError
from .base import BaseConnector, SQLConnectorConfig, SqliteConnectorConfig
from .base import BaseConnectorConfig
from sqlalchemy import create_engine, text, select, asc
Expand Down Expand Up @@ -246,9 +248,7 @@ def execute(self):
DataFrame: The result of the SQL query.
"""

if cached := self._cached() or self._cached(
include_additional_filters=True
):
if cached := self._cached() or self._cached(include_additional_filters=True):
return pd.read_parquet(cached)

if self.logger:
Expand Down Expand Up @@ -362,6 +362,46 @@ def column_hash(self):
def fallback_name(self):
return self._config.table

def equals(self, other):
if isinstance(other, self.__class__):
return (
self._config.dialect,
self._config.driver,
self._config.host,
self._config.port,
self._config.username,
self._config.password,
) == (
other._config.dialect,
other._config.driver,
other._config.host,
other._config.port,
other._config.username,
other._config.password,
)
return False

def _is_sql_query_safe(self, query: str):
infected_keywords = [
r"\bINSERT\b",
r"\bUPDATE\b",
r"\bDELETE\b",
r"\bDROP\b",
r"\bEXEC\b",
r"\bALTER\b",
r"\bCREATE\b",
]

return not any(
re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords
)
Comment on lines +384 to +397
Copy link
Contributor

@coderabbitai coderabbitai bot Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _is_sql_query_safe method checks for potentially harmful SQL keywords in a query. This is a good security practice to prevent SQL injection attacks. However, it's important to note that this method only checks for a limited set of keywords and may not catch all possible malicious queries. Consider using a more comprehensive SQL injection prevention technique or library.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coderabbitai can you suggest more?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArslanSaleem Sure, here are a few more suggestions to improve SQL query safety:

  1. Parameterized Queries: This is one of the most effective ways to prevent SQL injection. Instead of constructing SQL queries with string concatenation, use parameterized queries. This ensures that user-supplied input is always treated as literal values and not part of the SQL command.

  2. ORM (Object-Relational Mapping) Libraries: Using an ORM library can abstract the SQL layer, reducing the risk of SQL injection attacks. These libraries generate SQL queries on your behalf and are generally designed to avoid SQL injection.

  3. Least Privilege Principle: Limit the permissions of the database account used by the application. Don't use a database root account whose compromise could lead to full database access.

  4. Regular Expressions: While the current method uses regular expressions to check for malicious keywords, it could be expanded to include more patterns. However, this method should not be relied upon solely as it can be bypassed with various techniques.

  5. Web Application Firewalls (WAFs): These can help to filter out malicious data and provide an additional layer of security.

  6. Input Validation: Validate input data thoroughly. Use both whitelist (only allow known good patterns) and blacklist (block known bad patterns) validation.

Remember, no single technique is foolproof. A combination of these methods will provide a more robust defense against SQL injection attacks.


def execute_direct_sql_query(self, sql_query):
if not self._is_sql_query_safe(sql_query):
raise MaliciousQueryError("Malicious query is generated in code")

return pd.read_sql(sql_query, self._connection)
Comment on lines +399 to +403
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The execute_direct_sql_query method executes a SQL query directly. It first checks if the query is safe using the _is_sql_query_safe method and raises a MaliciousQueryError if the query is determined to be malicious. This is a good practice for security. However, as mentioned above, the _is_sql_query_safe method may not catch all possible malicious queries.



class SqliteConnector(SQLConnector):
"""
Expand Down
Loading