-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
141 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
""" | ||
Flat classifier approach, used for comparison purposes. | ||
Implementation by @lpfgarcia | ||
""" | ||
|
||
import numpy as np | ||
from sklearn.base import BaseEstimator | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.utils.validation import check_is_fitted | ||
|
||
|
||
class FlatClassifier(BaseEstimator): | ||
""" | ||
A flat classifier utility that accepts as input a hierarchy and flattens it internally. | ||
Examples | ||
-------- | ||
>>> from hiclass import FlatClassifier | ||
>>> y = [['1', '1.1'], ['2', '2.1']] | ||
>>> X = [[1, 2], [3, 4]] | ||
>>> flat = FlatClassifier() | ||
>>> flat.fit(X, y) | ||
>>> flat.predict(X) | ||
array([['1', '1.1'], | ||
['2', '2.1']]) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
local_classifier: BaseEstimator = LogisticRegression(), | ||
): | ||
""" | ||
Initialize a flat classifier. | ||
Parameters | ||
---------- | ||
local_classifier : BaseEstimator, default=LogisticRegression | ||
The scikit-learn model used for the flat classification. Needs to have fit, predict and clone methods. | ||
""" | ||
self.local_classifier = local_classifier | ||
|
||
def fit(self, X, y, sample_weight=None): | ||
""" | ||
Fit a flat classifier. | ||
Parameters | ||
---------- | ||
X : {array-like, sparse matrix} of shape (n_samples, n_features) | ||
The training input samples. Internally, its dtype will be converted | ||
to ``dtype=np.float32``. If a sparse matrix is provided, it will be | ||
converted into a sparse ``csc_matrix``. | ||
y : array-like of shape (n_samples, n_levels) | ||
The target values, i.e., hierarchical class labels for classification. | ||
sample_weight : array-like of shape (n_samples,), default=None | ||
Array of weights that are assigned to individual samples. | ||
If not provided, then each sample is given unit weight. | ||
Returns | ||
------- | ||
self : object | ||
Fitted estimator. | ||
""" | ||
# Convert from hierarchical labels to flat labels | ||
self.separator_ = "::HiClass::Separator::" | ||
y = [self.separator_.join(i) for i in y] | ||
|
||
# Fit flat classifier | ||
self.local_classifier.fit(X, y, sample_weight=sample_weight) | ||
|
||
# Return the classifier | ||
return self | ||
|
||
def predict(self, X): | ||
""" | ||
Predict classes for the given data. | ||
Hierarchical labels are returned. | ||
Parameters | ||
---------- | ||
X : {array-like, sparse matrix} of shape (n_samples, n_features) | ||
The input samples. Internally, its dtype will be converted | ||
to ``dtype=np.float32``. If a sparse matrix is provided, it will be | ||
converted into a sparse ``csr_matrix``. | ||
Returns | ||
------- | ||
y : ndarray of shape (n_samples,) or (n_samples, n_outputs) | ||
The predicted classes. | ||
""" | ||
# Check if fit has been called | ||
check_is_fitted(self) | ||
|
||
# Predict and remove separator | ||
predictions = [ | ||
i.split(self.separator_) for i in self.local_classifier.predict(X) | ||
] | ||
|
||
return np.array(predictions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import numpy as np | ||
from numpy.testing import assert_array_equal | ||
|
||
from hiclass import FlatClassifier | ||
|
||
|
||
def test_fit_predict(): | ||
flat = FlatClassifier() | ||
x = np.array([[1, 2], [3, 4]]) | ||
y = np.array([["a", "b"], ["b", "c"]]) | ||
flat.fit(x, y) | ||
predictions = flat.predict(x) | ||
assert_array_equal(y, predictions) |