Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config plot libraries #705

Merged
merged 4 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandasai/assets/prompt_templates/generate_python_code.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ You are provided with the following pandas DataFrames:
{conversation}
</conversation>

{viz_library_type}

This is the initial python function. Do not change the params. Given the context, use the right dataframes.
```python
{current_code}
Expand Down
63 changes: 63 additions & 0 deletions pandasai/helpers/viz_library_types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging
from typing import Union, Optional
from .base import VisualizationLibrary

from ._viz_library_types import (
MatplotlibVizLibraryType,
PlotlyVizLibraryType,
SeabornVizLibraryType,
)
from .. import Logger


viz_lib_map = {
VisualizationLibrary.MATPLOTLIB.value: MatplotlibVizLibraryType,
VisualizationLibrary.PLOTLY.value: PlotlyVizLibraryType,
VisualizationLibrary.SEABORN.value: SeabornVizLibraryType,
}


def viz_lib_type_factory(
viz_lib_type: str = None, logger: Optional[Logger] = None
) -> Union[MatplotlibVizLibraryType, PlotlyVizLibraryType, SeabornVizLibraryType,]:
"""
Factory function to get appropriate instance for viz library type.

Uses `viz_library_types_map` to determine the viz library type class.

Args:
viz_lib_type (Optional[str]): A name of the viz library type.
Defaults to None, an instance of `DefaultVizLibraryType` will be
returned.
logger (Optional[str]): If passed, collects logs about correctness
of the `viz_library_type` argument and what kind of VizLibraryType
is created.

Returns:
(Union[
MatplotlibVizLibraryType,
PlotlyVizLibraryType,
SeabornVizLibraryType,
DefaultVizLibraryType
]): An instance of the output type.
"""

if viz_lib_type is not None and viz_lib_type not in viz_lib_map and logger:
possible_types_msg = ", ".join(f"'{type_}'" for type_ in viz_lib_map)
logger.log(
f"Unknown value for the parameter `viz_library_type`: '{viz_lib_type}'."
f"Possible values are: {possible_types_msg} and None for default "
f"viz library type (miscellaneous).",
level=logging.WARNING,
)

viz_lib_default = MatplotlibVizLibraryType
viz_lib_type_helper = viz_lib_map.get(viz_lib_type, viz_lib_default)()

if logger:
logger.log(
f"{viz_lib_type_helper.__class__} is going to be used.", level=logging.DEBUG
)

return viz_lib_type_helper

88 changes: 88 additions & 0 deletions pandasai/helpers/viz_library_types/_viz_library_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from abc import abstractmethod, ABC
from typing import Any, Iterable


class BaseVizLibraryType(ABC):
@property
@abstractmethod
def template_hint(self) -> str:
...

@property
@abstractmethod
def name(self) -> str:
...

def _validate_type(self, actual_type: str) -> bool:
if actual_type != self.name:
return False
return True

def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable[str]]:
"""
Validate 'type' and 'constraint' from the result dict.

Args:
result (dict[str, Any]): The result of code execution in
dict representation. Should have the following schema:
{
"viz_library_type": <viz_library_name>
}

Returns:
(tuple(bool, Iterable(str)):
Boolean value whether the result matches output type
and collection of logs containing messages about
'type' mismatches.
"""
validation_logs = []
actual_type = result.get("type")

type_ok = self._validate_type(actual_type)
if not type_ok:
validation_logs.append(
f"The result dict contains inappropriate 'type'. "
f"Expected '{self.name}', actual '{actual_type}'."
)

return type_ok, validation_logs


class MatplotlibVizLibraryType(BaseVizLibraryType):
@property
def template_hint(self):
return """When a user requests to create a chart, utilize the Python matplotlib
library to generate high-quality graphics that will be saved
directly to a file.
If you import matplotlib use the 'agg' backend for rendering plots."""

@property
def name(self):
return "matplotlib"


class PlotlyVizLibraryType(BaseVizLibraryType):
@property
def template_hint(self):
return """When a user requests to create a chart, utilize the Python plotly
library to generate high-quality graphics that will be saved
directly to a file.
If you import matplotlib use the 'agg' backend for rendering plots."""

@property
def name(self):
return "plotly"


class SeabornVizLibraryType(BaseVizLibraryType):
@property
def template_hint(self):
return """When a user requests to create a chart, utilize the Python Seaborn
library to generate high-quality graphics that will be saved
directly to a file.
If you import matplotlib use the 'agg' backend for rendering plots."""

@property
def name(self):
return "seaborn"

20 changes: 20 additions & 0 deletions pandasai/helpers/viz_library_types/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from enum import Enum


class VisualizationLibrary(str, Enum):
"""
VisualizationLibrary is an enumeration that represents the available
data visualization libraries.

Attributes:
MATPLOTLIB (str): Represents the Matplotlib library.
SEABORN (str): Represents the Seaborn library.
PLOTLY (str): Represents the Plotly library.
"""

MATPLOTLIB = "matplotlib"
SEABORN = "seaborn"
PLOTLY = "plotly"

DEFAULT = "default"

2 changes: 2 additions & 0 deletions pandasai/schemas/df_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..callbacks.base import BaseCallback
from ..llm import LLM, LangchainLLM
from ..exceptions import LLMNotFoundError
from ..helpers.viz_library_types.base import VisualizationLibrary


class LogServerConfig(TypedDict):
Expand All @@ -32,6 +33,7 @@ class Config(BaseModel):
lazy_load_connector: bool = True
response_parser: Type[ResponseParser] = None
llm: Any = None
data_viz_library: Optional[VisualizationLibrary] = None
log_server: LogServerConfig = None

class Config:
Expand Down
2 changes: 1 addition & 1 deletion pandasai/smart_dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,4 +723,4 @@ def __repr__(self):
return self.dataframe.__repr__()

def __len__(self):
return len(self.dataframe)
return len(self.dataframe)
29 changes: 28 additions & 1 deletion pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pandasai.helpers.query_exec_tracker import QueryExecTracker

from ..helpers.output_types import output_type_factory
from ..helpers.viz_library_types import viz_lib_type_factory
from pandasai.responses.context import Context
from pandasai.responses.response_parser import ResponseParser
from ..llm.base import LLM
Expand All @@ -45,6 +46,7 @@
from ..middlewares.base import Middleware
from ..helpers.df_info import DataFrameType
from ..helpers.path import find_project_root
from ..helpers.viz_library_types.base import VisualizationLibrary
from ..exceptions import AdvancedReasoningDisabledError


Expand All @@ -68,6 +70,8 @@ class SmartDatalake:
_last_result: str = None
_last_error: str = None

_viz_lib: str = None

def __init__(
self,
dfs: List[Union[DataFrameType, Any]],
Expand Down Expand Up @@ -117,6 +121,9 @@ def __init__(
else:
self._response_parser = ResponseParser(context)

if self._config.data_viz_library:
self._viz_lib = self._config.data_viz_library.value

self._conversation_id = uuid.uuid4()

self._instance = self.__class__.__name__
Expand Down Expand Up @@ -191,6 +198,10 @@ def _load_config(self, config: Union[Config, dict]):
self._load_llm(config["llm"])
config["llm"] = self._llm

if config.get("data_viz_library"):
self._load_data_viz_library(config["data_viz_library"])
config["data_viz_library"] = self._data_viz_library

self._config = Config(**config)

def _load_llm(self, llm: LLM):
Expand All @@ -211,6 +222,21 @@ def _load_llm(self, llm: LLM):

self._llm = llm

def _load_data_viz_library(self, data_viz_library: str):
"""
Load the appropriate instance for viz library type to use.

Args:
data_viz_library (enum): TODO

Raises:
TODO
"""

self._data_viz_library = VisualizationLibrary.DEFAULT.value
if data_viz_library in (item.value for item in VisualizationLibrary):
self._data_viz_library = data_viz_library

def add_middlewares(self, *middlewares: Optional[Middleware]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _load_data_viz_library method correctly sets the _data_viz_library attribute based on the data_viz_library argument. However, the method's docstring needs to be updated to reflect its functionality and the possible exceptions it might raise.

def _load_data_viz_library(self, data_viz_library: str):
    """
    Load the appropriate instance for viz library type to use.

    Args:
        data_viz_library (str): The name of the visualization library to use.

    Raises:
        ValueError: If the provided visualization library is not supported.
    """
    self._data_viz_library = VisualizationLibrary.DEFAULT.value
    if data_viz_library in (item.value for item in VisualizationLibrary):
        self._data_viz_library = data_viz_library

"""
Add middlewares to PandasAI instance.
Expand Down Expand Up @@ -273,7 +299,6 @@ def _get_prompt(
prompt.set_var(key, value)

self.logger.log(f"Using prompt: {prompt}")

return prompt

def _get_cache_key(self) -> str:
Expand Down Expand Up @@ -333,6 +358,7 @@ def chat(self, query: str, output_type: Optional[str] = None):

try:
output_type_helper = output_type_factory(output_type, logger=self.logger)
viz_lib_helper = viz_lib_type_factory(self._viz_lib, logger=self.logger)

if (
self._config.enable_cache
Expand All @@ -349,6 +375,7 @@ def chat(self, query: str, output_type: Optional[str] = None):
# TODO: find a better way to determine the engine,
"engine": self._dfs[0].engine,
"output_type_hint": output_type_helper.template_hint,
"viz_library_type": viz_lib_helper.template_hint,
}

if (
Expand Down
Loading