forked from francescanaretto/LORE_sa
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
moving utility methods of TabularDataset in the super class.
Exploiting ```self.descriptor``` for computing access methods to features. Addresses issue #8
- Loading branch information
Showing
3 changed files
with
95 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,92 @@ | ||
__all__ = ["Dataset"] | ||
|
||
from abc import abstractmethod | ||
from lore_sa.logger import logger | ||
|
||
|
||
|
||
class Dataset(): | ||
""" | ||
Generic class to handle datasets | ||
""" | ||
|
||
def __init__(self, descriptor: dict = None, class_name: str = None): | ||
self.descriptor = descriptor | ||
self.class_name = class_name | ||
# Update the descriptor adding an additional key referring to the | ||
# target label | ||
self.descriptor = self.set_target_label(self.descriptor) | ||
|
||
@abstractmethod | ||
def update_descriptor(self): | ||
""" | ||
it creates the dataset descriptor dictionary | ||
""" | ||
|
||
def set_target_label(self, descriptor): | ||
""" | ||
Set the target column into the dataset descriptor | ||
:param descriptor: | ||
:return: a modified version of the input descriptor with a new key 'target' | ||
""" | ||
if self.class_name is None: | ||
logger.warning("No target class is defined") | ||
return descriptor | ||
|
||
for type in descriptor: | ||
for k in descriptor[type]: | ||
if k == self.class_name: | ||
descriptor['target'] = {k: descriptor[type][k]} | ||
descriptor[type].pop(k) | ||
return descriptor | ||
|
||
return descriptor | ||
|
||
def set_descriptor(self, descriptor): | ||
self.descriptor = descriptor | ||
self.descriptor = self.set_target_label(self.descriptor) | ||
|
||
def set_class_name(self,class_name: str): | ||
""" | ||
Set the class name. Only the column name string | ||
:param [str] class_name: | ||
:return: | ||
""" | ||
self.class_name = class_name | ||
self.descriptor = self.set_target_label(self.descriptor) | ||
|
||
def get_class_values(self): | ||
""" | ||
return the list of values of the target column | ||
:return: | ||
""" | ||
if self.class_name is None: | ||
raise Exception("ERR: class_name is None. Set class_name with set_class_name('<column name>')") | ||
print("test1", self.descriptor['target']) | ||
return self.descriptor['target'][self.class_name]['distinct_values'] | ||
|
||
def get_numeric_columns(self): | ||
numeric_columns = list(self.descriptor['numeric'].keys()) | ||
return numeric_columns | ||
|
||
def get_categorical_columns(self): | ||
categorical_columns = list(self.descriptor['categorical'].keys()) | ||
return categorical_columns | ||
|
||
def get_feature_names(self): | ||
return self.get_numeric_columns() + self.get_categorical_columns() | ||
|
||
def get_feature_name(self, index): | ||
pass | ||
|
||
def get_feature_name(self, index): | ||
pass | ||
""" | ||
Get the feature name by index | ||
:param index: | ||
:return: the name of the corresponding feature | ||
""" | ||
for category in self.descriptor.keys(): | ||
for name in self.descriptor[category].keys(): | ||
if self.descriptor[category][name]['index'] == index: | ||
return name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters