diff --git a/pandasai/core/code_generation/code_cleaning.py b/pandasai/core/code_generation/code_cleaning.py index b2555b6cb..988c51810 100644 --- a/pandasai/core/code_generation/code_cleaning.py +++ b/pandasai/core/code_generation/code_cleaning.py @@ -310,7 +310,7 @@ def _handle_charts(self, code: str) -> str: return add_save_chart( code, logger=self.context.logger, - file_name=str(self.context.prompt_id), + file_name=str(self.context.last_prompt_id), save_charts_path_str=self.context.config.save_charts_path, ) return add_save_chart( diff --git a/pandasai/core/response/__init__.py b/pandasai/core/response/__init__.py index 00964e638..ac9296a47 100644 --- a/pandasai/core/response/__init__.py +++ b/pandasai/core/response/__init__.py @@ -1,9 +1,9 @@ -from .parser import ResponseParser from .base import BaseResponse -from .string import StringResponse -from .number import NumberResponse -from .dataframe import DataFrameResponse from .chart import ChartResponse +from .dataframe import DataFrameResponse +from .number import NumberResponse +from .parser import ResponseParser +from .string import StringResponse __all__ = [ "ResponseParser", diff --git a/pandasai/core/response/base.py b/pandasai/core/response/base.py index 380c970dc..6e5f6ab25 100644 --- a/pandasai/core/response/base.py +++ b/pandasai/core/response/base.py @@ -1,5 +1,5 @@ -from typing import Any import json +from typing import Any class BaseResponse: diff --git a/pandasai/core/response/chart.py b/pandasai/core/response/chart.py index d50e10261..4d23fb379 100644 --- a/pandasai/core/response/chart.py +++ b/pandasai/core/response/chart.py @@ -1,7 +1,8 @@ -from typing import Any -from PIL import Image import base64 import io +from typing import Any + +from PIL import Image from .base import BaseResponse diff --git a/pandasai/core/response/dataframe.py b/pandasai/core/response/dataframe.py index f0de2ebfa..b5e5f4f13 100644 --- a/pandasai/core/response/dataframe.py +++ b/pandasai/core/response/dataframe.py @@ -1,4 +1,5 @@ from typing import Any + import pandas as pd from .base import BaseResponse diff --git a/pandasai/core/response/parser.py b/pandasai/core/response/parser.py index 5f2bbfac1..f83fea313 100644 --- a/pandasai/core/response/parser.py +++ b/pandasai/core/response/parser.py @@ -6,10 +6,10 @@ from pandasai.exceptions import InvalidOutputValueMismatch from .base import BaseResponse +from .chart import ChartResponse +from .dataframe import DataFrameResponse from .number import NumberResponse from .string import StringResponse -from .dataframe import DataFrameResponse -from .chart import ChartResponse class ResponseParser: diff --git a/pandasai/dataframe/base.py b/pandasai/dataframe/base.py index e3d52aaa9..88ef31824 100644 --- a/pandasai/dataframe/base.py +++ b/pandasai/dataframe/base.py @@ -12,6 +12,7 @@ import pandasai as pai from pandasai.config import Config +from pandasai.core.response import BaseResponse from pandasai.exceptions import DatasetNotFound, PandasAIApiKeyError from pandasai.helpers.dataframe_serializer import ( DataframeSerializer, @@ -19,7 +20,6 @@ ) from pandasai.helpers.path import find_project_root from pandasai.helpers.request import get_pandaai_session -from pandasai.core.response import BaseResponse if TYPE_CHECKING: from pandasai.agent.base import Agent diff --git a/tests/unit_tests/core/code_generation/test_code_cleaning.py b/tests/unit_tests/core/code_generation/test_code_cleaning.py index d324ce777..835bca91c 100644 --- a/tests/unit_tests/core/code_generation/test_code_cleaning.py +++ b/tests/unit_tests/core/code_generation/test_code_cleaning.py @@ -1,6 +1,6 @@ import ast import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from pandasai.agent.state import AgentState from pandasai.core.code_generation.code_cleaning import CodeCleaner @@ -174,6 +174,77 @@ def test_extract_fix_dataframe_redeclarations(self): ) self.assertIsInstance(updated_node, ast.AST) + @patch( + "pandasai.core.code_generation.code_cleaning.add_save_chart" + ) # Replace with actual module name + def test_handle_charts_save_charts_true(self, mock_add_save_chart): + handler = self.cleaner + handler.context = MagicMock() + handler.context.config.save_charts = True + handler.context.logger = MagicMock() # Mock logger + handler.context.last_prompt_id = 123 + handler.context.config.save_charts_path = "/custom/path" + + code = 'some text "temp_chart.png" more text' + + handler._handle_charts(code) + + mock_add_save_chart.assert_called_once_with( + code, + logger=handler.context.logger, + file_name="123", + save_charts_path_str="/custom/path", + ) + + @patch("pandasai.core.code_generation.code_cleaning.add_save_chart") + @patch( + "pandasai.core.code_generation.code_cleaning.find_project_root", + return_value="/root/project", + ) # Mock project root + def test_handle_charts_save_charts_false( + self, mock_find_project_root, mock_add_save_chart + ): + handler = self.cleaner + handler.context = MagicMock() + handler.context.config.save_charts = False + handler.context.logger = MagicMock() + handler.context.last_prompt_id = 123 + + code = 'some text "temp_chart.png" more text' + + handler._handle_charts(code) + + mock_add_save_chart.assert_called_once_with( + code, + logger=handler.context.logger, + file_name="temp_chart", + save_charts_path_str="/root/project/exports/charts", + ) + + def test_handle_charts_empty_code(self): + handler = self.cleaner + + code = "" + expected_code = "" # It should remain empty, as no substitution is made + + result = handler._handle_charts(code) + + self.assertEqual( + result, expected_code, f"Expected '{expected_code}', but got '{result}'" + ) + + def test_handle_charts_no_png(self): + handler = self.cleaner + + code = "some text without png" + expected_code = "some text without png" # No change should occur + + result = handler._handle_charts(code) + + self.assertEqual( + result, expected_code, f"Expected '{expected_code}', but got '{result}'" + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/unit_tests/helpers/test_responses.py b/tests/unit_tests/helpers/test_responses.py index 7d0a71dc8..d811c3724 100644 --- a/tests/unit_tests/helpers/test_responses.py +++ b/tests/unit_tests/helpers/test_responses.py @@ -2,13 +2,13 @@ import pandas as pd -from pandasai.core.response.parser import ResponseParser from pandasai.core.response import ( ChartResponse, DataFrameResponse, NumberResponse, StringResponse, ) +from pandasai.core.response.parser import ResponseParser from pandasai.exceptions import InvalidOutputValueMismatch diff --git a/tests/unit_tests/response/test_chart_response.py b/tests/unit_tests/response/test_chart_response.py index 48b7974dc..465cd8538 100644 --- a/tests/unit_tests/response/test_chart_response.py +++ b/tests/unit_tests/response/test_chart_response.py @@ -1,7 +1,9 @@ -import pytest -from PIL import Image import base64 import io + +import pytest +from PIL import Image + from pandasai.core.response.chart import ChartResponse diff --git a/tests/unit_tests/response/test_dataframe_response.py b/tests/unit_tests/response/test_dataframe_response.py index 799e90cb3..373a0af06 100644 --- a/tests/unit_tests/response/test_dataframe_response.py +++ b/tests/unit_tests/response/test_dataframe_response.py @@ -1,5 +1,6 @@ -import pytest import pandas as pd +import pytest + from pandasai.core.response.dataframe import DataFrameResponse