From d20204dda9c077fa50e66c2f86466a8477e86165 Mon Sep 17 00:00:00 2001 From: Sab Severino Date: Tue, 31 Oct 2023 15:26:56 +0100 Subject: [PATCH 01/13] feat: config plot libraries (#705) * In this commit, I introduced a new configuration parameter in our application settings that allows users to define their preferred data visualization library (matplotlib, seaborn, or plotly). With this update, I've eliminated the need for the user to specify in every prompt which library to use, thereby simplifying their interaction with the application and increasing its versatility. * This commit adds a configuration parameter for users to set their preferred data visualization library (matplotlib, seaborn, or plotly), simplifying interactions and enhancing the application's versatility. * viz_library_type' in test_generate_python_code_prompt.py, resolved failing tests --------- Co-authored-by: sabatino.severino Co-authored-by: Gabriele Venturi From 742b1b676fc60d61b7542d7681fe56982a189426 Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Wed, 1 Nov 2023 00:54:29 +0100 Subject: [PATCH 02/13] build: use ruff for formatting --- .pre-commit-config.yaml | 50 +++--- CONTRIBUTING.md | 9 +- pandasai/connectors/airtable.py | 4 +- pandasai/connectors/sql.py | 4 +- pandasai/helpers/from_google_sheets.py | 12 +- pandasai/helpers/openai_info.py | 8 +- pandasai/helpers/skills_manager.py | 7 +- .../helpers/viz_library_types/__init__.py | 7 +- pandasai/helpers/viz_library_types/base.py | 1 - pandasai/responses/response_parser.py | 4 +- pandasai/responses/streamlit_response.py | 4 +- pandasai/smart_dataframe/__init__.py | 47 +++--- poetry.lock | 144 ++++-------------- pyproject.toml | 3 +- 14 files changed, 99 insertions(+), 205 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4fcf32f17..cb1e1de48 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,29 +1,25 @@ repos: -- repo: https://github.com/psf/black - rev: 23.3.0 - hooks: - - id: black -- repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.220 - hooks: - - id: ruff - name: ruff - # Respect `exclude` and `extend-exclude` settings. - args: [--force-exclude] -- repo: local - hooks: - - id: pytest-check - name: pytest-check - entry: poetry run pytest - language: system - pass_filenames: false - always_run: true + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.1.3 + hooks: + - id: ruff + name: ruff + - id: ruff-format + name: ruff-format + - repo: local + hooks: + - id: pytest-check + name: pytest-check + entry: poetry run pytest + language: system + pass_filenames: false + always_run: true -- repo: https://github.com/sourcery-ai/sourcery - rev: v1.11.0 - hooks: - - id: sourcery - # The best way to use Sourcery in a pre-commit hook: - # * review only changed lines: - # * omit the summary - args: [--diff=git diff HEAD, --no-summary] + - repo: https://github.com/sourcery-ai/sourcery + rev: v1.11.0 + hooks: + - id: sourcery + # The best way to use Sourcery in a pre-commit hook: + # * review only changed lines: + # * omit the summary + args: [--diff=git diff HEAD, --no-summary] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 839d4ccc1..d74c720bf 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,7 +15,6 @@ To make a contribution, follow the following steps: For more details about pull requests, please read [GitHub's guides](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request). - ### ๐Ÿ“ฆ Package manager We use `poetry` as our package manager. You can install poetry by following the instructions [here](https://python-poetry.org/docs/#installation). @@ -44,12 +43,12 @@ ruff pandasai examples Make sure that the linter does not report any errors or warnings before submitting a pull request. -### Code Format with `black` +### Code Format with `ruff-format` -We use `black` to reformat the code by running the following command: +We use `ruff` to reformat the code by running the following command: ```bash -black pandasai +ruff format pandasai ``` ### ๐Ÿงช Testing @@ -62,8 +61,6 @@ poetry run pytest Make sure that all tests pass before submitting a pull request. - - ## ๐Ÿš€ Release Process At the moment, the release process is manual. We try to make frequent releases. Usually, we release a new version when we have a new feature or bugfix. A developer with admin rights to the repository will create a new release on GitHub, and then publish the new version to PyPI. diff --git a/pandasai/connectors/airtable.py b/pandasai/connectors/airtable.py index 868a6c161..9dd51257f 100644 --- a/pandasai/connectors/airtable.py +++ b/pandasai/connectors/airtable.py @@ -143,9 +143,7 @@ def execute(self): Returns: DataFrameType: The result of the connector. """ - if cached := self._cached() or self._cached( - include_additional_filters=True - ): + if cached := self._cached() or self._cached(include_additional_filters=True): return pd.read_parquet(cached) if isinstance(self._instance, pd.DataFrame): diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index 93af4ef9e..0e8c9715d 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -246,9 +246,7 @@ def execute(self): DataFrame: The result of the SQL query. """ - if cached := self._cached() or self._cached( - include_additional_filters=True - ): + if cached := self._cached() or self._cached(include_additional_filters=True): return pd.read_parquet(cached) if self.logger: diff --git a/pandasai/helpers/from_google_sheets.py b/pandasai/helpers/from_google_sheets.py index 2fdc08b49..513234beb 100644 --- a/pandasai/helpers/from_google_sheets.py +++ b/pandasai/helpers/from_google_sheets.py @@ -59,19 +59,14 @@ def sheet_to_df(sheet) -> list: # First pass: get all the headers for row in range(len(sheet)): # if every cell in the row is empty, skip row - if all( - sheet[row][col].strip() == "" for col in range(len(sheet[row])) - ): + if all(sheet[row][col].strip() == "" for col in range(len(sheet[row]))): headers += binding_headers binding_headers = [] continue for col in range(len(sheet[row])): # Check if the cell is bounded by a header - if any( - col >= header[2] and col <= header[3] - for header in binding_headers - ): + if any(col >= header[2] and col <= header[3] for header in binding_headers): continue # Check if the cell is commented out @@ -94,8 +89,7 @@ def sheet_to_df(sheet) -> list: df = [] for row in range(header[1], len(sheet)): if all( - sheet[row][col].strip() == "" - for col in range(header[2], header[3]) + sheet[row][col].strip() == "" for col in range(header[2], header[3]) ): break df_row = [sheet[row][col] for col in range(header[2], header[3])] diff --git a/pandasai/helpers/openai_info.py b/pandasai/helpers/openai_info.py index 50ae14f19..9de2f424b 100644 --- a/pandasai/helpers/openai_info.py +++ b/pandasai/helpers/openai_info.py @@ -74,10 +74,10 @@ def standardize_model_name( if "ft:" in model_name: model_name = model_name.split(":")[1] + "-finetuned" if is_completion and ( - model_name.startswith("gpt-4") - or model_name.startswith("gpt-3.5") - or model_name.startswith("gpt-35") - or "finetuned" in model_name + model_name.startswith("gpt-4") + or model_name.startswith("gpt-3.5") + or model_name.startswith("gpt-35") + or "finetuned" in model_name ): # The cost of completion token is different from # the cost of prompt tokens. diff --git a/pandasai/helpers/skills_manager.py b/pandasai/helpers/skills_manager.py index 8c37643bc..8223fa576 100644 --- a/pandasai/helpers/skills_manager.py +++ b/pandasai/helpers/skills_manager.py @@ -78,13 +78,10 @@ def prompt_display(self) -> str: if len(self._skills) == 0: return - return ( - """ + return """ You can also use the following functions, if relevant: -""" - + self.__str__() - ) +""" + self.__str__() @property def used_skills(self): diff --git a/pandasai/helpers/viz_library_types/__init__.py b/pandasai/helpers/viz_library_types/__init__.py index f02d9703a..b47eb8110 100644 --- a/pandasai/helpers/viz_library_types/__init__.py +++ b/pandasai/helpers/viz_library_types/__init__.py @@ -19,7 +19,11 @@ def viz_lib_type_factory( viz_lib_type: str = None, logger: Optional[Logger] = None -) -> Union[MatplotlibVizLibraryType, PlotlyVizLibraryType, SeabornVizLibraryType,]: +) -> Union[ + MatplotlibVizLibraryType, + PlotlyVizLibraryType, + SeabornVizLibraryType, +]: """ Factory function to get appropriate instance for viz library type. @@ -60,4 +64,3 @@ def viz_lib_type_factory( ) return viz_lib_type_helper - diff --git a/pandasai/helpers/viz_library_types/base.py b/pandasai/helpers/viz_library_types/base.py index f0406cb69..e53e4fa4c 100644 --- a/pandasai/helpers/viz_library_types/base.py +++ b/pandasai/helpers/viz_library_types/base.py @@ -17,4 +17,3 @@ class VisualizationLibrary(str, Enum): PLOTLY = "plotly" DEFAULT = "default" - diff --git a/pandasai/responses/response_parser.py b/pandasai/responses/response_parser.py index 1754b8c59..5866dfc26 100644 --- a/pandasai/responses/response_parser.py +++ b/pandasai/responses/response_parser.py @@ -93,7 +93,9 @@ def format_plot(self, result: dict) -> Any: try: image = mpimg.imread(result["value"]) except FileNotFoundError as e: - raise FileNotFoundError(f"The file {result['value']} does not exist.") from e # noqa: E501 + raise FileNotFoundError( + f"The file {result['value']} does not exist." + ) from e # noqa: E501 except OSError as e: raise ValueError( f"The file {result['value']} is not a valid image file." diff --git a/pandasai/responses/streamlit_response.py b/pandasai/responses/streamlit_response.py index e9d7a7eef..b566e0e40 100644 --- a/pandasai/responses/streamlit_response.py +++ b/pandasai/responses/streamlit_response.py @@ -18,7 +18,9 @@ def format_plot(self, result) -> None: try: image = mpimg.imread(result["value"]) except FileNotFoundError as e: - raise FileNotFoundError(f"The file {result['value']} does not exist.") from e # noqa: E501 + raise FileNotFoundError( + f"The file {result['value']} does not exist." + ) from e # noqa: E501 except OSError as e: raise ValueError( f"The file {result['value']} is not a valid image file." diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py index 9a3f54801..2d037b085 100644 --- a/pandasai/smart_dataframe/__init__.py +++ b/pandasai/smart_dataframe/__init__.py @@ -266,28 +266,7 @@ def __init__( ) if "://" in df_config["import_path"]: - connector_name = df_config["import_path"].split("://")[0] - connector_path = df_config["import_path"].split("://")[1] - connector_host = connector_path.split(":")[0] - connector_port = connector_path.split(":")[1].split("/")[0] - connector_database = connector_path.split(":")[1].split("/")[1] - connector_table = connector_path.split(":")[1].split("/")[2] - - connector_data = { - "host": connector_host, - "database": connector_database, - "table": connector_table, - } - if connector_port: - connector_data["port"] = connector_port - - # instantiate the connector - df = getattr( - __import__( - "pandasai.connectors", fromlist=[connector_name] - ), - connector_name, - )(config=connector_data) + df = self._instantiate_connector(df_config["import_path"]) else: df = df_config["import_path"] @@ -385,6 +364,28 @@ def load_connector(self, temporary: bool = False): """ self._core.load_connector(temporary) + def _instantiate_connector(self, import_path: str) -> BaseConnector: + connector_name = import_path.split("://")[0] + connector_path = import_path.split("://")[1] + connector_host = connector_path.split(":")[0] + connector_port = connector_path.split(":")[1].split("/")[0] + connector_database = connector_path.split(":")[1].split("/")[1] + connector_table = connector_path.split(":")[1].split("/")[2] + + connector_data = { + "host": connector_host, + "database": connector_database, + "table": connector_table, + } + if connector_port: + connector_data["port"] = connector_port + + # instantiate the connector + return getattr( + __import__("pandasai.connectors", fromlist=[connector_name]), + connector_name, + )(config=connector_data) + def _truncate_head_columns(self, df: DataFrameType, max_size=25) -> DataFrameType: """ Truncate the columns of the dataframe to a maximum of 20 characters. @@ -723,4 +724,4 @@ def __repr__(self): return self.dataframe.__repr__() def __len__(self): - return len(self.dataframe) \ No newline at end of file + return len(self.dataframe) diff --git a/poetry.lock b/poetry.lock index 88e74815a..d55182671 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -284,52 +284,6 @@ soupsieve = ">1.2" html5lib = ["html5lib"] lxml = ["lxml"] -[[package]] -name = "black" -version = "23.9.1" -description = "The uncompromising code formatter." -optional = false -python-versions = ">=3.8" -files = [ - {file = "black-23.9.1-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:d6bc09188020c9ac2555a498949401ab35bb6bf76d4e0f8ee251694664df6301"}, - {file = "black-23.9.1-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:13ef033794029b85dfea8032c9d3b92b42b526f1ff4bf13b2182ce4e917f5100"}, - {file = "black-23.9.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:75a2dc41b183d4872d3a500d2b9c9016e67ed95738a3624f4751a0cb4818fe71"}, - {file = "black-23.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13a2e4a93bb8ca74a749b6974925c27219bb3df4d42fc45e948a5d9feb5122b7"}, - {file = "black-23.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:adc3e4442eef57f99b5590b245a328aad19c99552e0bdc7f0b04db6656debd80"}, - {file = "black-23.9.1-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:8431445bf62d2a914b541da7ab3e2b4f3bc052d2ccbf157ebad18ea126efb91f"}, - {file = "black-23.9.1-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:8fc1ddcf83f996247505db6b715294eba56ea9372e107fd54963c7553f2b6dfe"}, - {file = "black-23.9.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:7d30ec46de88091e4316b17ae58bbbfc12b2de05e069030f6b747dfc649ad186"}, - {file = "black-23.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:031e8c69f3d3b09e1aa471a926a1eeb0b9071f80b17689a655f7885ac9325a6f"}, - {file = "black-23.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:538efb451cd50f43aba394e9ec7ad55a37598faae3348d723b59ea8e91616300"}, - {file = "black-23.9.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:638619a559280de0c2aa4d76f504891c9860bb8fa214267358f0a20f27c12948"}, - {file = "black-23.9.1-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:a732b82747235e0542c03bf352c126052c0fbc458d8a239a94701175b17d4855"}, - {file = "black-23.9.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:cf3a4d00e4cdb6734b64bf23cd4341421e8953615cba6b3670453737a72ec204"}, - {file = "black-23.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf99f3de8b3273a8317681d8194ea222f10e0133a24a7548c73ce44ea1679377"}, - {file = "black-23.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:14f04c990259576acd093871e7e9b14918eb28f1866f91968ff5524293f9c573"}, - {file = "black-23.9.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:c619f063c2d68f19b2d7270f4cf3192cb81c9ec5bc5ba02df91471d0b88c4c5c"}, - {file = "black-23.9.1-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:6a3b50e4b93f43b34a9d3ef00d9b6728b4a722c997c99ab09102fd5efdb88325"}, - {file = "black-23.9.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:c46767e8df1b7beefb0899c4a95fb43058fa8500b6db144f4ff3ca38eb2f6393"}, - {file = "black-23.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50254ebfa56aa46a9fdd5d651f9637485068a1adf42270148cd101cdf56e0ad9"}, - {file = "black-23.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:403397c033adbc45c2bd41747da1f7fc7eaa44efbee256b53842470d4ac5a70f"}, - {file = "black-23.9.1-py3-none-any.whl", hash = "sha256:6ccd59584cc834b6d127628713e4b6b968e5f79572da66284532525a042549f9"}, - {file = "black-23.9.1.tar.gz", hash = "sha256:24b6b3ff5c6d9ea08a8888f6977eae858e1f340d7260cf56d70a49823236b62d"}, -] - -[package.dependencies] -click = ">=8.0.0" -mypy-extensions = ">=0.4.3" -packaging = ">=22.0" -pathspec = ">=0.9.0" -platformdirs = ">=2" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} - -[package.extras] -colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)"] -jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] -uvloop = ["uvloop (>=0.15.2)"] - [[package]] name = "blinker" version = "1.6.2" @@ -588,7 +542,6 @@ files = [ {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18a64814ae7bce73925131381603fff0116e2df25230dfc80d6d690aa6e20b37"}, {file = "contourpy-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90c81f22b4f572f8a2110b0b741bb64e5a6427e0a198b2cdc1fbaf85f352a3aa"}, {file = "contourpy-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:53cc3a40635abedbec7f1bde60f8c189c49e84ac180c665f2cd7c162cc454baa"}, - {file = "contourpy-1.1.0-cp310-cp310-win32.whl", hash = "sha256:9b2dd2ca3ac561aceef4c7c13ba654aaa404cf885b187427760d7f7d4c57cff8"}, {file = "contourpy-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:1f795597073b09d631782e7245016a4323cf1cf0b4e06eef7ea6627e06a37ff2"}, {file = "contourpy-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0b7b04ed0961647691cfe5d82115dd072af7ce8846d31a5fac6c142dcce8b882"}, {file = "contourpy-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:27bc79200c742f9746d7dd51a734ee326a292d77e7d94c8af6e08d1e6c15d545"}, @@ -597,7 +550,6 @@ files = [ {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5cec36c5090e75a9ac9dbd0ff4a8cf7cecd60f1b6dc23a374c7d980a1cd710e"}, {file = "contourpy-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f0cbd657e9bde94cd0e33aa7df94fb73c1ab7799378d3b3f902eb8eb2e04a3a"}, {file = "contourpy-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:181cbace49874f4358e2929aaf7ba84006acb76694102e88dd15af861996c16e"}, - {file = "contourpy-1.1.0-cp311-cp311-win32.whl", hash = "sha256:edb989d31065b1acef3828a3688f88b2abb799a7db891c9e282df5ec7e46221b"}, {file = "contourpy-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fb3b7d9e6243bfa1efb93ccfe64ec610d85cfe5aec2c25f97fbbd2e58b531256"}, {file = "contourpy-1.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bcb41692aa09aeb19c7c213411854402f29f6613845ad2453d30bf421fe68fed"}, {file = "contourpy-1.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5d123a5bc63cd34c27ff9c7ac1cd978909e9c71da12e05be0231c608048bb2ae"}, @@ -606,7 +558,6 @@ files = [ {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:317267d915490d1e84577924bd61ba71bf8681a30e0d6c545f577363157e5e94"}, {file = "contourpy-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d551f3a442655f3dcc1285723f9acd646ca5858834efeab4598d706206b09c9f"}, {file = "contourpy-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e7a117ce7df5a938fe035cad481b0189049e8d92433b4b33aa7fc609344aafa1"}, - {file = "contourpy-1.1.0-cp38-cp38-win32.whl", hash = "sha256:108dfb5b3e731046a96c60bdc46a1a0ebee0760418951abecbe0fc07b5b93b27"}, {file = "contourpy-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:d4f26b25b4f86087e7d75e63212756c38546e70f2a92d2be44f80114826e1cd4"}, {file = "contourpy-1.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc00bb4225d57bff7ebb634646c0ee2a1298402ec10a5fe7af79df9a51c1bfd9"}, {file = "contourpy-1.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:189ceb1525eb0655ab8487a9a9c41f42a73ba52d6789754788d1883fb06b2d8a"}, @@ -615,7 +566,6 @@ files = [ {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:143dde50520a9f90e4a2703f367cf8ec96a73042b72e68fcd184e1279962eb6f"}, {file = "contourpy-1.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e94bef2580e25b5fdb183bf98a2faa2adc5b638736b2c0a4da98691da641316a"}, {file = "contourpy-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ed614aea8462735e7d70141374bd7650afd1c3f3cb0c2dbbcbe44e14331bf002"}, - {file = "contourpy-1.1.0-cp39-cp39-win32.whl", hash = "sha256:71551f9520f008b2950bef5f16b0e3587506ef4f23c734b71ffb7b89f8721999"}, {file = "contourpy-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:438ba416d02f82b692e371858143970ed2eb6337d9cdbbede0d8ad9f3d7dd17d"}, {file = "contourpy-1.1.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a698c6a7a432789e587168573a864a7ea374c6be8d4f31f9d87c001d5a843493"}, {file = "contourpy-1.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:397b0ac8a12880412da3551a8cb5a187d3298a72802b45a3bd1805e204ad8439"}, @@ -1348,11 +1298,11 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -2524,16 +2474,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -2871,7 +2811,7 @@ files = [ name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." -optional = false +optional = true python-versions = ">=3.5" files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, @@ -3096,8 +3036,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.20.3", markers = "python_version < \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.1" @@ -3859,7 +3799,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -3867,15 +3806,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -3892,7 +3824,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -3900,7 +3831,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -4096,27 +4026,28 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.0.220" +version = "0.1.3" description = "An extremely fast Python linter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.0.220-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:152e6697aca6aea991cdd37922c34a3e4db4828822c4663122326e6051e0f68a"}, - {file = "ruff-0.0.220-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:127887a00d53beb7c0c78a8b4bbdda2f14f07db7b3571feb6855cb32862cb88d"}, - {file = "ruff-0.0.220-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91235ff448786f8f3b856c104fd6c4fe11e835b0db75da5fdf337e1ed5d454da"}, - {file = "ruff-0.0.220-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2f0a104afc32012048627317ae8b0940e3f11a717905aed3fc26931a873e3b29"}, - {file = "ruff-0.0.220-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eee1deddf1671860e056a78938176600108857a527c078038627b284a554723c"}, - {file = "ruff-0.0.220-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:071082d09c953924eccfd88ffd0d71119ddd6fc7767f3c31549a1cd0651ba586"}, - {file = "ruff-0.0.220-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a5ddfc945a9076c9779b52c1f7296cf8d8e6919e619c4522617bc37b60eddd2e"}, - {file = "ruff-0.0.220-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15f387fd156430353fb61d2b609f1c38d2e9096e2fce31149da5cf08b73f04a8"}, - {file = "ruff-0.0.220-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5688983f21ac64bbcca8d84f4107733cc2d62c1354ea1a6b85eb9ead32328cc"}, - {file = "ruff-0.0.220-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e2b0c9dbff13649ded5ee92d6a47d720e8471461e0a4eba01bf3474f851cb2f0"}, - {file = "ruff-0.0.220-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9b540fa9f90f46f656b34fb73b738613562974599903a1f0d40bdd1a8180bfab"}, - {file = "ruff-0.0.220-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a061c17c2b0f81193fca5e53829b6c0569c5c7d393cc4fc1c192ce0a64d3b9ca"}, - {file = "ruff-0.0.220-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:42677089abd7db6f8aefa3dbe7a82fea4e2a43a08bc7bf6d3f58ec3d76c63712"}, - {file = "ruff-0.0.220-py3-none-win32.whl", hash = "sha256:f8821cfc204b38140afe870bcd4cc6c836bbd2f820b92df66b8fe8b8d71a3772"}, - {file = "ruff-0.0.220-py3-none-win_amd64.whl", hash = "sha256:8a1d678a224afd7149afbe497c97c3ccdc6c42632ee84fb0e3f68d190c1ccec1"}, - {file = "ruff-0.0.220.tar.gz", hash = "sha256:621f7f063c0d13570b709fb9904a329ddb9a614fdafc786718afd43e97440c34"}, + {file = "ruff-0.1.3-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:b46d43d51f7061652eeadb426a9e3caa1e0002470229ab2fc19de8a7b0766901"}, + {file = "ruff-0.1.3-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:b8afeb9abd26b4029c72adc9921b8363374f4e7edb78385ffaa80278313a15f9"}, + {file = "ruff-0.1.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca3cf365bf32e9ba7e6db3f48a4d3e2c446cd19ebee04f05338bc3910114528b"}, + {file = "ruff-0.1.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4874c165f96c14a00590dcc727a04dca0cfd110334c24b039458c06cf78a672e"}, + {file = "ruff-0.1.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eec2dd31eed114e48ea42dbffc443e9b7221976554a504767ceaee3dd38edeb8"}, + {file = "ruff-0.1.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:dc3ec4edb3b73f21b4aa51337e16674c752f1d76a4a543af56d7d04e97769613"}, + {file = "ruff-0.1.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e3de9ed2e39160800281848ff4670e1698037ca039bda7b9274f849258d26ce"}, + {file = "ruff-0.1.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c595193881922cc0556a90f3af99b1c5681f0c552e7a2a189956141d8666fe8"}, + {file = "ruff-0.1.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f75e670d529aa2288cd00fc0e9b9287603d95e1536d7a7e0cafe00f75e0dd9d"}, + {file = "ruff-0.1.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:76dd49f6cd945d82d9d4a9a6622c54a994689d8d7b22fa1322983389b4892e20"}, + {file = "ruff-0.1.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:918b454bc4f8874a616f0d725590277c42949431ceb303950e87fef7a7d94cb3"}, + {file = "ruff-0.1.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:d8859605e729cd5e53aa38275568dbbdb4fe882d2ea2714c5453b678dca83784"}, + {file = "ruff-0.1.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:0b6c55f5ef8d9dd05b230bb6ab80bc4381ecb60ae56db0330f660ea240cb0d4a"}, + {file = "ruff-0.1.3-py3-none-win32.whl", hash = "sha256:3e7afcbdcfbe3399c34e0f6370c30f6e529193c731b885316c5a09c9e4317eef"}, + {file = "ruff-0.1.3-py3-none-win_amd64.whl", hash = "sha256:7a18df6638cec4a5bd75350639b2bb2a2366e01222825562c7346674bdceb7ea"}, + {file = "ruff-0.1.3-py3-none-win_arm64.whl", hash = "sha256:12fd53696c83a194a2db7f9a46337ce06445fb9aa7d25ea6f293cf75b21aca9f"}, + {file = "ruff-0.1.3.tar.gz", hash = "sha256:3ba6145369a151401d5db79f0a47d50e470384d0d89d0d6f7fab0b589ad07c34"}, ] [[package]] @@ -4137,11 +4068,6 @@ files = [ {file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f66eddfda9d45dd6cadcd706b65669ce1df84b8549875691b1f403730bdef217"}, {file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6448c37741145b241eeac617028ba6ec2119e1339b1385c9720dae31367f2be"}, {file = "scikit_learn-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:c413c2c850241998168bbb3bd1bb59ff03b1195a53864f0b80ab092071af6028"}, - {file = "scikit_learn-1.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ef540e09873e31569bc8b02c8a9f745ee04d8e1263255a15c9969f6f5caa627f"}, - {file = "scikit_learn-1.3.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:9147a3a4df4d401e618713880be023e36109c85d8569b3bf5377e6cd3fecdeac"}, - {file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2cd3634695ad192bf71645702b3df498bd1e246fc2d529effdb45a06ab028b4"}, - {file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c275a06c5190c5ce00af0acbb61c06374087949f643ef32d355ece12c4db043"}, - {file = "scikit_learn-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:0e1aa8f206d0de814b81b41d60c1ce31f7f2c7354597af38fae46d9c47c45122"}, {file = "scikit_learn-1.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:52b77cc08bd555969ec5150788ed50276f5ef83abb72e6f469c5b91a0009bbca"}, {file = "scikit_learn-1.3.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a683394bc3f80b7c312c27f9b14ebea7766b1f0a34faf1a2e9158d80e860ec26"}, {file = "scikit_learn-1.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15d964d9eb181c79c190d3dbc2fff7338786bf017e9039571418a1d53dab236"}, @@ -4461,7 +4387,6 @@ files = [ {file = "SQLAlchemy-1.4.49-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:03db81b89fe7ef3857b4a00b63dedd632d6183d4ea5a31c5d8a92e000a41fc71"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:95b9df9afd680b7a3b13b38adf6e3a38995da5e162cc7524ef08e3be4e5ed3e1"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a63e43bf3f668c11bb0444ce6e809c1227b8f067ca1068898f3008a273f52b09"}, - {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca46de16650d143a928d10842939dab208e8d8c3a9a8757600cae9b7c579c5cd"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f835c050ebaa4e48b18403bed2c0fda986525896efd76c245bdd4db995e51a4c"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c21b172dfb22e0db303ff6419451f0cac891d2e911bb9fbf8003d717f1bcf91"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-win32.whl", hash = "sha256:5fb1ebdfc8373b5a291485757bd6431de8d7ed42c27439f543c81f6c8febd729"}, @@ -4471,35 +4396,26 @@ files = [ {file = "SQLAlchemy-1.4.49-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5debe7d49b8acf1f3035317e63d9ec8d5e4d904c6e75a2a9246a119f5f2fdf3d"}, {file = "SQLAlchemy-1.4.49-cp311-cp311-win32.whl", hash = "sha256:82b08e82da3756765c2e75f327b9bf6b0f043c9c3925fb95fb51e1567fa4ee87"}, {file = "SQLAlchemy-1.4.49-cp311-cp311-win_amd64.whl", hash = "sha256:171e04eeb5d1c0d96a544caf982621a1711d078dbc5c96f11d6469169bd003f1"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f23755c384c2969ca2f7667a83f7c5648fcf8b62a3f2bbd883d805454964a800"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8396e896e08e37032e87e7fbf4a15f431aa878c286dc7f79e616c2feacdb366c"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66da9627cfcc43bbdebd47bfe0145bb662041472393c03b7802253993b6b7c90"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-win32.whl", hash = "sha256:9a06e046ffeb8a484279e54bda0a5abfd9675f594a2e38ef3133d7e4d75b6214"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-win_amd64.whl", hash = "sha256:7cf8b90ad84ad3a45098b1c9f56f2b161601e4670827d6b892ea0e884569bd1d"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:36e58f8c4fe43984384e3fbe6341ac99b6b4e083de2fe838f0fdb91cebe9e9cb"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b31e67ff419013f99ad6f8fc73ee19ea31585e1e9fe773744c0f3ce58c039c30"}, - {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebc22807a7e161c0d8f3da34018ab7c97ef6223578fcdd99b1d3e7ed1100a5db"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c14b29d9e1529f99efd550cd04dbb6db6ba5d690abb96d52de2bff4ed518bc95"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c40f3470e084d31247aea228aa1c39bbc0904c2b9ccbf5d3cfa2ea2dac06f26d"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-win32.whl", hash = "sha256:706bfa02157b97c136547c406f263e4c6274a7b061b3eb9742915dd774bbc264"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-win_amd64.whl", hash = "sha256:a7f7b5c07ae5c0cfd24c2db86071fb2a3d947da7bd487e359cc91e67ac1c6d2e"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:4afbbf5ef41ac18e02c8dc1f86c04b22b7a2125f2a030e25bbb4aff31abb224b"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24e300c0c2147484a002b175f4e1361f102e82c345bf263242f0449672a4bccf"}, - {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:393cd06c3b00b57f5421e2133e088df9cabcececcea180327e43b937b5a7caa5"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:201de072b818f8ad55c80d18d1a788729cccf9be6d9dc3b9d8613b053cd4836d"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7653ed6817c710d0c95558232aba799307d14ae084cc9b1f4c389157ec50df5c"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-win32.whl", hash = "sha256:647e0b309cb4512b1f1b78471fdaf72921b6fa6e750b9f891e09c6e2f0e5326f"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-win_amd64.whl", hash = "sha256:ab73ed1a05ff539afc4a7f8cf371764cdf79768ecb7d2ec691e3ff89abbc541e"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:37ce517c011560d68f1ffb28af65d7e06f873f191eb3a73af5671e9c3fada08a"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1878ce508edea4a879015ab5215546c444233881301e97ca16fe251e89f1c55"}, - {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95ab792ca493891d7a45a077e35b418f68435efb3e1706cb8155e20e86a9013c"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:0e8e608983e6f85d0852ca61f97e521b62e67969e6e640fe6c6b575d4db68557"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccf956da45290df6e809ea12c54c02ace7f8ff4d765d6d3dfb3655ee876ce58d"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-win32.whl", hash = "sha256:f167c8175ab908ce48bd6550679cc6ea20ae169379e73c7720a28f89e53aa532"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-win_amd64.whl", hash = "sha256:45806315aae81a0c202752558f0df52b42d11dd7ba0097bf71e253b4215f34f4"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:b6d0c4b15d65087738a6e22e0ff461b407533ff65a73b818089efc8eb2b3e1de"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a843e34abfd4c797018fd8d00ffffa99fd5184c421f190b6ca99def4087689bd"}, - {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:738d7321212941ab19ba2acf02a68b8ee64987b248ffa2101630e8fccb549e0d"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:1c890421651b45a681181301b3497e4d57c0d01dc001e10438a40e9a9c25ee77"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d26f280b8f0a8f497bc10573849ad6dc62e671d2468826e5c748d04ed9e670d5"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-win32.whl", hash = "sha256:ec2268de67f73b43320383947e74700e95c6770d0c68c4e615e9897e46296294"}, @@ -4508,7 +4424,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"} +greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\")"} [package.extras] aiomysql = ["aiomysql", "greenlet (!=0.4.17)"] @@ -4577,20 +4493,12 @@ files = [ {file = "statsmodels-0.14.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5a6a0a1a06ff79be8aa89c8494b33903442859add133f0dda1daf37c3c71682e"}, {file = "statsmodels-0.14.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77b3cd3a5268ef966a0a08582c591bd29c09c88b4566c892a7c087935234f285"}, {file = "statsmodels-0.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c64ebe9cf376cba0c31aed138e15ed179a1d128612dd241cdf299d159e5e882"}, - {file = "statsmodels-0.14.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:229b2f676b4a45cb62d132a105c9c06ca8a09ffba060abe34935391eb5d9ba87"}, {file = "statsmodels-0.14.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb471f757fc45102a87e5d86e87dc2c8c78b34ad4f203679a46520f1d863b9da"}, {file = "statsmodels-0.14.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:582f9e41092e342aaa04920d17cc3f97240e3ee198672f194719b5a3d08657d6"}, {file = "statsmodels-0.14.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7ebe885ccaa64b4bc5ad49ac781c246e7a594b491f08ab4cfd5aa456c363a6f6"}, {file = "statsmodels-0.14.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b587ee5d23369a0e881da6e37f78371dce4238cf7638a455db4b633a1a1c62d6"}, {file = "statsmodels-0.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ef7fa4813c7a73b0d8a0c830250f021c102c71c95e9fe0d6877bcfb56d38b8c"}, - {file = "statsmodels-0.14.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afe80544ef46730ea1b11cc655da27038bbaa7159dc5af4bc35bbc32982262f2"}, {file = "statsmodels-0.14.0-cp311-cp311-win_amd64.whl", hash = "sha256:a6ad7b8aadccd4e4dd7f315a07bef1bca41d194eeaf4ec600d20dea02d242fce"}, - {file = "statsmodels-0.14.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0eea4a0b761aebf0c355b726ac5616b9a8b618bd6e81a96b9f998a61f4fd7484"}, - {file = "statsmodels-0.14.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4c815ce7a699047727c65a7c179bff4031cff9ae90c78ca730cfd5200eb025dd"}, - {file = "statsmodels-0.14.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:575f61337c8e406ae5fa074d34bc6eb77b5a57c544b2d4ee9bc3da6a0a084cf1"}, - {file = "statsmodels-0.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8be53cdeb82f49c4cb0fda6d7eeeb2d67dbd50179b3e1033510e061863720d93"}, - {file = "statsmodels-0.14.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:6f7d762df4e04d1dde8127d07e91aff230eae643aa7078543e60e83e7d5b40db"}, - {file = "statsmodels-0.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:fc2c7931008a911e3060c77ea8933f63f7367c0f3af04f82db3a04808ad2cd2c"}, {file = "statsmodels-0.14.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3757542c95247e4ab025291a740efa5da91dc11a05990c033d40fce31c450dc9"}, {file = "statsmodels-0.14.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:de489e3ed315bdba55c9d1554a2e89faa65d212e365ab81bc323fa52681fc60e"}, {file = "statsmodels-0.14.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76e290f4718177bffa8823a780f3b882d56dd64ad1c18cfb4bc8b5558f3f5757"}, @@ -4606,8 +4514,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.22.3", markers = "python_version == \"3.10\" and platform_system == \"Windows\" and platform_python_implementation != \"PyPy\""}, {version = ">=1.18", markers = "python_version != \"3.10\" or platform_system != \"Windows\" or platform_python_implementation == \"PyPy\""}, + {version = ">=1.22.3", markers = "python_version == \"3.10\" and platform_system == \"Windows\" and platform_python_implementation != \"PyPy\""}, ] packaging = ">=21.3" pandas = ">=1.0" @@ -5135,4 +5043,4 @@ yfinance = ["yfinance"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.9.7 || >3.9.7,<4.0" -content-hash = "c8db58630bea4943751b4031253897f35470c78052e67212a995478fa8eb3a02" +content-hash = "73f0448e3e2a2031b23c114b7a83cc59338825ef78dd3d8bf006c9710970f98a" diff --git a/pyproject.toml b/pyproject.toml index b6abe4387..13896ffe9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,9 +41,8 @@ sqlalchemy-databricks = { version = "^0.2.0", optional = true } snowflake-sqlalchemy = { version = "^1.5.0", optional = true } [tool.poetry.group.dev.dependencies] -black = "^23.3.0" pre-commit = "^3.2.2" -ruff = "^0.0.220" +ruff = "^0.1.0" pytest = "^7.3.1" pytest-mock = "^3.10.0" pytest-env = "^0.8.1" From d3e896c5d67bddd98b9c0d780d8672fe4d92d9a2 Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Wed, 1 Nov 2023 01:03:12 +0100 Subject: [PATCH 03/13] feat: add add_message method to the agent --- pandasai/agent/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index cc3ec9111..7d5ac818b 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -92,6 +92,14 @@ def chat(self, query: str, output_type: Optional[str] = None): f"\n{exception}\n" ) + def add_message(self, message, is_user=False): + """ + Add message to the memory. This is useful when you want to add a message + to the memory without calling the chat function (for example, when you + need to add a message from the agent). + """ + self._lake._memory.add(message, is_user=is_user) + def check_if_related_to_conversation(self, query: str) -> bool: """ Check if the query is related to the previous conversation From 5537a7e65ebddf2eb5a718faadee0c59973e4a6c Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Wed, 1 Nov 2023 01:48:58 +0100 Subject: [PATCH 04/13] Release v1.4.3 --- mkdocs.yml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mkdocs.yml b/mkdocs.yml index d896d53fa..0e0c6df48 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -42,7 +42,7 @@ nav: - Documents Building: building_docs.md - License: license.md extra: - version: "1.4.2" + version: "1.4.3" plugins: - search - mkdocstrings: diff --git a/pyproject.toml b/pyproject.toml index 13896ffe9..c89c81eb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pandasai" -version = "1.4.2" +version = "1.4.3" description = "PandasAI is a Python library that integrates generative artificial intelligence capabilities into Pandas, making dataframes conversational." authors = ["Gabriele Venturi"] license = "MIT" From 451a843ed4d183dfc4619b2012d2d2fe218dd9d0 Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Thu, 2 Nov 2023 04:31:20 +0500 Subject: [PATCH 05/13] feat: workspace env (#717) * fix(chart): charts to save to save_chart_path * refactor sourcery changes * 'Refactored by Sourcery' * refactor chart save code * fix: minor leftovers * feat(workspace_env): add workspace env to store cache, temp chart and config * add error handling and comments --------- Co-authored-by: Sourcery AI <> --- examples/using_workspace_env.py | 38 +++++++++++++++++++++++++++++ pandasai/exceptions.py | 9 +++++++ pandasai/helpers/code_manager.py | 9 +++++++ pandasai/helpers/path.py | 24 ++++++++++++++++-- pandasai/schemas/df_config.py | 1 - pandasai/smart_datalake/__init__.py | 2 ++ tests/test_smartdataframe.py | 3 ++- 7 files changed, 82 insertions(+), 4 deletions(-) create mode 100644 examples/using_workspace_env.py diff --git a/examples/using_workspace_env.py b/examples/using_workspace_env.py new file mode 100644 index 000000000..604d72a60 --- /dev/null +++ b/examples/using_workspace_env.py @@ -0,0 +1,38 @@ +import os +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI +from pandasai.schemas.df_config import Config + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + + +os.environ["PANDASAI_WORKSPACE"] = "workspace dir path" + + +llm = OpenAI("YOUR_API_KEY") +config__ = {"llm": llm, "save_charts": False} + + +agent = Agent( + [employees_df, salaries_df], + config=Config(**config__), + memory_size=10, +) + +# Chat with the agent +response = agent.chat("plot salary against department?") +print(response) diff --git a/pandasai/exceptions.py b/pandasai/exceptions.py index 56551ff34..56853c79d 100644 --- a/pandasai/exceptions.py +++ b/pandasai/exceptions.py @@ -146,3 +146,12 @@ class AdvancedReasoningDisabledError(Exception): Args: Exception (Exception): AdvancedReasoningDisabledError """ + + +class InvalidWorkspacePathError(Exception): + """ + Raised when the environment variable of workspace exist but path is invalid + + Args: + Exception (Exception): InvalidWorkspacePathError + """ diff --git a/pandasai/helpers/code_manager.py b/pandasai/helpers/code_manager.py index c95264524..edf7aa0fd 100644 --- a/pandasai/helpers/code_manager.py +++ b/pandasai/helpers/code_manager.py @@ -5,6 +5,7 @@ import astor import pandas as pd +from pandasai.helpers.path import find_project_root from pandasai.helpers.skills_manager import SkillsManager @@ -235,6 +236,14 @@ def execute_code(self, code: str, context: CodeExecutionContext) -> Any: file_name=str(context.prompt_id), save_charts_path_str=self._config.save_charts_path, ) + else: + # Temporarily save generated chart to display + code = add_save_chart( + code, + logger=self._logger, + file_name="temp_chart", + save_charts_path_str=find_project_root(), + ) # Reset used skills context.skills_manager.used_skills = [] diff --git a/pandasai/helpers/path.py b/pandasai/helpers/path.py index 18d771a7b..5527a55e7 100644 --- a/pandasai/helpers/path.py +++ b/pandasai/helpers/path.py @@ -1,8 +1,27 @@ import os +from pandasai.exceptions import InvalidWorkspacePathError + def find_project_root(filename=None): - # Get the path of the file that is being executed + """ + Check if Custom workspace path provide use that otherwise iterate to + find project root + """ + if "PANDASAI_WORKSPACE" in os.environ: + workspace_path = os.environ["PANDASAI_WORKSPACE"] + if ( + workspace_path + and os.path.exists(workspace_path) + and os.path.isdir(workspace_path) + ): + return workspace_path + raise InvalidWorkspacePathError( + "PANDASAI_WORKSPACE does not point to a valid directory" + ) + + # Get the path of the file that is be + # ing executed current_file_path = os.path.abspath(os.getcwd()) # Navigate back until we either find a $filename file or there is no parent @@ -26,7 +45,8 @@ def find_project_root(filename=None): parent_folder = os.path.dirname(root_folder) if parent_folder == root_folder: - raise ValueError("Could not find the root folder of the project.") + # if project root is not found return cwd + return os.getcwd() root_folder = parent_folder diff --git a/pandasai/schemas/df_config.py b/pandasai/schemas/df_config.py index 7ceaf725e..2cd1f312d 100644 --- a/pandasai/schemas/df_config.py +++ b/pandasai/schemas/df_config.py @@ -1,7 +1,6 @@ from pydantic import BaseModel, validator, Field from typing import Optional, List, Any, Dict, Type, TypedDict from pandasai.constants import DEFAULT_CHART_DIRECTORY - from pandasai.responses import ResponseParser from ..middlewares.base import Middleware from ..callbacks.base import BaseCallback diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index f3dac336f..b480f1e4a 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -160,6 +160,7 @@ def initialize(self): charts_dir = os.path.join( (find_project_root()), self._config.save_charts_path ) + self._config.save_charts_path = charts_dir except ValueError: charts_dir = os.path.join( os.getcwd(), self._config.save_charts_path @@ -438,6 +439,7 @@ def chat(self, query: str, output_type: Optional[str] = None): ) break + except Exception as e: if ( not self._config.use_error_correction_framework diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index 938418394..b2f84c68e 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -473,6 +473,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: "llm": llm, "enable_cache": False, "save_charts": True, + "save_charts_path": "charts", }, ) @@ -482,7 +483,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: assert plt_mock.savefig.called assert ( plt_mock.savefig.call_args.args[0] - == f"exports/charts/{smart_dataframe.last_prompt_id}.png" + == f"charts/{smart_dataframe.last_prompt_id}.png" ) def test_add_middlewares(self, smart_dataframe: SmartDataframe, custom_middleware): From f65cc228687d3c0f4049fc48431caf753d1d2e3e Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Thu, 2 Nov 2023 00:39:47 +0100 Subject: [PATCH 06/13] fix: hallucinations was plotting when not asked --- .../default_instructions.tmpl | 5 +++ .../generate_python_code.tmpl | 2 -- .../assets/prompt_templates/viz_library.tmpl | 1 + .../viz_library_types/_viz_library_types.py | 5 ++- pandasai/prompts/base.py | 10 +++--- pandasai/prompts/generate_python_code.py | 31 ++++++++++++------- .../test_generate_python_code_prompt.py | 6 ++-- tests/test_smartdataframe.py | 13 ++------ 8 files changed, 38 insertions(+), 35 deletions(-) create mode 100644 pandasai/assets/prompt_templates/default_instructions.tmpl create mode 100644 pandasai/assets/prompt_templates/viz_library.tmpl diff --git a/pandasai/assets/prompt_templates/default_instructions.tmpl b/pandasai/assets/prompt_templates/default_instructions.tmpl new file mode 100644 index 000000000..f72542e20 --- /dev/null +++ b/pandasai/assets/prompt_templates/default_instructions.tmpl @@ -0,0 +1,5 @@ +Analyze the data, using the provided dataframes (`dfs`). + 1. Prepare: Preprocessing and cleaning data if necessary + 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + {viz_library_type} \ No newline at end of file diff --git a/pandasai/assets/prompt_templates/generate_python_code.tmpl b/pandasai/assets/prompt_templates/generate_python_code.tmpl index 28e850b9c..b0c337e44 100644 --- a/pandasai/assets/prompt_templates/generate_python_code.tmpl +++ b/pandasai/assets/prompt_templates/generate_python_code.tmpl @@ -6,8 +6,6 @@ You are provided with the following pandas DataFrames: {conversation} -{viz_library_type} - This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python {current_code} diff --git a/pandasai/assets/prompt_templates/viz_library.tmpl b/pandasai/assets/prompt_templates/viz_library.tmpl new file mode 100644 index 000000000..ee3b7ed25 --- /dev/null +++ b/pandasai/assets/prompt_templates/viz_library.tmpl @@ -0,0 +1 @@ +If the user requests to create a chart, utilize the Python {library} library to generate high-quality graphics that will be saved directly to a file. \ No newline at end of file diff --git a/pandasai/helpers/viz_library_types/_viz_library_types.py b/pandasai/helpers/viz_library_types/_viz_library_types.py index 87b290a8f..3c9ae66e3 100644 --- a/pandasai/helpers/viz_library_types/_viz_library_types.py +++ b/pandasai/helpers/viz_library_types/_viz_library_types.py @@ -1,13 +1,12 @@ from abc import abstractmethod, ABC from typing import Any, Iterable +from pandasai.prompts.generate_python_code import VizLibraryPrompt class BaseVizLibraryType(ABC): @property def template_hint(self) -> str: - return f"""When a user requests to create a chart, utilize the Python -{self.name} library to generate high-quality graphics that will be saved -directly to a file.""" + return VizLibraryPrompt(library=self.name) @property @abstractmethod diff --git a/pandasai/prompts/base.py b/pandasai/prompts/base.py index 7a692a18e..1015ea353 100644 --- a/pandasai/prompts/base.py +++ b/pandasai/prompts/base.py @@ -2,6 +2,7 @@ In order to better handle the instructions, this prompt module is written. """ from abc import ABC, abstractmethod +import string class AbstractPrompt(ABC): @@ -92,12 +93,11 @@ def to_string(self): prompt_args = {} for key, value in self._args.items(): if isinstance(value, AbstractPrompt): + args = [ + arg[1] for arg in string.Formatter().parse(value.template) if arg[1] + ] value.set_vars( - { - k: v - for k, v in self._args.items() - if k != key and not isinstance(v, AbstractPrompt) - } + {k: v for k, v in self._args.items() if k != key and k in args} ) prompt_args[key] = value.to_string() else: diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/prompts/generate_python_code.py index 7ee7aba0e..9bc0813ad 100644 --- a/pandasai/prompts/generate_python_code.py +++ b/pandasai/prompts/generate_python_code.py @@ -26,18 +26,30 @@ class CurrentCodePrompt(FileBasedPrompt): _path_to_template = "assets/prompt_templates/current_code.tmpl" +class DefaultInstructionsPrompt(FileBasedPrompt): + """The default instructions""" + + _path_to_template = "assets/prompt_templates/default_instructions.tmpl" + + class AdvancedReasoningPrompt(FileBasedPrompt): - """The current code""" + """The advanced reasoning instructions""" _path_to_template = "assets/prompt_templates/advanced_reasoning.tmpl" class SimpleReasoningPrompt(FileBasedPrompt): - """The current code""" + """The simple reasoning instructions""" _path_to_template = "assets/prompt_templates/simple_reasoning.tmpl" +class VizLibraryPrompt(FileBasedPrompt): + """Provide information about the visualization library""" + + _path_to_template = "assets/prompt_templates/viz_library.tmpl" + + class GeneratePythonCodePrompt(FileBasedPrompt): """Prompt to generate Python code""" @@ -45,14 +57,11 @@ class GeneratePythonCodePrompt(FileBasedPrompt): def setup(self, **kwargs) -> None: if "custom_instructions" in kwargs: - self._set_instructions(kwargs["custom_instructions"]) - else: - self._set_instructions( - """Analyze the data, using the provided dataframes (`dfs`). -1. Prepare: Preprocessing and cleaning data if necessary -2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) -3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)""" # noqa: E501 + self.set_var( + "instructions", self._format_instructions(kwargs["custom_instructions"]) ) + else: + self.set_var("instructions", DefaultInstructionsPrompt()) if "current_code" in kwargs: self.set_var("current_code", kwargs["current_code"]) @@ -70,8 +79,8 @@ def on_prompt_generation(self) -> None: else: self.set_var("reasoning", SimpleReasoningPrompt()) - def _set_instructions(self, instructions: str): + def _format_instructions(self, instructions: str): lines = instructions.split("\n") indented_lines = [f" {line}" for line in lines[1:]] result = "\n".join([lines[0]] + indented_lines) - self.set_var("instructions", result) + return result diff --git a/tests/prompts/test_generate_python_code_prompt.py b/tests/prompts/test_generate_python_code_prompt.py index 64d70fc21..5358e862a 100644 --- a/tests/prompts/test_generate_python_code_prompt.py +++ b/tests/prompts/test_generate_python_code_prompt.py @@ -90,8 +90,6 @@ def test_str_with_args( Question -{viz_library_type_hint} - This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required @@ -103,6 +101,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + {viz_library_type_hint} At the end, return a dictionary of: {output_type_hint} """ @@ -151,8 +150,6 @@ def test_advanced_reasoning_prompt(self): Question -{viz_library_type_hint} - This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required @@ -164,6 +161,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + {viz_library_type_hint} At the end, return a dictionary of: """ diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index b2f84c68e..9a6fa3a1a 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -219,10 +219,6 @@ def test_run_with_privacy_enforcement(self, llm): User: How many countries are in the dataframe? -When a user requests to create a chart, utilize the Python -matplotlib library to generate high-quality graphics that will be saved -directly to a file. - This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required @@ -234,6 +230,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + If the user requests to create a chart, utilize the Python matplotlib library to generate high-quality graphics that will be saved directly to a file. At the end, return a dictionary of: - type (possible values "string", "number", "dataframe", "plot") - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) @@ -284,10 +281,6 @@ def test_run_passing_output_type(self, llm, output_type, output_type_hint): User: How many countries are in the dataframe? -When a user requests to create a chart, utilize the Python -matplotlib library to generate high-quality graphics that will be saved -directly to a file. - This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required @@ -299,6 +292,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + If the user requests to create a chart, utilize the Python matplotlib library to generate high-quality graphics that will be saved directly to a file. At the end, return a dictionary of: {output_type_hint} """ @@ -1084,8 +1078,6 @@ def test_run_passing_viz_library_type( User: Plot the histogram of countries showing for each the gdp with distinct bar colors -%s - This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required @@ -1097,6 +1089,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + %s At the end, return a dictionary of: - type (possible values "string", "number", "dataframe", "plot") - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) From d7a7cc72fd301e9386797b6c5d7a0d80ff8e276c Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Thu, 2 Nov 2023 00:42:05 +0100 Subject: [PATCH 07/13] Release v1.4.4 --- mkdocs.yml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mkdocs.yml b/mkdocs.yml index 0e0c6df48..abed27537 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -42,7 +42,7 @@ nav: - Documents Building: building_docs.md - License: license.md extra: - version: "1.4.3" + version: "1.4.4" plugins: - search - mkdocstrings: diff --git a/pyproject.toml b/pyproject.toml index c89c81eb7..4947ff2cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pandasai" -version = "1.4.3" +version = "1.4.4" description = "PandasAI is a Python library that integrates generative artificial intelligence capabilities into Pandas, making dataframes conversational." authors = ["Gabriele Venturi"] license = "MIT" From 593d2833ccede36b57d94edf96983cc3975ca9be Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Mon, 6 Nov 2023 15:26:46 +0500 Subject: [PATCH 08/13] feat(sqlConnector): add direct config run sql at runtime --- .../direct_sql_connector.tmpl | 39 +++++++++++++ pandasai/connectors/sql.py | 22 +++++++ pandasai/exceptions.py | 8 +++ pandasai/helpers/code_manager.py | 4 ++ pandasai/prompts/direct_sql_prompt.py | 34 +++++++++++ pandasai/schemas/df_config.py | 2 +- pandasai/smart_dataframe/__init__.py | 12 ++++ pandasai/smart_datalake/__init__.py | 57 +++++++++++++++---- 8 files changed, 167 insertions(+), 11 deletions(-) create mode 100644 pandasai/assets/prompt_templates/direct_sql_connector.tmpl create mode 100644 pandasai/prompts/direct_sql_prompt.py diff --git a/pandasai/assets/prompt_templates/direct_sql_connector.tmpl b/pandasai/assets/prompt_templates/direct_sql_connector.tmpl new file mode 100644 index 000000000..71227310e --- /dev/null +++ b/pandasai/assets/prompt_templates/direct_sql_connector.tmpl @@ -0,0 +1,39 @@ +You are provided with the following samples of sql tables data: + + +{tables} + + + +{conversation} + + +You are provided with following function that executes the sql query, + +def execute_sql_query(sql_query: str) -> pd.Dataframe +"""his method connect to the database executes the sql query and returns the dataframe""" + + +This is the initial python function. Do not change the params. + +```python +# TODO import all the dependencies required +import pandas as pd + +def analyze_data() -> dict: + """ + Analyze the data, using the provided dataframes (`dfs`). + 1. Prepare: generate sql query to get data for analysis (grouping, filtering, aggregating, etc.) + 2. Process: execute the query using execute method available to you which returns dataframe + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + {viz_library_type} + At the end, return a dictionary of: + {output_type_hint} + """ +``` + +Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. +Based on the last message in the conversation: + +- return the updated analyze_data function wrapped within `python ` \ No newline at end of file diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index 0e8c9715d..21e05adab 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -360,6 +360,28 @@ def column_hash(self): def fallback_name(self): return self._config.table + def equals(self, other): + if isinstance(other, self.__class__): + return ( + self._config.dialect, + self._config.driver, + self._config.host, + self._config.port, + self._config.username, + self._config.password, + ) == ( + other._config.dialect, + other._config.driver, + other._config.host, + other._config.port, + other._config.username, + other._config.password, + ) + return False + + def execute_direct_sql_query(self, sql_query): + return pd.read_sql(sql_query, self._connection) + class SqliteConnector(SQLConnector): """ diff --git a/pandasai/exceptions.py b/pandasai/exceptions.py index 56853c79d..819724c63 100644 --- a/pandasai/exceptions.py +++ b/pandasai/exceptions.py @@ -155,3 +155,11 @@ class InvalidWorkspacePathError(Exception): Args: Exception (Exception): InvalidWorkspacePathError """ + + +class InvalidConfigError(Exception): + """ + Raised when config value is not appliable + Args: + Exception (Exception): InvalidConfigError + """ diff --git a/pandasai/helpers/code_manager.py b/pandasai/helpers/code_manager.py index edf7aa0fd..e8b40eaec 100644 --- a/pandasai/helpers/code_manager.py +++ b/pandasai/helpers/code_manager.py @@ -283,6 +283,10 @@ def execute_code(self, code: str, context: CodeExecutionContext) -> Any: analyze_data = environment.get("analyze_data") + if self._config.direct_sql: + environment["execute_sql_query"] = self._dfs[0].get_query_exec_func() + return analyze_data() + return analyze_data(self._get_originals(dfs)) def _get_samples(self, dfs): diff --git a/pandasai/prompts/direct_sql_prompt.py b/pandasai/prompts/direct_sql_prompt.py new file mode 100644 index 000000000..5b27098ec --- /dev/null +++ b/pandasai/prompts/direct_sql_prompt.py @@ -0,0 +1,34 @@ +""" Prompt to explain code generation by the LLM +The previous conversation we had + + +{conversation} + + +Based on the last conversation you generated the following code: + + +{code} + + +Explain how you came up with code for non-technical people without +mentioning technical details or mentioning the libraries used? + +""" +from .file_based_prompt import FileBasedPrompt + + +class DirectSQLPrompt(FileBasedPrompt): + """Prompt to explain code generation by the LLM""" + + _path_to_template = "assets/prompt_templates/direct_sql_connector.tmpl" + + def _prepare_tables_data(self, tables): + tables_join = [] + for table in tables: + table = f"{table.head_csv}
" + tables_join.append(table) + return "\n".join(tables_join) + + def setup(self, tables) -> None: + self.set_var("tables", self._prepare_tables_data(tables)) diff --git a/pandasai/schemas/df_config.py b/pandasai/schemas/df_config.py index 2cd1f312d..fa100a90e 100644 --- a/pandasai/schemas/df_config.py +++ b/pandasai/schemas/df_config.py @@ -30,11 +30,11 @@ class Config(BaseModel): max_retries: int = 3 middlewares: List[Middleware] = Field(default_factory=list) callback: Optional[BaseCallback] = None - lazy_load_connector: bool = True response_parser: Type[ResponseParser] = None llm: Any = None data_viz_library: Optional[VisualizationLibrary] = None log_server: LogServerConfig = None + direct_sql: bool = False class Config: arbitrary_types_allowed = True diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py index 2d037b085..64b16be8b 100644 --- a/pandasai/smart_dataframe/__init__.py +++ b/pandasai/smart_dataframe/__init__.py @@ -725,3 +725,15 @@ def __repr__(self): def __len__(self): return len(self.dataframe) + + def __eq__(self, other): + if isinstance(other, self.__class__): + if self._core.has_connector and other._core.has_connector: + return self._core.connector.equals(other._core.connector) + else: + return self.dataframe == other.dataframe + + return False + + def get_query_exec_func(self): + return self._core.connector.execute_direct_sql_query diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index b480f1e4a..983dd5cec 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -23,11 +23,9 @@ import traceback from pandasai.constants import DEFAULT_CHART_DIRECTORY from pandasai.helpers.skills_manager import SkillsManager - +from pandasai.prompts.direct_sql_prompt import DirectSQLPrompt from pandasai.skills import skill - from pandasai.helpers.query_exec_tracker import QueryExecTracker - from ..helpers.output_types import output_type_factory from ..helpers.viz_library_types import viz_lib_type_factory from pandasai.responses.context import Context @@ -42,13 +40,13 @@ from ..prompts.base import AbstractPrompt from ..prompts.correct_error_prompt import CorrectErrorPrompt from ..prompts.generate_python_code import GeneratePythonCodePrompt -from typing import Union, List, Any, Type, Optional +from typing import Union, List, Any, Optional from ..helpers.code_manager import CodeExecutionContext, CodeManager from ..middlewares.base import Middleware from ..helpers.df_info import DataFrameType from ..helpers.path import find_project_root from ..helpers.viz_library_types.base import VisualizationLibrary -from ..exceptions import AdvancedReasoningDisabledError +from ..exceptions import AdvancedReasoningDisabledError, InvalidConfigError class SmartDatalake: @@ -133,6 +131,9 @@ def __init__( server_config=self._config.log_server, ) + # Checks if direct sql config set they all belong to same sql connector type + self._validate_direct_sql(self._dfs) + def set_instance_type(self, type: str): self._instance = type @@ -247,6 +248,35 @@ def _load_data_viz_library(self, data_viz_library: str): if data_viz_library in (item.value for item in VisualizationLibrary): self._data_viz_library = data_viz_library + def _validate_direct_sql(self, dfs: List) -> None: + """ + Raises error if they don't belong sqlconnector or have different credentials + Args: + dfs (List[SmartDataframe]): list of SmartDataframes + + Raises: + InvalidConfigError: Raise Error in case of config is set but criteria is not met + """ + if self._config.direct_sql: + if dfs and all(df == dfs[0] for df in dfs): + return True + else: + raise InvalidConfigError( + "Direct requires all connector belong to same datasource " + "and have same credentials" + ) + + def _get_chat_prompt(self): + key = "direct_sql_prompt" if self._config.direct_sql else "generate_python_code" + return ( + key, + ( + DirectSQLPrompt(tables=self._dfs) + if self._config.direct_sql + else GeneratePythonCodePrompt() + ), + ) + def add_middlewares(self, *middlewares: Optional[Middleware]): """ Add middlewares to PandasAI instance. @@ -273,7 +303,7 @@ def _assign_prompt_id(self): def _get_prompt( self, key: str, - default_prompt: Type[AbstractPrompt], + default_prompt: AbstractPrompt, default_values: Optional[dict] = None, ) -> AbstractPrompt: """ @@ -292,7 +322,9 @@ def _get_prompt( default_values = {} custom_prompt = self._config.custom_prompts.get(key) - prompt = custom_prompt or default_prompt() + print(key) + print(custom_prompt.__class__) + prompt = custom_prompt or default_prompt # set default values for the prompt prompt.set_config(self._config) @@ -325,6 +357,10 @@ def _get_cache_key(self) -> str: hash = df.column_hash() cache_key += str(hash) + # direct flag to separate out caching for different codegen + if self._config.direct_sql: + cache_key += "direct_sql" + return cache_key def chat(self, query: str, output_type: Optional[str] = None): @@ -395,11 +431,12 @@ def chat(self, query: str, output_type: Optional[str] = None): ): default_values["current_code"] = self._last_code_generated + prompt_key, prompt = self._get_chat_prompt() generate_python_code_instruction = ( self._query_exec_tracker.execute_func( self._get_prompt, - "generate_python_code", - default_prompt=GeneratePythonCodePrompt, + key=prompt_key, + default_prompt=prompt, default_values=default_values, ) ) @@ -553,7 +590,7 @@ def _retry_run_code(self, code: str, e: Exception) -> List: } error_correcting_instruction = self._get_prompt( "correct_error", - default_prompt=CorrectErrorPrompt, + default_prompt=CorrectErrorPrompt(), default_values=default_values, ) From b36f5fb31266c7c7e8e58de158cf2ef67f296dbc Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Mon, 6 Nov 2023 22:40:47 +0500 Subject: [PATCH 09/13] feat(DirectSqlConnector): add sql test cases --- examples/sql_direct_config.py | 49 ++++++++++++ pandasai/connectors/databricks.py | 17 ++++ pandasai/connectors/snowflake.py | 15 ++++ pandasai/connectors/sql.py | 20 +++++ pandasai/exceptions.py | 8 ++ pandasai/helpers/code_manager.py | 15 +++- pandasai/smart_dataframe/__init__.py | 5 +- pandasai/smart_datalake/__init__.py | 15 ++-- tests/connectors/test_sql.py | 92 +++++++++++++++++++++- tests/prompts/test_sql_prompt.py | 113 +++++++++++++++++++++++++++ tests/skills/test_skills.py | 5 +- tests/test_codemanager.py | 7 +- tests/test_smartdataframe.py | 1 + tests/test_smartdatalake.py | 100 ++++++++++++++++++++++++ 14 files changed, 446 insertions(+), 16 deletions(-) create mode 100644 examples/sql_direct_config.py create mode 100644 tests/prompts/test_sql_prompt.py diff --git a/examples/sql_direct_config.py b/examples/sql_direct_config.py new file mode 100644 index 000000000..3a198bb2e --- /dev/null +++ b/examples/sql_direct_config.py @@ -0,0 +1,49 @@ +"""Example of using PandasAI with a CSV file.""" + +from pandasai import SmartDatalake +from pandasai.llm import OpenAI +from pandasai.connectors import PostgreSQLConnector + + +# With a PostgreSQL database +payment_connector = PostgreSQLConnector( + config={ + "host": "localhost", + "port": 5432, + "database": "testdb", + "username": "postgres", + "password": "123456", + "table": "orders", + } +) + +order_details = PostgreSQLConnector( + config={ + "host": "localhost", + "port": 5432, + "database": "testdb", + "username": "postgres", + "password": "123456", + "table": "order_details", + } +) + +products = PostgreSQLConnector( + config={ + "host": "localhost", + "port": 5432, + "database": "testdb", + "username": "postgres", + "password": "123456", + "table": "products", + } +) + + +llm = OpenAI("YOUR_API_KEY") +df = SmartDatalake( + [order_details, payment_connector, products], + config={"llm": llm, "direct_sql": True}, +) +response = df.chat("Return Orders with OrderDetails and counts of distinct Products") +print(response) diff --git a/pandasai/connectors/databricks.py b/pandasai/connectors/databricks.py index 0f70b980b..8ab852cb8 100644 --- a/pandasai/connectors/databricks.py +++ b/pandasai/connectors/databricks.py @@ -63,3 +63,20 @@ def __repr__(self): f"host={self._config.host} port={self._config.port} " f"database={self._config.database} httpPath={str(self._config.httpPath)}" ) + + def equals(self, other): + if isinstance(other, self.__class__): + return ( + self._config.dialect, + self._config.token, + self._config.host, + self._config.port, + self._config.httpPath, + ) == ( + other._config.dialect, + other._config.token, + other._config.host, + other._config.port, + other._config.httpPath, + ) + return False diff --git a/pandasai/connectors/snowflake.py b/pandasai/connectors/snowflake.py index 7120ec45c..abd53a316 100644 --- a/pandasai/connectors/snowflake.py +++ b/pandasai/connectors/snowflake.py @@ -90,3 +90,18 @@ def __repr__(self): f"database={self._config.database} schema={str(self._config.dbSchema)} " f"table={self._config.table}>" ) + + def equals(self, other): + if isinstance(other, self.__class__): + return ( + self._config.dialect, + self._config.account, + self._config.username, + self._config.password, + ) == ( + other._config.dialect, + other._config.account, + other._config.username, + other._config.password, + ) + return False diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index 21e05adab..b0d304369 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -5,6 +5,8 @@ import re import os import pandas as pd + +from pandasai.exceptions import MaliciousQueryError from .base import BaseConnector, SQLConnectorConfig, SqliteConnectorConfig from .base import BaseConnectorConfig from sqlalchemy import create_engine, text, select, asc @@ -379,7 +381,25 @@ def equals(self, other): ) return False + def _is_sql_query_safe(self, query: str): + infected_keywords = [ + r"\bINSERT\b", + r"\bUPDATE\b", + r"\bDELETE\b", + r"\bDROP\b", + r"\bEXEC\b", + r"\bALTER\b", + r"\bCREATE\b", + ] + + return not any( + re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords + ) + def execute_direct_sql_query(self, sql_query): + if not self._is_sql_query_safe(sql_query): + raise MaliciousQueryError("Malicious query in generated code") + return pd.read_sql(sql_query, self._connection) diff --git a/pandasai/exceptions.py b/pandasai/exceptions.py index 819724c63..9f924519f 100644 --- a/pandasai/exceptions.py +++ b/pandasai/exceptions.py @@ -163,3 +163,11 @@ class InvalidConfigError(Exception): Args: Exception (Exception): InvalidConfigError """ + + +class MaliciousQueryError(Exception): + """ + Raise error if malicious query is generated + Args: + Exception (Excpetion): MaliciousQueryError + """ diff --git a/pandasai/helpers/code_manager.py b/pandasai/helpers/code_manager.py index e8b40eaec..f15f2ecc8 100644 --- a/pandasai/helpers/code_manager.py +++ b/pandasai/helpers/code_manager.py @@ -28,9 +28,15 @@ class CodeExecutionContext: _prompt_id: uuid.UUID = None + _can_direct_sql: bool = False _skills_manager: SkillsManager = None - def __init__(self, prompt_id: uuid.UUID, skills_manager: SkillsManager): + def __init__( + self, + prompt_id: uuid.UUID, + skills_manager: SkillsManager, + _can_direct_sql: bool = False, + ): """ Additional Context for code execution Args: @@ -39,6 +45,7 @@ def __init__(self, prompt_id: uuid.UUID, skills_manager: SkillsManager): """ self._skills_manager = skills_manager self._prompt_id = prompt_id + self._can_direct_sql = _can_direct_sql @property def prompt_id(self): @@ -48,6 +55,10 @@ def prompt_id(self): def skills_manager(self): return self._skills_manager + @property + def can_direct_sql(self): + return self._can_direct_sql + class CodeManager: _dfs: List @@ -283,7 +294,7 @@ def execute_code(self, code: str, context: CodeExecutionContext) -> Any: analyze_data = environment.get("analyze_data") - if self._config.direct_sql: + if context.can_direct_sql: environment["execute_sql_query"] = self._dfs[0].get_query_exec_func() return analyze_data() diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py index 64b16be8b..e9cd55d8e 100644 --- a/pandasai/smart_dataframe/__init__.py +++ b/pandasai/smart_dataframe/__init__.py @@ -730,10 +730,11 @@ def __eq__(self, other): if isinstance(other, self.__class__): if self._core.has_connector and other._core.has_connector: return self._core.connector.equals(other._core.connector) - else: - return self.dataframe == other.dataframe return False + def is_connector(self): + return self._core.has_connector + def get_query_exec_func(self): return self._core.connector.execute_direct_sql_query diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 983dd5cec..460716cc1 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -62,6 +62,7 @@ class SmartDatalake: _skills: SkillsManager _instance: str _query_exec_tracker: QueryExecTracker + _can_direct_sql: bool _last_code_generated: str = None _last_reasoning: str = None @@ -132,7 +133,7 @@ def __init__( ) # Checks if direct sql config set they all belong to same sql connector type - self._validate_direct_sql(self._dfs) + self._can_direct_sql = self._validate_direct_sql(self._dfs) def set_instance_type(self, type: str): self._instance = type @@ -257,14 +258,16 @@ def _validate_direct_sql(self, dfs: List) -> None: Raises: InvalidConfigError: Raise Error in case of config is set but criteria is not met """ - if self._config.direct_sql: + + if self._config.direct_sql and all(df.is_connector() for df in dfs): if dfs and all(df == dfs[0] for df in dfs): return True else: raise InvalidConfigError( - "Direct requires all connector belong to same datasource " + "Direct requires all SQLConnector and they belong to same datasource " "and have same credentials" ) + return False def _get_chat_prompt(self): key = "direct_sql_prompt" if self._config.direct_sql else "generate_python_code" @@ -322,8 +325,6 @@ def _get_prompt( default_values = {} custom_prompt = self._config.custom_prompts.get(key) - print(key) - print(custom_prompt.__class__) prompt = custom_prompt or default_prompt # set default values for the prompt @@ -469,7 +470,9 @@ def chat(self, query: str, output_type: Optional[str] = None): while retry_count < self._config.max_retries: try: # Execute the code - context = CodeExecutionContext(self._last_prompt_id, self._skills) + context = CodeExecutionContext( + self._last_prompt_id, self._skills, self._can_direct_sql + ) result = self._code_manager.execute_code( code=code_to_run, context=context, diff --git a/tests/connectors/test_sql.py b/tests/connectors/test_sql.py index e9d55fbb7..83f2e6723 100644 --- a/tests/connectors/test_sql.py +++ b/tests/connectors/test_sql.py @@ -2,7 +2,8 @@ import pandas as pd from unittest.mock import Mock, patch from pandasai.connectors.base import SQLConnectorConfig -from pandasai.connectors.sql import SQLConnector +from pandasai.connectors.sql import PostgreSQLConnector, SQLConnector +from pandasai.exceptions import MaliciousQueryError class TestSQLConnector(unittest.TestCase): @@ -104,3 +105,92 @@ def test_fallback_name_property(self): # Test fallback_name property fallback_name = self.connector.fallback_name self.assertEqual(fallback_name, "your_table") + + def test_is_sql_query_safe_safe_query(self): + safe_query = "SELECT * FROM users WHERE username = 'John'" + result = self.connector._is_sql_query_safe(safe_query) + assert result is True + + def test_is_sql_query_safe_malicious_query(self): + malicious_query = "DROP TABLE users" + result = self.connector._is_sql_query_safe(malicious_query) + assert result is False + + @patch("pandasai.connectors.sql.pd.read_sql", autospec=True) + def test_execute_direct_sql_query_safe_query(self, mock_sql): + safe_query = "SELECT * FROM users WHERE username = 'John'" + expected_data = pd.DataFrame({"Column1": [1, 2, 3], "Column2": [4, 5, 6]}) + mock_sql.return_value = expected_data + result = self.connector.execute_direct_sql_query(safe_query) + assert isinstance(result, pd.DataFrame) + + def test_execute_direct_sql_query_malicious_query(self): + malicious_query = "DROP TABLE users" + try: + self.connector.execute_direct_sql_query(malicious_query) + assert False, "MaliciousQueryError not raised" + except MaliciousQueryError: + pass + + @patch("pandasai.connectors.SQLConnector._init_connection") + def test_equals_identical_configs(self, mock_init_connection): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + connector_2 = SQLConnector(self.config) + + assert self.connector.equals(connector_2) + + @patch("pandasai.connectors.SQLConnector._load_connector_config") + @patch("pandasai.connectors.SQLConnector._init_connection") + def test_equals_different_configs( + self, mock_load_connector_config, mock_init_connection + ): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username_differ", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + connector_2 = SQLConnector(self.config) + + assert not self.connector.equals(connector_2) + + @patch("pandasai.connectors.SQLConnector._init_connection") + def test_equals_different_connector(self, mock_init_connection): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username_differ", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + connector_2 = PostgreSQLConnector(self.config) + + assert not self.connector.equals(connector_2) diff --git a/tests/prompts/test_sql_prompt.py b/tests/prompts/test_sql_prompt.py new file mode 100644 index 000000000..92df71481 --- /dev/null +++ b/tests/prompts/test_sql_prompt.py @@ -0,0 +1,113 @@ +"""Unit tests for the correct error prompt class""" +import sys + +import pandas as pd +import pytest +from pandasai import SmartDataframe +from pandasai.llm.fake import FakeLLM +from pandasai.prompts.direct_sql_prompt import DirectSQLPrompt +from pandasai.helpers.viz_library_types import ( + MatplotlibVizLibraryType, + viz_lib_map, + viz_lib_type_factory, +) +from pandasai.helpers.output_types import ( + output_type_factory, + DefaultOutputType, + output_types_map, +) + + +class TestDirectSqlPrompt: + """Unit tests for the correct error prompt class""" + + @pytest.mark.parametrize( + "save_charts_path,output_type_hint,viz_library_type_hint", + [ + ( + "exports/charts", + DefaultOutputType().template_hint, + MatplotlibVizLibraryType().template_hint, + ), + ( + "custom/dir/for/charts", + DefaultOutputType().template_hint, + MatplotlibVizLibraryType().template_hint, + ), + *[ + ( + "exports/charts", + output_type_factory(type_).template_hint, + viz_lib_type_factory(viz_type_).template_hint, + ) + for type_ in output_types_map + for viz_type_ in viz_lib_map + ], + ], + ) + def test_direct_sql_prompt_with_params( + self, save_charts_path, output_type_hint, viz_library_type_hint + ): + """Test that the __str__ method is implemented""" + + llm = FakeLLM("plt.show()") + dfs = [ + SmartDataframe( + pd.DataFrame({}), + config={"llm": llm}, + ) + ] + + prompt = DirectSQLPrompt(tables=dfs) + prompt.set_var("dfs", dfs) + prompt.set_var("conversation", "What is the correct code?") + prompt.set_var("output_type_hint", output_type_hint) + prompt.set_var("save_charts_path", save_charts_path) + prompt.set_var("viz_library_type", viz_library_type_hint) + prompt_content = prompt.to_string() + if sys.platform.startswith("win"): + prompt_content = prompt_content.replace("\r\n", "\n") + + assert ( + prompt_content + == f'''You are provided with the following samples of sql tables data: + + + +
+ + + +What is the correct code? + + +You are provided with following function that executes the sql query, + +def execute_sql_query(sql_query: str) -> pd.Dataframe +"""his method connect to the database executes the sql query and returns the dataframe""" + + +This is the initial python function. Do not change the params. + +```python +# TODO import all the dependencies required +import pandas as pd + +def analyze_data() -> dict: + """ + Analyze the data, using the provided dataframes (`dfs`). + 1. Prepare: generate sql query to get data for analysis (grouping, filtering, aggregating, etc.) + 2. Process: execute the query using execute method available to you which returns dataframe + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + {viz_library_type_hint} + At the end, return a dictionary of: + {output_type_hint} + """ +``` + +Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. +Based on the last message in the conversation: + +- return the updated analyze_data function wrapped within `python `''' # noqa: E501 + ) diff --git a/tests/skills/test_skills.py b/tests/skills/test_skills.py index ed979951c..1e8ddcc50 100644 --- a/tests/skills/test_skills.py +++ b/tests/skills/test_skills.py @@ -71,8 +71,7 @@ def code_manager(self, smart_dataframe: SmartDataframe): @pytest.fixture def exec_context(self) -> MagicMock: - context = MagicMock(spec=CodeExecutionContext) - return context + return CodeExecutionContext(uuid.uuid4(), SkillsManager()) @pytest.fixture def agent(self, llm, sample_df): @@ -317,7 +316,7 @@ def test_run_prompt_without_skills(self, agent): ) def test_code_exec_with_skills_no_use( - self, code_manager: CodeManager, exec_context: MagicMock + self, code_manager: CodeManager, exec_context ): code = """def analyze_data(dfs): return {'type': 'number', 'value': 1 + 1}""" diff --git a/tests/test_codemanager.py b/tests/test_codemanager.py index 0df494429..f8af9de63 100644 --- a/tests/test_codemanager.py +++ b/tests/test_codemanager.py @@ -1,11 +1,13 @@ """Unit tests for the CodeManager class""" from typing import Optional from unittest.mock import MagicMock, Mock, patch +import uuid import pandas as pd import pytest from pandasai.exceptions import BadImportError, NoCodeFoundError +from pandasai.helpers.skills_manager import SkillsManager from pandasai.llm.fake import FakeLLM from pandasai.smart_dataframe import SmartDataframe @@ -73,8 +75,7 @@ def code_manager(self, smart_dataframe: SmartDataframe): @pytest.fixture def exec_context(self) -> MagicMock: - context = MagicMock(spec=CodeExecutionContext) - return context + return CodeExecutionContext(uuid.uuid4(), SkillsManager()) def test_run_code_for_calculations( self, code_manager: CodeManager, exec_context: MagicMock @@ -97,6 +98,8 @@ def test_clean_code_remove_builtins( builtins_code = """import set def analyze_data(dfs): return {'type': 'number', 'value': set([1, 2, 3])}""" + + exec_context._can_direct_sql = False assert code_manager.execute_code(builtins_code, exec_context)["value"] == { 1, 2, diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index 9a6fa3a1a..2552cc611 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -199,6 +199,7 @@ def analyze_data(dfs): ) output_df = smart_dataframe.chat("Set column b to column a + 1") + print("output", output_df) assert output_df["a"].tolist() == [1, 2, 3] assert output_df["b"].tolist() == [2, 3, 4] diff --git a/tests/test_smartdatalake.py b/tests/test_smartdatalake.py index bb33936bc..3195a53de 100644 --- a/tests/test_smartdatalake.py +++ b/tests/test_smartdatalake.py @@ -9,12 +9,18 @@ import pytest from pandasai import SmartDataframe, SmartDatalake +from pandasai.connectors.base import SQLConnectorConfig +from pandasai.connectors.sql import PostgreSQLConnector, SQLConnector +from pandasai.exceptions import InvalidConfigError from pandasai.helpers.code_manager import CodeManager from pandasai.llm.fake import FakeLLM from pandasai.middlewares import Middleware from langchain import OpenAI +from pandasai.prompts.direct_sql_prompt import DirectSQLPrompt +from pandasai.prompts.generate_python_code import GeneratePythonCodePrompt + class TestSmartDatalake: """Unit tests for the SmartDatlake class""" @@ -66,6 +72,44 @@ def sample_df(self): } ) + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def sql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return SQLConnector(self.config) + + @pytest.fixture + @patch("pandasai.connectors.sql.create_engine", autospec=True) + def pgsql_connector(self, create_engine): + # Define your ConnectorConfig instance here + self.config = SQLConnectorConfig( + dialect="mysql", + driver="pymysql", + username="your_username", + password="your_password", + host="your_host", + port=443, + database="your_database", + table="your_table", + where=[["column_name", "=", "value"]], + ).dict() + + # Create an instance of SQLConnector + return PostgreSQLConnector(self.config) + @pytest.fixture def smart_dataframe(self, llm, sample_df): return SmartDataframe(sample_df, config={"llm": llm, "enable_cache": False}) @@ -229,3 +273,59 @@ def analyze_data(dfs): smart_datalake.chat("How many countries are in the dataframe?") assert smart_datalake.last_answer == "Custom answer" assert smart_datalake.last_reasoning == "Custom reasoning" + + def test_get_chat_prompt(self, smart_datalake: SmartDatalake): + # Test case 1: direct_sql is True + smart_datalake._config.direct_sql = True + gen_key, gen_prompt = smart_datalake._get_chat_prompt() + expected_key = "direct_sql_prompt" + assert gen_key == expected_key + assert isinstance(gen_prompt, DirectSQLPrompt) + + # Test case 2: direct_sql is False + smart_datalake._config.direct_sql = False + gen_key, gen_prompt = smart_datalake._get_chat_prompt() + expected_key = "generate_python_code" + assert gen_key == expected_key + assert isinstance(gen_prompt, GeneratePythonCodePrompt) + + def test_validate_true_direct_sql_with_non_connector(self, llm, sample_df): + # raise exception with non connector + SmartDatalake( + [sample_df], + config={"llm": llm, "enable_cache": False, "direct_sql": True}, + ) + + def test_validate_direct_sql_with_connector(self, llm, sql_connector): + # not exception is raised using single connector + SmartDatalake( + [sql_connector], + config={"llm": llm, "enable_cache": False, "direct_sql": True}, + ) + + def test_validate_false_direct_sql_with_connector(self, llm, sql_connector): + # not exception is raised using single connector + SmartDatalake( + [sql_connector], + config={"llm": llm, "enable_cache": False, "direct_sql": False}, + ) + + def test_validate_false_direct_sql_with_two_different_connector( + self, llm, sql_connector, pgsql_connector + ): + # not exception is raised using single connector + SmartDatalake( + [sql_connector, pgsql_connector], + config={"llm": llm, "enable_cache": False, "direct_sql": False}, + ) + + def test_validate_true_direct_sql_with_two_different_connector( + self, llm, sql_connector, pgsql_connector + ): + # not exception is raised using single connector + # raise exception when two different connector + with pytest.raises(InvalidConfigError): + SmartDatalake( + [sql_connector, pgsql_connector], + config={"llm": llm, "enable_cache": False, "direct_sql": True}, + ) From 36da80c4bc492bcc6062c5aa70cba7887ca8b843 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Mon, 6 Nov 2023 22:50:07 +0500 Subject: [PATCH 10/13] fix: minor leftovers --- pandasai/connectors/sql.py | 2 +- pandasai/smart_datalake/__init__.py | 2 +- tests/test_smartdataframe.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index b0d304369..5d4443563 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -398,7 +398,7 @@ def _is_sql_query_safe(self, query: str): def execute_direct_sql_query(self, sql_query): if not self._is_sql_query_safe(sql_query): - raise MaliciousQueryError("Malicious query in generated code") + raise MaliciousQueryError("Malicious query is generated in code") return pd.read_sql(sql_query, self._connection) diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 460716cc1..111eaf155 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -260,7 +260,7 @@ def _validate_direct_sql(self, dfs: List) -> None: """ if self._config.direct_sql and all(df.is_connector() for df in dfs): - if dfs and all(df == dfs[0] for df in dfs): + if all(df == dfs[0] for df in dfs): return True else: raise InvalidConfigError( diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index 2552cc611..9a6fa3a1a 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -199,7 +199,6 @@ def analyze_data(dfs): ) output_df = smart_dataframe.chat("Set column b to column a + 1") - print("output", output_df) assert output_df["a"].tolist() == [1, 2, 3] assert output_df["b"].tolist() == [2, 3, 4] From 976bf4e4a9cacfb3c9282da5da2492c6593b564e Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 7 Nov 2023 13:10:56 +0500 Subject: [PATCH 11/13] fix(orders): check examples of different tables --- examples/sql_direct_config.py | 4 ++-- pandasai/prompts/direct_sql_prompt.py | 4 ++-- tests/prompts/test_sql_prompt.py | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/sql_direct_config.py b/examples/sql_direct_config.py index 3a198bb2e..b8c52dfb9 100644 --- a/examples/sql_direct_config.py +++ b/examples/sql_direct_config.py @@ -40,10 +40,10 @@ ) -llm = OpenAI("YOUR_API_KEY") +llm = OpenAI("OPEN_API_KEY") df = SmartDatalake( [order_details, payment_connector, products], config={"llm": llm, "direct_sql": True}, ) -response = df.chat("Return Orders with OrderDetails and counts of distinct Products") +response = df.chat("return orders with count of distinct products") print(response) diff --git a/pandasai/prompts/direct_sql_prompt.py b/pandasai/prompts/direct_sql_prompt.py index 5b27098ec..78fa4e406 100644 --- a/pandasai/prompts/direct_sql_prompt.py +++ b/pandasai/prompts/direct_sql_prompt.py @@ -26,9 +26,9 @@ class DirectSQLPrompt(FileBasedPrompt): def _prepare_tables_data(self, tables): tables_join = [] for table in tables: - table = f"{table.head_csv}
" + table = f'\n{table.head_csv}\n
' tables_join.append(table) - return "\n".join(tables_join) + return "\n\n".join(tables_join) def setup(self, tables) -> None: self.set_var("tables", self._prepare_tables_data(tables)) diff --git a/tests/prompts/test_sql_prompt.py b/tests/prompts/test_sql_prompt.py index 92df71481..c6bd65362 100644 --- a/tests/prompts/test_sql_prompt.py +++ b/tests/prompts/test_sql_prompt.py @@ -73,7 +73,9 @@ def test_direct_sql_prompt_with_params( == f'''You are provided with the following samples of sql tables data: - +
+ +
From c62bcbbc70ac91ee965d602b462874a599359821 Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Tue, 7 Nov 2023 08:22:13 +0000 Subject: [PATCH 12/13] 'Refactored by Sourcery' --- pandasai/prompts/generate_python_code.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/prompts/generate_python_code.py index 9bc0813ad..a54d4a794 100644 --- a/pandasai/prompts/generate_python_code.py +++ b/pandasai/prompts/generate_python_code.py @@ -82,5 +82,4 @@ def on_prompt_generation(self) -> None: def _format_instructions(self, instructions: str): lines = instructions.split("\n") indented_lines = [f" {line}" for line in lines[1:]] - result = "\n".join([lines[0]] + indented_lines) - return result + return "\n".join([lines[0]] + indented_lines) From a02f332430ead36c759c374f13d3f40d91ebaaf4 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 7 Nov 2023 15:26:47 +0500 Subject: [PATCH 13/13] chore(sqlprompt): add description only when we have it --- examples/sql_direct_config.py | 14 ++++++++++++-- pandasai/prompts/direct_sql_prompt.py | 8 +++++++- tests/prompts/test_sql_prompt.py | 2 +- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/examples/sql_direct_config.py b/examples/sql_direct_config.py index b8c52dfb9..97b2abef8 100644 --- a/examples/sql_direct_config.py +++ b/examples/sql_direct_config.py @@ -3,10 +3,11 @@ from pandasai import SmartDatalake from pandasai.llm import OpenAI from pandasai.connectors import PostgreSQLConnector +from pandasai.smart_dataframe import SmartDataframe # With a PostgreSQL database -payment_connector = PostgreSQLConnector( +order = PostgreSQLConnector( config={ "host": "localhost", "port": 5432, @@ -41,8 +42,17 @@ llm = OpenAI("OPEN_API_KEY") + + +order_details_smart_df = SmartDataframe( + order_details, + config={"llm": llm, "direct_sql": True}, + description="Contain user order details", +) + + df = SmartDatalake( - [order_details, payment_connector, products], + [order_details_smart_df, order, products], config={"llm": llm, "direct_sql": True}, ) response = df.chat("return orders with count of distinct products") diff --git a/pandasai/prompts/direct_sql_prompt.py b/pandasai/prompts/direct_sql_prompt.py index 78fa4e406..f05e40bb0 100644 --- a/pandasai/prompts/direct_sql_prompt.py +++ b/pandasai/prompts/direct_sql_prompt.py @@ -26,7 +26,13 @@ class DirectSQLPrompt(FileBasedPrompt): def _prepare_tables_data(self, tables): tables_join = [] for table in tables: - table = f'\n{table.head_csv}\n
' + table_description_tag = ( + f' description="{table.table_description}"' + if table.table_description is not None + else "" + ) + table_head_tag = f'' + table = f"{table_head_tag}\n{table.head_csv}\n
" tables_join.append(table) return "\n\n".join(tables_join) diff --git a/tests/prompts/test_sql_prompt.py b/tests/prompts/test_sql_prompt.py index c6bd65362..533a50371 100644 --- a/tests/prompts/test_sql_prompt.py +++ b/tests/prompts/test_sql_prompt.py @@ -73,7 +73,7 @@ def test_direct_sql_prompt_with_params( == f'''You are provided with the following samples of sql tables data: - +