Skip to content

Commit

Permalink
moving utility methods of TabularDataset in the super class.
Browse files Browse the repository at this point in the history
Exploiting ```self.descriptor``` for computing access methods to features. Addresses issue #8
  • Loading branch information
rinziv committed Nov 30, 2023
1 parent b24451e commit 2fb4501
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 69 deletions.
74 changes: 73 additions & 1 deletion lore_sa/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,92 @@
__all__ = ["Dataset"]

from abc import abstractmethod
from lore_sa.logger import logger



class Dataset():
"""
Generic class to handle datasets
"""

def __init__(self, descriptor: dict = None, class_name: str = None):
self.descriptor = descriptor
self.class_name = class_name
# Update the descriptor adding an additional key referring to the
# target label
self.descriptor = self.set_target_label(self.descriptor)

@abstractmethod
def update_descriptor(self):
"""
it creates the dataset descriptor dictionary
"""

def set_target_label(self, descriptor):
"""
Set the target column into the dataset descriptor
:param descriptor:
:return: a modified version of the input descriptor with a new key 'target'
"""
if self.class_name is None:
logger.warning("No target class is defined")
return descriptor

for type in descriptor:
for k in descriptor[type]:
if k == self.class_name:
descriptor['target'] = {k: descriptor[type][k]}
descriptor[type].pop(k)
return descriptor

return descriptor

def set_descriptor(self, descriptor):
self.descriptor = descriptor
self.descriptor = self.set_target_label(self.descriptor)

def set_class_name(self,class_name: str):
"""
Set the class name. Only the column name string
:param [str] class_name:
:return:
"""
self.class_name = class_name
self.descriptor = self.set_target_label(self.descriptor)

def get_class_values(self):
"""
return the list of values of the target column
:return:
"""
if self.class_name is None:
raise Exception("ERR: class_name is None. Set class_name with set_class_name('<column name>')")
print("test1", self.descriptor['target'])
return self.descriptor['target'][self.class_name]['distinct_values']

def get_numeric_columns(self):
numeric_columns = list(self.descriptor['numeric'].keys())
return numeric_columns

def get_categorical_columns(self):
categorical_columns = list(self.descriptor['categorical'].keys())
return categorical_columns

def get_feature_names(self):
return self.get_numeric_columns() + self.get_categorical_columns()

def get_feature_name(self, index):
pass

def get_feature_name(self, index):
pass
"""
Get the feature name by index
:param index:
:return: the name of the corresponding feature
"""
for category in self.descriptor.keys():
for name in self.descriptor[category].keys():
if self.descriptor[category][name]['index'] == index:
return name
88 changes: 21 additions & 67 deletions lore_sa/dataset/tabular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class TabularDataset(Dataset):
}
"""
def __init__(self,data: DataFrame, class_name:str = None):

super().__init__()
self.class_name = class_name
self.df = data

Expand All @@ -56,55 +56,39 @@ def __init__(self,data: DataFrame, class_name:str = None):
self.descriptor = {'numeric':{}, 'categorical':{}}

#creation of a default version of descriptor
self.update_descriptor()
self.descriptor = self.update_descriptor(self.df)
print(self.descriptor)

def update_descriptor(self):
def update_descriptor(self, df: DataFrame):
"""
it creates the dataset descriptor dictionary
"""
self.descriptor = {'numeric':{}, 'categorical':{}}
for feature in self.df.columns:
index = self.df.columns.get_loc(feature)
if feature in self.df.select_dtypes(include=np.number).columns.tolist():
descriptor = {'numeric':{}, 'categorical':{}}
for feature in df.columns:
index = df.columns.get_loc(feature)
if feature in df.select_dtypes(include=np.number).columns.tolist():
#numerical
desc = {'index': index,
'min' : self.df[feature].min(),
'max' : self.df[feature].max(),
'mean':self.df[feature].mean(),
'std':self.df[feature].std(),
'median':self.df[feature].median(),
'q1':self.df[feature].quantile(0.25),
'q3':self.df[feature].quantile(0.75),
'min' : df[feature].min(),
'max' : df[feature].max(),
'mean':df[feature].mean(),
'std':df[feature].std(),
'median':df[feature].median(),
'q1':df[feature].quantile(0.25),
'q3':df[feature].quantile(0.75),
}
self.descriptor['numeric'][feature] = desc
descriptor['numeric'][feature] = desc
else:
#categorical feature
desc = {'index': index,
'distinct_values' : list(self.df[feature].unique()),
'count' : {x : len(self.df[self.df[feature] == x]) for x in list(self.df[feature].unique())}}
self.descriptor['categorical'][feature] = desc

self.descriptor = self.set_target_label(self.descriptor)
'distinct_values' : list(df[feature].unique()),
'count' : {x : len(df[df[feature] == x]) for x in list(df[feature].unique())}}
descriptor['categorical'][feature] = desc

def set_target_label(self, descriptor):
"""
Set the target column into the dataset descriptor
descriptor = self.set_target_label(descriptor)
return descriptor

:param descriptor:
:return:
"""
if self.class_name is None:
logger.warning("No target class is defined")
return descriptor

for type in descriptor:
for k in descriptor[type]:
if k == self.class_name:
descriptor['target'] = {k:descriptor[type][k]}
descriptor[type].pop(k)
return descriptor

return descriptor


@classmethod
Expand All @@ -131,33 +115,3 @@ def from_dict(cls, data: dict, class_name: str=None):
"""
return cls(pd.DataFrame(data), class_name=class_name)

def set_class_name(self,class_name: str):
"""
Set the class name. Only the column name string
:param [str] class_name:
:return:
"""
self.class_name = class_name

def get_class_values(self):
"""
Provides the class_name
:return:
"""
if self.class_name is None:
raise Exception("ERR: class_name is None. Set class_name with set_class_name('<column name>')")
return self.df[self.class_name].values


def get_numeric_columns(self):
numeric_columns = list(self.df._get_numeric_data().columns)
return numeric_columns

def get_features_names(self):
return list(self.df.columns)

def get_feature_name(self, index):
for category in self.descriptor.keys():
for name in self.descriptor[category].keys():
if self.descriptor[category][name]['index'] == index:
return name
2 changes: 1 addition & 1 deletion test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_get_numeric_columns(self):

def test_get_class_value_as_exptected(self):
dataset = TabularDataset.from_dict({'col1': [1, 2], 'col2': [3, 4], 'col3': ['America', 'Europe']},class_name='col3')
self.assertEqual(dataset.get_class_values().tolist(),['America', 'Europe'])
self.assertEqual(dataset.get_class_values(),['America', 'Europe'])

def test_get_class_value_raise_error(self):
dataset = TabularDataset.from_dict({'col1': [1, 2], 'col2': [3, 4], 'col3': ['America', 'Europe']})
Expand Down

0 comments on commit 2fb4501

Please sign in to comment.