Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: config plot libraries #706

Merged
merged 5 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandasai/assets/prompt_templates/generate_python_code.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ You are provided with the following pandas DataFrames:
{conversation}
</conversation>

{viz_library_type}

This is the initial python function. Do not change the params. Given the context, use the right dataframes.
```python
{current_code}
Expand Down
63 changes: 63 additions & 0 deletions pandasai/helpers/viz_library_types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging
from typing import Union, Optional
from .base import VisualizationLibrary

from ._viz_library_types import (
MatplotlibVizLibraryType,
PlotlyVizLibraryType,
SeabornVizLibraryType,
)
from .. import Logger


viz_lib_map = {
VisualizationLibrary.MATPLOTLIB.value: MatplotlibVizLibraryType,
VisualizationLibrary.PLOTLY.value: PlotlyVizLibraryType,
VisualizationLibrary.SEABORN.value: SeabornVizLibraryType,
}


def viz_lib_type_factory(
viz_lib_type: str = None, logger: Optional[Logger] = None
) -> Union[MatplotlibVizLibraryType, PlotlyVizLibraryType, SeabornVizLibraryType,]:
"""
Factory function to get appropriate instance for viz library type.

Uses `viz_library_types_map` to determine the viz library type class.

Args:
viz_lib_type (Optional[str]): A name of the viz library type.
Defaults to None, an instance of `DefaultVizLibraryType` will be
returned.
logger (Optional[str]): If passed, collects logs about correctness
of the `viz_library_type` argument and what kind of VizLibraryType
is created.

Returns:
(Union[
MatplotlibVizLibraryType,
PlotlyVizLibraryType,
SeabornVizLibraryType,
DefaultVizLibraryType
]): An instance of the output type.
"""

if viz_lib_type is not None and viz_lib_type not in viz_lib_map and logger:
possible_types_msg = ", ".join(f"'{type_}'" for type_ in viz_lib_map)
logger.log(
f"Unknown value for the parameter `viz_library_type`: '{viz_lib_type}'."
f"Possible values are: {possible_types_msg} and None for default "
f"viz library type (miscellaneous).",
level=logging.WARNING,
)

viz_lib_default = MatplotlibVizLibraryType
viz_lib_type_helper = viz_lib_map.get(viz_lib_type, viz_lib_default)()

if logger:
logger.log(
f"{viz_lib_type_helper.__class__} is going to be used.", level=logging.DEBUG
)

return viz_lib_type_helper

65 changes: 65 additions & 0 deletions pandasai/helpers/viz_library_types/_viz_library_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from abc import abstractmethod, ABC
from typing import Any, Iterable


class BaseVizLibraryType(ABC):
@property
def template_hint(self) -> str:
return f"""When a user requests to create a chart, utilize the Python
{self.name} library to generate high-quality graphics that will be saved
directly to a file."""

@property
@abstractmethod
def name(self) -> str:
...

def _validate_type(self, actual_type: str) -> bool:
return actual_type == self.name

def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable[str]]:
"""
Validate 'type' and 'constraint' from the result dict.

Args:
result (dict[str, Any]): The result of code execution in
dict representation. Should have the following schema:
{
"viz_library_type": <viz_library_name>
}

Returns:
(tuple(bool, Iterable(str)):
Boolean value whether the result matches output type
and collection of logs containing messages about
'type' mismatches.
"""
validation_logs = []
actual_type = result.get("type")

type_ok = self._validate_type(actual_type)
if not type_ok:
validation_logs.append(
f"The result dict contains inappropriate 'type'. "
f"Expected '{self.name}', actual '{actual_type}'."
)

return type_ok, validation_logs


class MatplotlibVizLibraryType(BaseVizLibraryType):
@property
def name(self):
return "matplotlib"


class PlotlyVizLibraryType(BaseVizLibraryType):
@property
def name(self):
return "plotly"


class SeabornVizLibraryType(BaseVizLibraryType):
@property
def name(self):
return "seaborn"
20 changes: 20 additions & 0 deletions pandasai/helpers/viz_library_types/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from enum import Enum


class VisualizationLibrary(str, Enum):
"""
VisualizationLibrary is an enumeration that represents the available
data visualization libraries.

Attributes:
MATPLOTLIB (str): Represents the Matplotlib library.
SEABORN (str): Represents the Seaborn library.
PLOTLY (str): Represents the Plotly library.
"""

MATPLOTLIB = "matplotlib"
SEABORN = "seaborn"
PLOTLY = "plotly"

DEFAULT = "default"

2 changes: 2 additions & 0 deletions pandasai/schemas/df_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..callbacks.base import BaseCallback
from ..llm import LLM, LangchainLLM
from ..exceptions import LLMNotFoundError
from ..helpers.viz_library_types.base import VisualizationLibrary


class LogServerConfig(TypedDict):
Expand All @@ -32,6 +33,7 @@ class Config(BaseModel):
lazy_load_connector: bool = True
response_parser: Type[ResponseParser] = None
llm: Any = None
data_viz_library: Optional[VisualizationLibrary] = None
log_server: LogServerConfig = None

class Config:
Expand Down
2 changes: 1 addition & 1 deletion pandasai/smart_dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,4 +723,4 @@ def __repr__(self):
return self.dataframe.__repr__()

def __len__(self):
return len(self.dataframe)
return len(self.dataframe)
29 changes: 28 additions & 1 deletion pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pandasai.helpers.query_exec_tracker import QueryExecTracker

from ..helpers.output_types import output_type_factory
from ..helpers.viz_library_types import viz_lib_type_factory
from pandasai.responses.context import Context
from pandasai.responses.response_parser import ResponseParser
from ..llm.base import LLM
Expand All @@ -45,6 +46,7 @@
from ..middlewares.base import Middleware
from ..helpers.df_info import DataFrameType
from ..helpers.path import find_project_root
from ..helpers.viz_library_types.base import VisualizationLibrary
from ..exceptions import AdvancedReasoningDisabledError


Expand All @@ -68,6 +70,8 @@ class SmartDatalake:
_last_result: str = None
_last_error: str = None

_viz_lib: str = None

def __init__(
self,
dfs: List[Union[DataFrameType, Any]],
Expand Down Expand Up @@ -117,6 +121,9 @@ def __init__(
else:
self._response_parser = ResponseParser(context)

if self._config.data_viz_library:
self._viz_lib = self._config.data_viz_library.value

self._conversation_id = uuid.uuid4()

self._instance = self.__class__.__name__
Expand Down Expand Up @@ -191,6 +198,10 @@ def _load_config(self, config: Union[Config, dict]):
self._load_llm(config["llm"])
config["llm"] = self._llm

if config.get("data_viz_library"):
self._load_data_viz_library(config["data_viz_library"])
config["data_viz_library"] = self._data_viz_library

self._config = Config(**config)

def _load_llm(self, llm: LLM):
Expand All @@ -211,6 +222,21 @@ def _load_llm(self, llm: LLM):

self._llm = llm

def _load_data_viz_library(self, data_viz_library: str):
"""
Load the appropriate instance for viz library type to use.

Args:
data_viz_library (enum): TODO

Raises:
TODO
"""

self._data_viz_library = VisualizationLibrary.DEFAULT.value
if data_viz_library in (item.value for item in VisualizationLibrary):
self._data_viz_library = data_viz_library

def add_middlewares(self, *middlewares: Optional[Middleware]):
"""
Add middlewares to PandasAI instance.
Expand Down Expand Up @@ -273,7 +299,6 @@ def _get_prompt(
prompt.set_var(key, value)

self.logger.log(f"Using prompt: {prompt}")

return prompt

def _get_cache_key(self) -> str:
Expand Down Expand Up @@ -333,6 +358,7 @@ def chat(self, query: str, output_type: Optional[str] = None):

try:
output_type_helper = output_type_factory(output_type, logger=self.logger)
viz_lib_helper = viz_lib_type_factory(self._viz_lib, logger=self.logger)

if (
self._config.enable_cache
Expand All @@ -349,6 +375,7 @@ def chat(self, query: str, output_type: Optional[str] = None):
# TODO: find a better way to determine the engine,
"engine": self._dfs[0].engine,
"output_type_hint": output_type_helper.template_hint,
"viz_library_type": viz_lib_helper.template_hint,
}

if (
Expand Down
46 changes: 38 additions & 8 deletions tests/prompts/test_generate_python_code_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,54 @@
)
from pandasai.prompts import GeneratePythonCodePrompt
from pandasai.llm.fake import FakeLLM
from pandasai.helpers.viz_library_types import (
MatplotlibVizLibraryType,
viz_lib_map,
viz_lib_type_factory,
)


class TestGeneratePythonCodePrompt:
"""Unit tests for the generate python code prompt class"""

@pytest.mark.parametrize(
"save_charts_path,output_type_hint",
"save_charts_path,output_type_hint,viz_library_type_hint",
[
("exports/charts", DefaultOutputType().template_hint),
("custom/dir/for/charts", DefaultOutputType().template_hint),
(
"exports/charts",
DefaultOutputType().template_hint,
MatplotlibVizLibraryType().template_hint,
),
(
"custom/dir/for/charts",
DefaultOutputType().template_hint,
MatplotlibVizLibraryType().template_hint,
),
*[
("exports/charts", output_type_factory(type_).template_hint)
(
"exports/charts",
output_type_factory(type_).template_hint,
viz_lib_type_factory(viz_type_).template_hint,
)
for type_ in output_types_map
for viz_type_ in viz_lib_map
],
],
)
def test_str_with_args(self, save_charts_path, output_type_hint):
def test_str_with_args(
self, save_charts_path, output_type_hint, viz_library_type_hint
):
"""Test casting of prompt to string and interpolation of context.

Parameterized for the following cases:
* `save_charts_path` is "exports/charts", `output_type_hint` is default
* `save_charts_path` is "exports/charts", `output_type_hint` is default,
`viz_library_type_hint` is default
* `save_charts_path` is "custom/dir/for/charts", `output_type_hint`
is default
is default, `viz_library_type_hint` is default
* `save_charts_path` is "exports/charts", `output_type_hint` any of
possible types in `pandasai.helpers.output_types.output_types_map`
possible types in `pandasai.helpers.output_types.output_types_map`,
`viz_library_type_hint` any of
possible types in `pandasai.helpers.viz_library_types.viz_library_types_map`
"""

llm = FakeLLM("plt.show()")
Expand All @@ -51,6 +74,7 @@ def test_str_with_args(self, save_charts_path, output_type_hint):
prompt.set_var("conversation", "Question")
prompt.set_var("save_charts_path", save_charts_path)
prompt.set_var("output_type_hint", output_type_hint)
prompt.set_var("viz_library_type", viz_library_type_hint)
prompt.set_var("skills", "")

expected_prompt_content = f'''You are provided with the following pandas DataFrames:
Expand All @@ -66,6 +90,8 @@ def test_str_with_args(self, save_charts_path, output_type_hint):
Question
</conversation>

{viz_library_type_hint}

This is the initial python function. Do not change the params. Given the context, use the right dataframes.
```python
# TODO import all the dependencies required
Expand Down Expand Up @@ -97,6 +123,7 @@ def test_advanced_reasoning_prompt(self):
"""

llm = FakeLLM("plt.show()")
viz_library_type_hint = ""
dfs = [
SmartDataframe(
pd.DataFrame({"a": [1], "b": [4]}),
Expand All @@ -110,6 +137,7 @@ def test_advanced_reasoning_prompt(self):
prompt.set_var("save_charts_path", "")
prompt.set_var("output_type_hint", "")
prompt.set_var("skills", "")
prompt.set_var("viz_library_type", "")

expected_prompt_content = f'''You are provided with the following pandas DataFrames:

Expand All @@ -124,6 +152,8 @@ def test_advanced_reasoning_prompt(self):
Question
</conversation>

{viz_library_type_hint}

This is the initial python function. Do not change the params. Given the context, use the right dataframes.
```python
# TODO import all the dependencies required
Expand Down
Loading