Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
Gandalf + Refactoring (manujosephv#257)
Browse files Browse the repository at this point in the history
* included entmax lib in req
moved entmax and sparsemax functions out of the code base

* added exp settings and benchmarks for exp

* working t-softmax

* added init and controllable sparsity

* added learnable sparsity config

* pre-commit changes

* enabling GATE with no trees

* removing a accidental commited checkpoint

* req change for tabnet

* downgrading tabnet

* downgrading tabnet

* again

* reverting tabnet req

* bug fix for regression

* Add Layer to GATE for regresion without trees

* bug fix in sequential

* refactored common layers

* refactored some more common components
added rich logger

* created new gandalf model

* refactored gate also working

* re-factored feature importance
fixed bugs

* pre-commit changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added try block for lightning_lite import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* documentation updates

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
manujosephv and pre-commit-ci[bot] authored Jul 19, 2023
1 parent 15d8d25 commit 20499a4
Show file tree
Hide file tree
Showing 35 changed files with 1,868 additions and 1,041 deletions.
66 changes: 44 additions & 22 deletions docs/apidocs_common.md
Original file line number Diff line number Diff line change
@@ -1,71 +1,93 @@
## Layers
## Embeddings

::: pytorch_tabular.models.common.layers.AddNorm
options:
heading_level: 3
::: pytorch_tabular.models.common.layers.Embedding1dLayer
options:
heading_level: 3
::: pytorch_tabular.models.common.layers.Embedding2dLayer
options:
heading_level: 3
::: pytorch_tabular.models.common.layers.Lambda
::: pytorch_tabular.models.common.layers.PreEncoded1dLayer
options:
heading_level: 3
::: pytorch_tabular.models.common.layers.ModuleWithInit
::: pytorch_tabular.models.common.layers.SharedEmbeddings
options:
heading_level: 3
::: pytorch_tabular.models.common.layers.MultiHeadedAttention

## Gated Units
::: pytorch_tabular.models.common.layers.GatedFeatureLearningUnit
options:
heading_level: 3
::: pytorch_tabular.models.common.layers.PositionWiseFeedForward
::: pytorch_tabular.models.common.layers.GEGLU
options:
heading_level: 3
::: pytorch_tabular.models.common.layers.PreEncoded1dLayer
::: pytorch_tabular.models.common.layers.ReGLU
options:
heading_level: 3
::: pytorch_tabular.models.common.layers.Residual
::: pytorch_tabular.models.common.layers.SwiGLU
options:
heading_level: 3
::: pytorch_tabular.models.common.layers.SharedEmbeddings
::: pytorch_tabular.models.common.layers.PositionWiseFeedForward
options:
heading_level: 3
::: pytorch_tabular.models.common.layers.TransformerEncoderBlock

## Soft Trees
::: pytorch_tabular.models.common.layers.NeuralDecisionTree
options:
heading_level: 3
::: pytorch_tabular.models.common.layers.ODST
options:
heading_level: 3

## Activations

::: pytorch_tabular.models.common.activations.Entmax15Function
## Transformers
::: pytorch_tabular.models.common.layers.AddNorm
options:
heading_level: 3
::: pytorch_tabular.models.common.activations.Entmoid15

::: pytorch_tabular.models.common.layers.AppendCLSToken
options:
heading_level: 3
::: pytorch_tabular.models.common.activations.GEGLU
::: pytorch_tabular.models.common.layers.MultiHeadedAttention
options:
heading_level: 3
::: pytorch_tabular.models.common.activations.PositionWiseFeedForward
::: pytorch_tabular.models.common.layers.TransformerEncoderBlock
options:
heading_level: 3
::: pytorch_tabular.models.common.activations.ReGLU

## Miscellaneous
::: pytorch_tabular.models.common.layers.Lambda
options:
heading_level: 3
::: pytorch_tabular.models.common.activations.SparsemaxFunction
::: pytorch_tabular.models.common.layers.ModuleWithInit
options:
heading_level: 3
::: pytorch_tabular.models.common.activations.SwiGLU
::: pytorch_tabular.models.common.layers.Residual
options:
heading_level: 3
::: pytorch_tabular.models.common.activations.entmax15


## Activations
::: pytorch_tabular.models.common.activations.Entmoid15
options:
heading_level: 3
::: pytorch_tabular.models.common.activations.entmoid15
options:
heading_level: 3
::: pytorch_tabular.models.common.activations.entmax15
options:
heading_level: 3
::: pytorch_tabular.models.common.activations.sparsemax
options:
heading_level: 3
::: pytorch_tabular.models.common.activations.sparsemoid
options:
heading_level: 3
::: pytorch_tabular.models.common.activations.t_softmax
options:
heading_level: 3
::: pytorch_tabular.models.common.activations.TSoftmax
options:
heading_level: 3
::: pytorch_tabular.models.common.activations.RSoftmax
options:
heading_level: 3

6 changes: 6 additions & 0 deletions docs/apidocs_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
::: pytorch_tabular.models.FTTransformerConfig
options:
heading_level: 3
::: pytorch_tabular.models.GANDALFConfig
options:
heading_level: 3
::: pytorch_tabular.models.GatedAdditiveTreeEnsembleConfig
options:
heading_level: 3
Expand Down Expand Up @@ -39,6 +42,9 @@
::: pytorch_tabular.models.FTTransformerModel
options:
heading_level: 3
::: pytorch_tabular.models.GANDALFModel
options:
heading_level: 3
::: pytorch_tabular.models.GatedAdditiveTreeEnsembleModel
options:
heading_level: 3
Expand Down
34 changes: 28 additions & 6 deletions docs/apidocs_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,50 @@
options:
heading_level: 3

## Miscllaneous Utilities
## Data Utilities
::: pytorch_tabular.utils.get_balanced_sampler
options:
heading_level: 3
::: pytorch_tabular.utils.get_class_weighted_cross_entropy
options:
heading_level: 3
::: pytorch_tabular.utils.get_gaussian_centers
options:
heading_level: 3

## NN Utilities
::: pytorch_tabular.utils._initialize_layers
options:
heading_level: 3
::: pytorch_tabular.utils._linear_dropout_bn
::: pytorch_tabular.utils._initialize_kaiming
options:
heading_level: 3
::: pytorch_tabular.utils._make_smooth_weights_for_balanced_classes
::: pytorch_tabular.utils._linear_dropout_bn
options:
heading_level: 3
::: pytorch_tabular.utils.get_balanced_sampler
::: pytorch_tabular.utils._make_ix_like
options:
heading_level: 3
::: pytorch_tabular.utils.get_class_weighted_cross_entropy
::: pytorch_tabular.utils.reset_all_weights
options:
heading_level: 3
::: pytorch_tabular.utils.get_gaussian_centers
::: pytorch_tabular.utils.to_one_hot
options:
heading_level: 3

## Python Utilities
::: pytorch_tabular.utils.getattr_nested
options:
heading_level: 3
::: pytorch_tabular.utils.ifnone
options:
heading_level: 3
::: pytorch_tabular.utils.check_numpy
options:
heading_level: 3
::: pytorch_tabular.utils.pl_load
options:
heading_level: 3
::: pytorch_tabular.utils.generate_doc_dataclass
options:
heading_level: 3
57 changes: 49 additions & 8 deletions examples/__only_for_dev__/adhoc_scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,21 @@ def print_metrics(y_true, y_pred, tag):
trainer_config = TrainerConfig(
auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate
batch_size=32,
max_epochs=10,
fast_dev_run=True,
max_epochs=5,
# fast_dev_run=True,
# profiler="simple",
early_stopping=None,
checkpoints=None,
trainer_kwargs={"limit_train_batches": 10},
)
optimizer_config = OptimizerConfig()
model_config = GatedAdditiveTreeEnsembleConfig(
task="classification",
gflu_stages=3,
num_trees=0,
tree_depth=2,
binning_activation="sigmoid",
feature_mask_function="t-softmax",
# layers="4096-4096-512", # Number of nodes in each layer
# activation="LeakyReLU", # Activation between each layers
learning_rate=1e-3,
Expand All @@ -86,13 +93,47 @@ def print_metrics(y_true, y_pred, tag):
)

tabular_model.fit(train=train, validation=val)
test.drop(columns=["target"], inplace=True)
pred_df = tabular_model.predict(test)
pred_df = tabular_model.predict(test, device="cpu")
pred_df = tabular_model.predict(test, device="cuda")
import torch # noqa: E402
# test.drop(columns=["target"], inplace=True)
# pred_df = tabular_model.predict(test)
# pred_df = tabular_model.predict(test, device="cpu")
# pred_df = tabular_model.predict(test, device="cuda")
# import torch

pred_df = tabular_model.predict(test, device=torch.device("cuda"))
# pred_df = tabular_model.predict(test, device=torch.device("cuda"))
# tabular_model.fit(train=train, validation=val)
# tabular_model.fit(train=train, validation=val, max_epochs=5)
# tabular_model.fit(train=train, validation=val, max_epochs=5, reset=True)


# t = torch.rand(128,200)
# a = t.numpy()

# start = time.time()
# t.median(dim=-1)
# end = time.time()
# print("torch median", end - start)

# start = time.time()
# t.quantile(torch.rand(128), dim=-1)
# end = time.time()
# print("torch quant ", end - start)

# start = time.time()
# np.median(t.numpy(), axis=-1)
# end = time.time()
# print("numpy median", end - start)

# start = time.time()
# np.quantile(t.numpy(), np.random.rand(128), axis=-1)
# end = time.time()
# print("numpy quant ", end - start)

# start = time.time()
# st = torch.sort(t, dim=-1)
# end = time.time()
# print("torch sort", end - start)

# start = time.time()
# st = np.sort(t.numpy(), axis=-1)
# end = time.time()
# print("numpy sort", end - start)
86 changes: 86 additions & 0 deletions examples/__only_for_dev__/runtime_benchmarks.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
sigmoid - softmax

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Total | - | 12591 | 12.938 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

entmoid-entmax

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Total | - | 12591 | 25.504 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

sigmoid-entmax
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Total | - | 12591 | 23.011 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

entmoid - softmax
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Total | - | 12591 | 16.014 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

entmoid - sparsemax
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Total | - | 12591 | 24.583 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

sparsemoid - sparsemax
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Total | - | 12591 | 22.899 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

relu15-relu15
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Total | - | 12591 | 15.541 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

entmoid - t-softmax
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Total | - | 12591 | 21.79 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

sigmoid - t-softmax
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Total | - | 12591 | 19.542 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

sigmoid - t-softmax(modified)

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Total | - | 12591 | 19.415 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

entmoid - weighted softmax

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Total | - | 12591 | 18.445 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
| Total | - | 12591 | 16.943 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Loading

0 comments on commit 20499a4

Please sign in to comment.