Skip to content

Commit

Permalink
Merge pull request #40 from mathysgrapotte/yaml-refactor-auto-class-b…
Browse files Browse the repository at this point in the history
…uild

Yaml refactor auto class build
  • Loading branch information
mathysgrapotte authored Jan 17, 2025
2 parents d01b012 + 8cfab78 commit bdcb655
Show file tree
Hide file tree
Showing 16 changed files with 572 additions and 342 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies = [
"safetensors>=0.4.5",
"scikit-learn>=1.5.0",
"scipy==1.14.1",
"syrupy>=4.8.0",
"torch>=2.2.2",
"torch==2.2.2; sys_platform == 'darwin' and platform_machine == 'x86_64'"
]
Expand Down
123 changes: 73 additions & 50 deletions src/stimulus/data/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@

from functools import partial
from typing import Any, Tuple, Union
from abc import ABC

import numpy as np
import polars as pl
import yaml
import stimulus.data.experiments as experiments
import stimulus.utils.yaml_data as yaml_data
import torch
import yaml

from stimulus.data import experiments
from stimulus.utils import yaml_data


class DatasetManager:
"""Class for managing the dataset.
This class handles loading and organizing dataset configuration from YAML files.
It manages column categorization into input, label and meta types based on the config.
Expand All @@ -36,9 +37,10 @@ class DatasetManager:
categorize_columns_by_type() -> dict: Organizes the columns into input, label, meta based on the config.
"""

def __init__(self,
config_path: str,
) -> None:
def __init__(
self,
config_path: str,
) -> None:
self.config = self._load_config(config_path)
self.column_categories = self.categorize_columns_by_type()

Expand All @@ -52,7 +54,7 @@ def categorize_columns_by_type(self) -> dict:
dict: Dictionary containing lists of column names for each category:
{
"input": ["col1", "col2"], # Input columns
"label": ["target"], # Label/output columns
"label": ["target"], # Label/output columns
"meta": ["id"] # Metadata columns
}
Expand Down Expand Up @@ -94,13 +96,33 @@ def _load_config(self, config_path: str) -> dict:
>>> print(config["columns"][0]["column_name"])
'hello'
"""
with open(config_path, "r") as file:
return yaml_data.YamlConfigDict(**yaml.safe_load(file))
with open(config_path) as file:
return yaml_data.YamlSubConfigDict(**yaml.safe_load(file))

def get_split_columns(self) -> str:
"""Get the columns that are used for splitting."""
return self.config.split.split_input_columns


def get_transform_logic(self) -> dict:
"""Get the transformation logic.
Returns a dictionary in the following structure :
{
"transformation_name": str,
"transformations": list[Tuple[str, str, dict]]
}
"""
transformation_logic = {
"transformation_name": self.config.transforms.transformation_name,
"transformations": [],
}
for column in self.config.transforms.columns:
for transformation in column.transformations:
transformation_logic["transformations"].append(
(column.column_name, transformation.name, transformation.params)
)
return transformation_logic


class EncodeManager:
"""Manages the encoding of data columns using configured encoders.
Expand All @@ -115,15 +137,16 @@ class EncodeManager:
Example:
>>> encoder_loader = EncoderLoader(config)
>>> encode_manager = EncodeManager(encoder_loader)
>>> data = ["ACGT", "TGCA", "GCTA"]
>>> data = ["ACGT", "TGCA", "GCTA"]
>>> encoded = encode_manager.encode_column("dna_seq", data)
>>> print(encoded.shape)
torch.Size([3, 4, 4]) # 3 sequences, length 4, one-hot encoded
"""

def __init__(self,
encoder_loader: experiments.EncoderLoader,
) -> None:
def __init__(
self,
encoder_loader: experiments.EncoderLoader,
) -> None:
"""Initializes the EncodeManager.
Args:
Expand Down Expand Up @@ -167,39 +190,41 @@ def encode_columns(self, column_data: dict) -> dict:
tensor depends on the encoder used for that column.
Example:
>>> data = {
... "dna_seq": ["ACGT", "TGCA"],
... "labels": ["1", "2"]
... }
>>> data = {"dna_seq": ["ACGT", "TGCA"], "labels": ["1", "2"]}
>>> encoded = encode_manager.encode_columns(data)
>>> print(encoded["dna_seq"].shape)
torch.Size([2, 4, 4]) # 2 sequences, length 4, one-hot encoded
"""
return {col: self.encode_column(col, values) for col, values in column_data.items()}


class TransformManager:
"""Class for managing the transformations."""

def __init__(self,
transform_loader: experiments.TransformLoader,
) -> None:
def __init__(
self,
transform_loader: experiments.TransformLoader,
) -> None:
self.transform_loader = transform_loader


class SplitManager:
"""Class for managing the splitting."""

def __init__(self,
split_loader: experiments.SplitLoader,
) -> None:
def __init__(
self,
split_loader: experiments.SplitLoader,
) -> None:
self.split_loader = split_loader

def get_split_indices(self, data: dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Get the indices for train, validation, and test splits."""
return self.split_loader.get_function_split()(data)


class DatasetHandler:
"""Main class for handling dataset loading, encoding, transformation and splitting.
This class coordinates the interaction between different managers to process
CSV datasets according to the provided configuration.
Expand All @@ -210,13 +235,14 @@ class DatasetHandler:
dataset_manager (DatasetManager): Manager for organizing dataset columns and config.
"""

def __init__(self,
encoder_loader: experiments.EncoderLoader,
transform_loader: experiments.TransformLoader,
split_loader: experiments.SplitLoader,
config_path: str,
csv_path: str,
) -> None:
def __init__(
self,
encoder_loader: experiments.EncoderLoader,
transform_loader: experiments.TransformLoader,
split_loader: experiments.SplitLoader,
config_path: str,
csv_path: str,
) -> None:
"""Initialize the DatasetHandler with required loaders and config.
Args:
Expand All @@ -235,7 +261,7 @@ def __init__(self,

def read_csv_header(self, csv_path: str) -> list:
"""Get the column names from the header of the CSV file.
Args:
csv_path (str): Path to the CSV file to read headers from.
Expand All @@ -245,35 +271,35 @@ def read_csv_header(self, csv_path: str) -> list:
with open(csv_path) as f:
header = f.readline().strip().split(",")
return header

def load_csv(self, csv_path: str) -> pl.DataFrame:
"""Load the CSV file into a polars DataFrame.
Args:
csv_path (str): Path to the CSV file to load.
Returns:
pl.DataFrame: Polars DataFrame containing the loaded CSV data.
"""
return pl.read_csv(csv_path)

def select_columns(self, columns: list) -> dict:
"""Select specific columns from the DataFrame and return as a dictionary.
Args:
columns (list): List of column names to select.
Returns:
dict: A dictionary where keys are column names and values are lists containing the column data.
Example:
>>> handler = DatasetHandler(...)
>>> data_dict = handler.select_columns(["col1", "col2"])
>>> # Returns {'col1': [1, 2, 3], 'col2': [4, 5, 6]}
"""
df = self.data.select(columns)
return {col: df[col].to_list() for col in columns}

def add_split(self, force=False) -> None:
"""Add a column specifying the train, validation, test splits of the data.
An error exception is raised if the split column is already present in the csv file. This behaviour can be overriden by setting force=True.
Expand All @@ -290,12 +316,7 @@ def add_split(self, force=False) -> None:
)
# get relevant split columns from the dataset_manager
split_columns = self.dataset_manager.get_split_columns()

# if split_columns is none, build an empty dictionary
if split_columns is None:
split_input_data = {}
else:
split_input_data = self.select_columns(split_columns)
split_input_data = self.select_columns(split_columns)

# get the split indices
train, validation, test = self.split_manager.get_split_indices(split_input_data)
Expand Down Expand Up @@ -324,14 +345,14 @@ def get_all_items(self) -> tuple[dict, dict, dict]:
>>> input_dict, label_dict, meta_dict = handler.get_dataset()
>>> print(input_dict.keys())
dict_keys(['age', 'fare'])
>>> print(label_dict.keys())
>>> print(label_dict.keys())
dict_keys(['survived'])
>>> print(meta_dict.keys())
dict_keys(['passenger_id'])
"""
# Get columns for each category from dataset manager
input_cols = self.dataset_manager.column_categories["input"]
label_cols = self.dataset_manager.column_categories["label"]
label_cols = self.dataset_manager.column_categories["label"]
meta_cols = self.dataset_manager.column_categories["meta"]

# Select and organize data by category
Expand All @@ -345,13 +366,15 @@ def get_all_items(self) -> tuple[dict, dict, dict]:

return encoded_input, encoded_label, meta_data


class CsvHandler:
"""Meta class for handling CSV files."""

def __init__(self, experiment: Any, csv_path: str) -> None:
self.experiment = experiment
self.csv_path = csv_path


class CsvProcessing(CsvHandler):
"""Class to load the input csv data and add noise accordingly."""

Expand Down
Loading

0 comments on commit bdcb655

Please sign in to comment.