Skip to content

Commit

Permalink
updated encoder tests to class
Browse files Browse the repository at this point in the history
  • Loading branch information
suzannejin committed Jan 16, 2025
1 parent c9fab83 commit 76301b7
Show file tree
Hide file tree
Showing 3 changed files with 390 additions and 511 deletions.
19 changes: 10 additions & 9 deletions src/stimulus/data/encoding/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def encode_all(self, data: List[str]) -> torch.tensor:

def decode(self, data: Any) -> Any:
"""Returns an error since decoding does not make sense without encoder information, which is not yet supported."""
raise NotImplementedError("Decoding is not yet supported for StrClassificationInt.")
raise NotImplementedError("Decoding is not yet supported.")

def _check_dtype(self, data: List[str]) -> None:
"""Check if the input data is string data.
Expand All @@ -454,21 +454,22 @@ def _check_dtype(self, data: List[str]) -> None:


class StrClassificationScaledEncoder(StrClassificationIntEncoder):
"""Considering a ensemble of strings, this encoder encodes them into floats from 0 to 1 (essentially scaling the integer encoding)."""

def encode_all(self, data: list) -> np.array:
"""Considering a ensemble of strings, this encoder encodes them into floats from 0 to 1 (essentially scaling the integer encoding).
"""
def encode_all(self, data: List[str]) -> torch.Tensor:
"""Encodes the data.
This method takes as input a list of data points, should be mappable to a single output, using LabelEncoder from scikit learn and returning a numpy array.
For more info visit : https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html
Args:
data (List[str]): a list of strings
Returns:
encoded_data (torch.Tensor): the encoded data
"""
encoded_data = super().encode_all(data)
return encoded_data / (len(np.unique(encoded_data)) - 1)

def decode(self, data: Any) -> Any:
"""Returns an error since decoding does not make sense without encoder information, which is not yet supported."""
raise NotImplementedError("Decoding is not yet supported for StrClassificationScaled.")


class FloatRankEncoder(AbstractEncoder):
"""Considering an ensemble of float values, this encoder encodes them into floats from 0 to 1, where 1 is the maximum value and 0 is the minimum value."""

Expand Down
Loading

0 comments on commit 76301b7

Please sign in to comment.