Skip to content

Commit

Permalink
Merge pull request Giskard-AI#2039 from Giskard-AI/feature/gsk-3827-l…
Browse files Browse the repository at this point in the history
…oad-scan-test-suite-doesnt-work

[GSK-3827] Fix load/save giskard Dataset
  • Loading branch information
henchaves authored Oct 11, 2024
2 parents dcf36fb + b5561dd commit 6cdde2f
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 13 deletions.
32 changes: 19 additions & 13 deletions giskard/datasets/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
if TYPE_CHECKING:
from mlflow import MlflowClient

SAMPLE_SIZE = 1000

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -526,10 +524,22 @@ def cast_column_to_dtypes(df, column_dtypes):

@classmethod
def load(cls, local_path: str):
with open(local_path, "rb") as ds_stream:
return pd.read_csv(
ZstdDecompressor().stream_reader(ds_stream), keep_default_na=False, na_values=["_GSK_NA_"]
)
# load metadata
with open(Path(local_path) / "giskard-dataset-meta.yaml", "r") as meta_f:
meta = yaml.safe_load(meta_f)

# load data
with open(Path(local_path) / "data.csv.zst", "rb") as ds_stream:
df = pd.read_csv(ZstdDecompressor().stream_reader(ds_stream), keep_default_na=False, na_values=["_GSK_NA_"])

return cls(
df,
name=meta.get("name"),
target=meta.get("target"),
cat_columns=[k for k in meta["category_features"].keys()],
column_types=meta.get("column_types"),
original_id=meta.get("id"),
)

@staticmethod
def _cat_columns(meta):
Expand All @@ -543,21 +553,17 @@ def _cat_columns(meta):
def cat_columns(self):
return self._cat_columns(self.meta)

def save(self, local_path: Path, dataset_id):
with open(local_path / "data.csv.zst", "wb") as f, open(local_path / "data.sample.csv.zst", "wb") as f_sample:
def save(self, local_path: str):
with (open(Path(local_path) / "data.csv.zst", "wb") as f,):
uncompressed_bytes = save_df(self.df)
compressed_bytes = compress(uncompressed_bytes)
f.write(compressed_bytes)
original_size_bytes, compressed_size_bytes = len(uncompressed_bytes), len(compressed_bytes)

uncompressed_bytes = save_df(self.df.sample(min(SAMPLE_SIZE, len(self.df.index))))
compressed_bytes = compress(uncompressed_bytes)
f_sample.write(compressed_bytes)

with open(Path(local_path) / "giskard-dataset-meta.yaml", "w") as meta_f:
yaml.dump(
{
"id": dataset_id,
"id": str(self.id),
"name": self.meta.name,
"target": self.meta.target,
"column_types": self.meta.column_types,
Expand Down
57 changes: 57 additions & 0 deletions tests/datasets/test_dataset_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import tempfile

import pandas as pd
import pytest

from giskard.datasets import Dataset


@pytest.mark.parametrize(
"dataset",
[
Dataset(
pd.DataFrame(
{
"question": [
"What is the capital of France?",
"What is the capital of Germany?",
]
}
),
column_types={"question": "text"},
target=None,
),
Dataset(
pd.DataFrame(
{
"country": ["France", "Germany", "France", "Germany", "France"],
"capital": ["Paris", "Berlin", "Paris", "Berlin", "Paris"],
}
),
column_types={"country": "category", "capital": "category"},
cat_columns=["country", "capital"],
target=None,
),
Dataset(
pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"y": [2, 4, 6, 8, 10],
}
),
column_types={"x": "numeric", "y": "numeric"},
target="y",
),
],
ids=["text", "category", "numeric"],
)
def test_save_and_load_dataset(dataset: Dataset):
with tempfile.TemporaryDirectory() as tmp_test_folder:
dataset.save(tmp_test_folder)

loaded_dataset = Dataset.load(tmp_test_folder)

assert loaded_dataset.id != dataset.id
assert loaded_dataset.original_id == dataset.id
assert pd.DataFrame.equals(loaded_dataset.df, dataset.df)
assert loaded_dataset.meta == dataset.meta

0 comments on commit 6cdde2f

Please sign in to comment.