From bfa70b218a733e492151680cf7f7abe4e7e8ddc4 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Sat, 16 Nov 2024 12:39:46 +0100 Subject: [PATCH 1/2] fix[output_format]: accept dataframe dict as output and secure sql query execution --- pandasai/connectors/sql.py | 2 +- pandasai/helpers/output_validator.py | 4 ++-- pandasai/responses/response_parser.py | 12 ++++++++++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index e1494ba59..68638e8a2 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -441,7 +441,7 @@ def execute_direct_sql_query(self, sql_query): if not self._is_sql_query_safe(sql_query): raise MaliciousQueryError("Malicious query is generated in code") - return pd.read_sql(sql_query, self._connection) + return pd.read_sql(text(sql_query), self._connection) @property def cs_table_name(self): diff --git a/pandasai/helpers/output_validator.py b/pandasai/helpers/output_validator.py index e26bcf2ff..56a3a495d 100644 --- a/pandasai/helpers/output_validator.py +++ b/pandasai/helpers/output_validator.py @@ -56,7 +56,7 @@ def validate_value(self, expected_type: str) -> bool: elif expected_type == "string": return isinstance(self, str) elif expected_type == "dataframe": - return isinstance(self, (pd.DataFrame, pd.Series)) + return isinstance(self, (pd.DataFrame, pd.Series, dict)) elif expected_type == "plot": if not isinstance(self, (str, dict)): return False @@ -82,7 +82,7 @@ def validate_result(result: dict) -> bool: elif result["type"] == "string": return isinstance(result["value"], str) elif result["type"] == "dataframe": - return isinstance(result["value"], (pd.DataFrame, pd.Series)) + return isinstance(result["value"], (pd.DataFrame, pd.Series, dict)) elif result["type"] == "plot": if "plotly" in repr(type(result["value"])): return True diff --git a/pandasai/responses/response_parser.py b/pandasai/responses/response_parser.py index fd202784d..6b1378c4f 100644 --- a/pandasai/responses/response_parser.py +++ b/pandasai/responses/response_parser.py @@ -2,6 +2,7 @@ from typing import Any from PIL import Image +import pandas as pd from pandasai.exceptions import MethodNotImplementedError @@ -51,9 +52,20 @@ def parse(self, result: dict) -> Any: if result["type"] == "plot": return self.format_plot(result) + elif result["type"] == "dataframe": + return self.format_dataframe(result) else: return result["value"] + def format_dataframe(self, result: dict) -> Any: + if isinstance(result["value"], dict): + print("Df conversiont") + df = pd.Dataframe(result["value"]) + print("Df conversiont Done") + result["value"] = df + + return result["value"] + def format_plot(self, result: dict) -> Any: """ Display matplotlib plot against a user query. From 331a5704137ebbb5c60b382164c27b29afa0b58d Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Sat, 16 Nov 2024 12:45:09 +0100 Subject: [PATCH 2/2] fix: ruff errors --- pandasai/responses/response_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandasai/responses/response_parser.py b/pandasai/responses/response_parser.py index 6b1378c4f..4254c77ec 100644 --- a/pandasai/responses/response_parser.py +++ b/pandasai/responses/response_parser.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Any -from PIL import Image import pandas as pd +from PIL import Image from pandasai.exceptions import MethodNotImplementedError