diff --git a/.gitignore b/.gitignore
index e933d05..1d7e3ca 100644
--- a/.gitignore
+++ b/.gitignore
@@ -115,7 +115,7 @@ dmypy.json
.pyre/
# pycharm
-*/.DS_Store
+*.DS_Store
**/__pycache__/
.idea/
FETCH_HEAD
@@ -123,3 +123,4 @@ FETCH_HEAD
# vscode
.vscode
*.DS_Store
+PaddleFSL/raw_data/
\ No newline at end of file
diff --git a/PaddleFSL/examples/optim/README.md b/PaddleFSL/examples/optim/README.md
new file mode 100644
index 0000000..0d643d7
--- /dev/null
+++ b/PaddleFSL/examples/optim/README.md
@@ -0,0 +1,48 @@
+# Image Classification Tasks
+
+Here, we provide examples of applying PaddleFSL to few-shot image classification tasks which is similarity to example with [model_zoo](../image_classification/README.md).
+
+
+## Datasets
+
+We evaluate the performance on 5 benchmark datasets, including Omniglot, *mini*ImageNet, CIFAR-FS, FC100 and Tiered-ImageNet, which can be accessed as described in [raw_data/README.md](../../raw_data/README.md).
+
+
+## Results
+
+We provide results of using MAML [1], ANIL [2] below. The exact model configuration and pretrained models can be downloaded from [here](https://drive.google.com/file/d/1pmCI-8cwLsadG6JOcubufrQ2d4zpK9B-/view?usp=sharing), which can reproduce these results.
+
+### [MAML](http://proceedings.mlr.press/v70/finn17a/finn17a.pdf?source=post_page---------------------------)
+
+
+| Dataset | Backbone | Way | Shot | Original paper | Other reports | model zoo(first order) | Optim(first order) |
+| :-------------: | :------: | :--: | :--: | :------------: | :----------------------------------------------------------: | :--------------------: | ------------------ |
+| Omniglot | MLP | 5 | 1 | 89.7 ± 1.1 | 88.9
([learn2learn](http://learn2learn.net/)) | 88.88 ± 2.99 | -- |
+| Omniglot | MLP | 5 | 5 | 97.5 ± 0.6 | -- | 97.50 ± 0.47 | -- |
+| Omniglot | CNN | 5 | 1 | 98.7 ± 0.4 | 99.1
([learn2learn](http://learn2learn.net/)) | 97.13 ± 1.25 | 92.7 |
+| Omniglot | CNN | 5 | 5 | 99.9 ± 0.1 | 99.9 ± 0.1
([R2D2](https://arxiv.org/pdf/1805.08136.pdf)) | 99.23 ± 0.40 | ***93.1*** |
+| *mini*ImageNet | CNN | 5 | 1 | 48.70 ± 1.84 | 48.3
([learn2learn](http://learn2learn.net/)) | 49.81 ± 1.78 | |
+| *mini*ImageNet | CNN | 5 | 5 | 63.11 ± 0.92 | 65.4
([learn2learn](http://learn2learn.net/)) | 64.21 ± 1.33 | -- |
+| CIFAR-FS | CNN | 5 | 1 | -- | 58.9 ± 1.9
([R2D2](https://arxiv.org/pdf/1805.08136.pdf)) | 57.06 ± 3.83 | 49.1 |
+| CIFAR-FS | CNN | 5 | 5 | -- | 76.6
([learn2learn](http://learn2learn.net/)) | 72.24 ± 1.71 | -- |
+| FC100 | CNN | 5 | 1 | -- | -- | 37.63 ± 2.23 | 30.2 |
+| FC100 | CNN | 5 | 5 | -- | 49.0
([learn2learn](http://learn2learn.net/)) | 49.14 ± 1.58 | -- |
+| CUB | CNN | 5 | 1 | -- | 54.73 ± 0.97
([CloseLookFS](https://arxiv.org/pdf/1904.04232.pdf)) | 53.31 ± 1.77 | 20.7 |
+| CUB | CNN | 5 | 5 | -- | 75.75 ± 0.76
([CloseLookFS](https://arxiv.org/pdf/1904.04232.pdf)) | 69.88 ± 1.47 | -- |
+| Tiered-ImageNet | CNN | 5 | 5 | -- | -- | 67.56 ± 1.80 | -- |
+
+### [ANIL](https://openreview.net/pdf?id=rkgMkCEtPB)
+
+| Dataset | Backbone | Way | Shot | Author Report | Other Report | model zoo(first order) | Optimizer(First Order) |
+| :------------: | :------: | :--: | :--: | :-----------: | :-----------------------------------------------: | :--------------------: | ---------------------- |
+| Omniglot | CNN | 5 | 1 | -- | -- | 96.06 ± 1.00 | 96.34 ± 1.98 |
+| Omniglot | CNN | 5 | 5 | -- | -- | 98.74 ± 0.48 | |
+| *mini*ImageNet | CNN | 5 | 1 | 46.7 ± 0.4 | -- | 48.31 ± 2.83 | 45.31 ± 1.43 |
+| *mini*ImageNet | CNN | 5 | 5 | 61.5 ± 0.5 | -- | 62.38 ± 1.96 | 61.81 ± 1.2 |
+| CIFAR-FS | CNN | 5 | 1 | -- | -- | 56.19 ± 3.39 | ***30.8 ± 2.5*** |
+| CIFAR-FS | CNN | 5 | 5 | -- | 68.3
([learn2learn](http://learn2learn.net/)) | 68.60 ± 1.25 | 48.6 |
+| FC100 | CNN | 5 | 1 | -- | -- | 40.69 ± 3.32 | 38.4 ± 1.3 |
+| FC100 | CNN | 5 | 5 | -- | 47.6
([learn2learn](http://learn2learn.net/)) | 48.01 ± 1.22 | 35.0 |
+| CUB | CNN | 5 | 1 | -- | -- | 53.25 ± 2.18 | -- |
+| CUB | CNN | 5 | 5 | -- | -- | 69.09 ± 1.12 | -- |
+
diff --git a/PaddleFSL/examples/optim/anil_example.py b/PaddleFSL/examples/optim/anil_example.py
new file mode 100644
index 0000000..08aa8cc
--- /dev/null
+++ b/PaddleFSL/examples/optim/anil_example.py
@@ -0,0 +1,143 @@
+"""MAML example for optimization"""
+from __future__ import annotations
+import os
+import paddle
+from paddle import nn
+from paddle.optimizer import Adam
+import paddlefsl
+from paddlefsl.metaopt.anil import ANILLearner
+from examples.optim.meta_trainer import Config, Trainer, load_datasets
+
+
+def init_models(config: Config):
+ """Initialize models."""
+ if config.dataset == 'cub':
+ config.meta_lr = 0.002
+ config.inner_lr = 0.01
+ config.test_epoch = 10
+ config.meta_batch_size = 32
+ config.train_inner_adapt_steps = 5
+ config.test_inner_adapt_steps = 10
+ config.epochs = 10000
+
+ if config.k_shot == 5:
+ config.meta_lr = 0.003
+ config.inner_lr = 0.05
+ config.epochs = 10000
+
+ feature_model = paddlefsl.backbones.Conv(input_size=(3, 84, 84), output_size=config.n_way, conv_channels=[32, 32, 32, 32])
+ feature_model.output = paddle.nn.Flatten()
+ head_layer = paddle.nn.Linear(in_features=feature_model.feature_size, out_features=config.n_way,
+ weight_attr=feature_model.init_weight_attr, bias_attr=feature_model.init_bias_attr)
+
+ if config.dataset == 'cifarfs':
+ config.meta_lr = 0.001
+ config.inner_lr = 0.02
+ config.test_epoch = 10
+ config.meta_batch_size = 32
+ config.train_inner_adapt_steps = 5
+ config.test_inner_adapt_steps = 10
+ config.epochs = 20000
+ if config.k_shot == 5:
+ config.meta_lr = 0.001
+ config.inner_lr = 0.08
+
+ feature_model = paddlefsl.backbones.Conv(input_size=(3, 32, 32), output_size=config.n_way, conv_channels=[32, 32, 32, 32])
+ feature_model.output = paddle.nn.Flatten()
+ head_layer = paddle.nn.Linear(in_features=32, out_features=config.n_way,
+ weight_attr=feature_model.init_weight_attr, bias_attr=feature_model.init_bias_attr)
+
+ if config.dataset == 'miniimagenet':
+
+ config.meta_lr = 0.002
+ config.inner_lr = 0.05
+ config.test_epoch = 10
+ config.meta_batch_size = 32
+ config.train_inner_adapt_steps = 5
+ config.test_inner_adapt_steps = 10
+ config.epochs = 30000
+
+ feature_model = paddlefsl.backbones.Conv(input_size=(3, 84, 84), output_size=config.n_way, conv_channels=[32, 32, 32, 32])
+ feature_model.output = paddle.nn.Flatten()
+ head_layer = paddle.nn.Linear(in_features=feature_model.feature_size, out_features=config.n_way,
+ weight_attr=feature_model.init_weight_attr, bias_attr=feature_model.init_bias_attr)
+
+ if config.dataset == 'omniglot':
+ config.meta_lr = 0.005
+ config.inner_lr = 0.5
+
+ if config.k_shot == 5:
+ config.meta_lr = 0.06
+ config.inner_lr = 0.12
+ config.train_inner_adapt_steps = 3
+ config.test_inner_adapt_steps = 5
+
+ config.test_epoch = 10
+ config.meta_batch_size = 32
+ config.train_inner_adapt_steps = 1
+ config.test_inner_adapt_steps = 3
+ config.epochs = 30000
+
+ feature_model = paddlefsl.backbones.Conv(input_size=(1, 28, 28), output_size=config.n_way, pooling=False)
+ feature_model.output = paddle.nn.Flatten()
+ head_layer = paddle.nn.Linear(in_features=feature_model.feature_size, out_features=config.n_way,
+ weight_attr=feature_model.init_weight_attr, bias_attr=feature_model.init_bias_attr)
+
+ if config.dataset == 'fc100':
+ config.meta_lr = 0.005
+ config.inner_lr = 0.1
+ config.test_epoch = 10
+ config.meta_batch_size = 32
+ config.train_inner_adapt_steps = 5
+ config.test_inner_adapt_steps = 10
+ config.epochs = 5000
+ if config.k_shot == 5:
+ config.meta_lr = 0.002
+ config.epochs = 2000
+
+ feature_model = paddlefsl.backbones.Conv(input_size=(3, 32, 32), output_size=config.n_way)
+ feature_model.output = paddle.nn.Flatten()
+ head_layer = paddle.nn.Linear(in_features=feature_model.feature_size, out_features=config.n_way,
+ weight_attr=feature_model.init_weight_attr, bias_attr=feature_model.init_bias_attr)
+
+ return feature_model, head_layer
+
+
+if __name__ == '__main__':
+
+ config = Config().parse_args(known_only=True)
+ config.device = 'gpu'
+ config.k_shot = 1
+
+ # config.dataset = 'omniglot'
+ config.dataset = 'miniimagenet'
+ # config.dataset = 'cifarfs'
+ # config.dataset = 'fc100'
+ # config.dataset = 'cub'
+
+ config.tracking_uri = os.environ.get('TRACKING_URI', None)
+ config.experiment_id = os.environ.get('EXPERIMENT_ID', None)
+
+ # Config: ANIL, Omniglot, Conv, 5 Ways, 1 Shot
+ train_dataset, valid_dataset, test_dataset = load_datasets(config.dataset)
+ feature_model, head_layer = init_models(config)
+
+ criterion = nn.CrossEntropyLoss()
+ learner = ANILLearner(
+ feature_model=feature_model,
+ head_layer=head_layer,
+ learning_rate=config.inner_lr,
+ )
+ scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=config.meta_lr, T_max=config.epochs)
+ optimizer = Adam(parameters=learner.parameters(), learning_rate=scheduler)
+ trainer = Trainer(
+ config=config,
+ train_dataset=train_dataset,
+ dev_dataset=valid_dataset,
+ test_dataset=test_dataset,
+ learner=learner,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ criterion=criterion
+ )
+ trainer.train()
diff --git a/PaddleFSL/examples/optim/anil_text_classification.py b/PaddleFSL/examples/optim/anil_text_classification.py
new file mode 100644
index 0000000..6889032
--- /dev/null
+++ b/PaddleFSL/examples/optim/anil_text_classification.py
@@ -0,0 +1,62 @@
+"""ANIL example for optimization"""
+from __future__ import annotations
+import os
+import paddle
+from paddle import nn
+from paddle.optimizer import Adam
+import paddlefsl
+from paddlefsl.metaopt.anil import ANILLearner
+from paddlenlp.transformers.ernie.modeling import ErnieModel
+from paddlenlp.transformers.ernie.tokenizer import ErnieTokenizer
+
+from examples.optim.meta_trainer import Config, Trainer, load_datasets
+
+class SequenceClassifier(nn.Layer):
+ """Sequence Classifier"""
+ def __init__(self, hidden_size: int, output_size: int, dropout: float = 0.1):
+ super().__init__()
+ self.dropout = nn.Dropout(dropout)
+ self.classifier = nn.Linear(hidden_size, output_size)
+
+ def forward(self, embedding):
+ """handle the main logic"""
+ embedding = self.dropout(embedding)
+ logits = self.classifier(embedding)
+ return logits
+
+
+if __name__ == '__main__':
+
+ config = Config().parse_args(known_only=True)
+ config.device = 'gpu'
+
+ train_dataset = paddlefsl.datasets.few_rel.FewRel('train')
+ valid_dataset = paddlefsl.datasets.few_rel.FewRel('valid')
+ test_dataset = paddlefsl.datasets.few_rel.FewRel('valid')
+
+ config.tracking_uri = os.environ.get('TRACKING_URI', None)
+ config.experiment_id = os.environ.get('EXPERIMENT_ID', None)
+
+ tokenzier = ErnieTokenizer.from_pretrained('ernie-1.0')
+ feature_model, head_layer = ErnieModel.from_pretrained('ernie-1.0'), SequenceClassifier(hidden_size=768, output_size=config.n_way)
+
+ criterion = nn.CrossEntropyLoss()
+ learner = ANILLearner(
+ feature_model=feature_model,
+ head_layer=head_layer,
+ learning_rate=config.inner_lr,
+ )
+ scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=config.meta_lr, T_max=config.epochs)
+ optimizer = Adam(parameters=learner.parameters(), learning_rate=scheduler)
+ trainer = Trainer(
+ config=config,
+ train_dataset=train_dataset,
+ dev_dataset=valid_dataset,
+ test_dataset=test_dataset,
+ learner=learner,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ criterion=criterion,
+ tokenizer=tokenzier
+ )
+ trainer.train()
diff --git a/PaddleFSL/examples/optim/data_utils.py b/PaddleFSL/examples/optim/data_utils.py
new file mode 100644
index 0000000..a4a6352
--- /dev/null
+++ b/PaddleFSL/examples/optim/data_utils.py
@@ -0,0 +1,48 @@
+"""Data Utils for Meta Optimzations Algorithms"""
+from __future__ import annotations
+from typing import Tuple, Dict
+import paddlefsl
+from paddlefsl.datasets.cv_dataset import CVDataset
+
+
+def load_datasets(name: str) -> Tuple[CVDataset, CVDataset, CVDataset]:
+ """load CV Dataset by name, which can be omniglot, miniimagenet, or cifar10
+
+ Args:
+ name (str): the name of datasets
+
+ Returns:
+ Tuple[CVDataset, CVDataset, CVDataset]: train, dev, test dataset
+ """
+ datasets_map: Dict[str, CVDataset] = {
+ "omniglot": (
+ paddlefsl.datasets.Omniglot(mode='train', image_size=(28, 28)),
+ paddlefsl.datasets.Omniglot(mode='valid', image_size=(28, 28)),
+ paddlefsl.datasets.Omniglot(mode='test', image_size=(28, 28))
+ ),
+ # "miniimagenet": (
+ # paddlefsl.datasets.MiniImageNet(mode='train'),
+ # paddlefsl.datasets.MiniImageNet(mode='valid'),
+ # paddlefsl.datasets.MiniImageNet(mode='test')
+ # ),
+ # "cifarfs": (
+ # paddlefsl.datasets.CifarFS(mode='train', image_size=(28, 28)),
+ # paddlefsl.datasets.CifarFS(mode='valid', image_size=(28, 28)),
+ # paddlefsl.datasets.CifarFS(mode='test', image_size=(28, 28))
+ # ),
+ # "fc100": (
+ # paddlefsl.datasets.FC100(mode='train'),
+ # paddlefsl.datasets.FC100(mode='valid'),
+ # paddlefsl.datasets.FC100(mode='test')
+ # ),
+ # "cub": (
+ # paddlefsl.datasets.CubFS(mode='train'),
+ # paddlefsl.datasets.CubFS(mode='valid'),
+ # paddlefsl.datasets.CubFS(mode='test')
+ # )
+ }
+ if name not in datasets_map:
+ names = ",".join(list(datasets_map.keys()))
+ raise ValueError(f"{name} is not a valid dataset name, which should be in {names}")
+
+ return datasets_map[name]
diff --git a/PaddleFSL/examples/optim/maml_example.py b/PaddleFSL/examples/optim/maml_example.py
new file mode 100644
index 0000000..147e6e0
--- /dev/null
+++ b/PaddleFSL/examples/optim/maml_example.py
@@ -0,0 +1,113 @@
+"""MAML example for optimization"""
+import os
+
+import paddle
+from paddle import nn
+from paddle.optimizer import Adam
+
+import paddlefsl
+from paddlefsl.metaopt.maml import MAMLLearner
+from examples.optim.meta_trainer import Config, Trainer, load_datasets
+
+
+
+def init_models(config: Config):
+ """Initialize models."""
+ if config.dataset == 'cub':
+ config.meta_lr = 0.002
+ config.inner_lr = 0.03
+ config.test_epoch = 10
+ config.meta_batch_size = 32
+ config.train_inner_adapt_steps = 5
+ config.test_inner_adapt_steps = 10
+ config.epochs = 10000
+
+ model = paddlefsl.backbones.Conv(input_size=(3, 84, 84), output_size=config.n_way, conv_channels=[32, 32, 32, 32])
+
+ if config.dataset == 'cifarfs':
+ config.meta_lr = 0.001
+ config.inner_lr = 0.03
+ config.test_epoch = 10
+ config.meta_batch_size = 32
+ config.train_inner_adapt_steps = 5
+ config.test_inner_adapt_steps = 10
+ config.epochs = 30000
+
+ model = paddlefsl.backbones.Conv(input_size=(3, 32, 32), output_size=config.n_way, conv_channels=[32, 32, 32, 32])
+ model.output = nn.Sequential(
+ nn.Flatten(),
+ nn.Linear(in_features=32, out_features=config.n_way,
+ weight_attr=model.init_weight_attr, bias_attr=model.init_bias_attr)
+ )
+
+ if config.dataset == 'miniimagenet':
+ config.meta_lr = 0.002
+ config.inner_lr = 0.03
+ config.test_epoch = 10
+ config.meta_batch_size = 32
+ config.train_inner_adapt_steps = 5
+ config.test_inner_adapt_steps = 10
+ config.epochs = 60000
+
+ model = paddlefsl.backbones.Conv(input_size=(3, 84, 84), output_size=config.n_way, conv_channels=[32, 32, 32, 32])
+
+ if config.dataset == 'omniglot':
+ config.meta_lr = 0.005
+ config.inner_lr = 0.5
+ config.test_epoch = 10
+ config.meta_batch_size = 32
+ config.train_inner_adapt_steps = 1
+ config.test_inner_adapt_steps = 3
+ config.epochs = 30000
+
+ model = paddlefsl.backbones.Conv(input_size=(1, 28, 28), output_size=config.n_way, pooling=False)
+
+ if config.dataset == 'fc100':
+ config.meta_lr = 0.002
+ config.inner_lr = 0.05
+ config.test_epoch = 10
+ config.meta_batch_size = 32
+ config.train_inner_adapt_steps = 5
+ config.test_inner_adapt_steps = 10
+ config.epochs = 5000
+
+ model = paddlefsl.backbones.Conv(input_size=(3, 32, 32), output_size=config.n_way)
+
+ return model
+
+
+if __name__ == '__main__':
+
+ config = Config().parse_args(known_only=True)
+ config.device = 'gpu'
+ if not config.dataset:
+ # config.dataset = 'omniglot'
+ # config.dataset = 'miniimagenet'
+ config.dataset = 'cifarfs'
+ # config.dataset = 'fc100'
+ # config.dataset = 'cub'
+
+ config.tracking_uri = os.environ.get('TRACKING_URI', None)
+ config.experiment_id = os.environ.get('EXPERIMENT_ID', None)
+
+ train_dataset, valid_dataset, test_dataset = load_datasets(config.dataset)
+ model = init_models(config)
+
+ criterion = nn.CrossEntropyLoss()
+ learner = MAMLLearner(
+ module=model,
+ learning_rate=config.inner_lr,
+ )
+ scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=config.meta_lr, T_max=config.epochs)
+ optimizer = Adam(parameters=learner.parameters(), learning_rate=scheduler)
+ trainer = Trainer(
+ config=config,
+ train_dataset=train_dataset,
+ dev_dataset=valid_dataset,
+ test_dataset=test_dataset,
+ learner=learner,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ criterion=criterion
+ )
+ trainer.train()
\ No newline at end of file
diff --git a/PaddleFSL/examples/optim/meta_trainer.py b/PaddleFSL/examples/optim/meta_trainer.py
new file mode 100644
index 0000000..4bb2fb6
--- /dev/null
+++ b/PaddleFSL/examples/optim/meta_trainer.py
@@ -0,0 +1,325 @@
+"""MAML example for optimization"""
+from __future__ import annotations
+import os
+from typing import Optional, Tuple
+import warnings
+from loguru import logger
+
+import paddle
+from paddle.optimizer import Optimizer
+from paddle.optimizer.lr import LRScheduler
+from paddle.nn import Layer
+from paddle.metric.metrics import Accuracy
+from tap import Tap
+from tqdm import tqdm
+from mlflow.tracking import MlflowClient
+from mlflow import set_tag
+import numpy as np
+
+import paddlefsl
+from paddlefsl.datasets.cv_dataset import CVDataset
+from paddlefsl.metaopt.base_learner import BaseLearner
+import paddlefsl
+from paddlefsl.datasets.cv_dataset import CVDataset
+
+
+def load_datasets(name: str) -> Tuple[CVDataset, CVDataset, CVDataset]:
+ """load CV Dataset by name, which can be omniglot, miniimagenet, or cifar10
+
+ Args:
+ name (str): the name of datasets
+
+ Returns:
+ Tuple[CVDataset, CVDataset, CVDataset]: train, dev, test dataset
+ """
+ if name == "omniglot":
+ return (
+ paddlefsl.datasets.Omniglot(mode='train', image_size=(28, 28)),
+ paddlefsl.datasets.Omniglot(mode='valid', image_size=(28, 28)),
+ paddlefsl.datasets.Omniglot(mode='test', image_size=(28, 28))
+ )
+ if name == "miniimagenet":
+ return (
+ paddlefsl.datasets.MiniImageNet(mode='train'),
+ paddlefsl.datasets.MiniImageNet(mode='valid'),
+ paddlefsl.datasets.MiniImageNet(mode='test')
+ )
+ if name == "cifarfs":
+ return (
+ paddlefsl.datasets.CifarFS(mode='train', image_size=(28, 28)),
+ paddlefsl.datasets.CifarFS(mode='valid', image_size=(28, 28)),
+ paddlefsl.datasets.CifarFS(mode='test', image_size=(28, 28))
+ )
+ if name == "fc100":
+ return (
+ paddlefsl.datasets.FC100(mode='train'),
+ paddlefsl.datasets.FC100(mode='valid'),
+ paddlefsl.datasets.FC100(mode='test')
+ )
+ if name == "cub":
+ return (
+ paddlefsl.datasets.CubFS(mode='train'),
+ paddlefsl.datasets.CubFS(mode='valid'),
+ paddlefsl.datasets.CubFS(mode='test')
+ )
+ raise ValueError(f"the dataset name: <{name}> is not supported")
+
+
+class Config(Tap):
+ """Alernative for Argument Parse"""
+ dataset: str = ''
+ input_size: Optional[str] = None
+ n_way: int = 5
+ k_shot: int = 1
+ meta_lr: float = 0.005
+ inner_lr: float = 0.5
+ epochs: int = 60000 # also named as iterations
+ test_epoch: int = 10
+ eval_iters: int = 10
+
+ meta_batch_size: int = 32
+ test_batch_size: int = 10
+
+ train_inner_adapt_steps: int = 1
+ test_inner_adapt_steps: int = 1
+
+ approximate: bool = True
+
+ do_eval_step: int = 30
+ do_test_step: int = 100
+ save_model_iter: int = 5000
+ save_model_root: str = '~/trained_models'
+ test_param_file: str = 'iteration60000.params'
+
+ device: str = 'cpu'
+
+ tracking_uri: str = ''
+ experiment_id: str = '0'
+ run_id: str = ''
+
+ def place(self):
+ """get the default device place for tensor"""
+ return paddle.fluid.CUDAPlace(0) if self.device == 'gpu' else paddle.fluid.CPUPlace()
+
+ def get_input_size(self) -> Tuple[int, int, int]:
+ """get the input size based on the datasets"""
+ if self.dataset in ['omniglot']:
+ return (1, 28, 28)
+
+ if self.dataset in ['cifarfs']:
+ return (3, 32, 32)
+
+ if self.dataset == 'miniimagenet':
+ return (3, 84, 84)
+
+ if self.dataset == 'fc100':
+ return (3, 32, 32)
+
+ if self.dataset == 'cub':
+ return (3, 84, 84)
+
+ if not self.input_size:
+ return None
+
+ return tuple(map(int, self.input_size.split(',')))
+
+class ContextData:
+ """context data to store the training and testing results"""
+ def __init__(self) -> None:
+ self.train_epoch = 0
+ self.epoch = 0
+
+ self.train_loss = 0
+ self.train_acc = 0
+
+ self.dev_loss = 0
+ self.dev_acc = 0
+
+
+class Trainer:
+ """Trainer for meta training epoch"""
+ def __init__(
+ self,
+ config: Config,
+ train_dataset: CVDataset,
+ dev_dataset: CVDataset,
+ test_dataset: CVDataset,
+ learner: BaseLearner,
+ optimizer: Optimizer,
+ scheduler: LRScheduler,
+ criterion: Layer,
+
+ ) -> None:
+ self.config = config
+
+ self.train_dataset = train_dataset
+ self.dev_dataset = dev_dataset
+ self.test_dataset = test_dataset
+
+ self.criterion = criterion
+ self.learner = learner
+
+ self.train_bar = tqdm(total=self.config.epochs, desc='Train Progress')
+ self.context = ContextData()
+
+ self.metric = Accuracy()
+
+ self.scheduler = scheduler
+ self.optimizer = optimizer
+
+ self._set_device()
+ warnings.filterwarnings("ignore")
+
+ if self.config.tracking_uri:
+ self.client = MlflowClient(tracking_uri=self.config.tracking_uri)
+
+ run = self.client.create_run(self.config.experiment_id)
+ self.config.run_id = run.info.run_id
+
+ for key, value in self.config.as_dict().items():
+ self.client.log_param(
+ self.config.run_id,
+ key=key,
+ value=value
+ )
+
+ self.client.log_param(self.config.run_id, 'learner', value=self.learner.__class__.__name__)
+ set_tag("mlflow.runName", self.learner.__class__.__name__)
+
+ learner_name = learner.__class__.__name__
+ file_name = f'{config.dataset}-{learner_name}-{config.run_id}.log'
+ logger.add(os.path.join('logs', file_name))
+ logger.info(self.config)
+
+ def _set_device(self):
+ paddle.device.set_device(self.config.device)
+ if paddle.distributed.get_world_size() > 1:
+ paddle.distributed.init_parallel_env()
+
+ def on_train_epoch_end(self):
+ """handle the end of one epoch"""
+ self.context.epoch += 1
+
+ self.train_bar.update()
+
+ bar_info = f'Epoch: {self.context.epoch}/{self.config.epochs} \t train-loss: {self.context.train_loss} \t\t train-acc: {self.context.train_acc}'
+ self.train_bar.set_description(bar_info)
+
+ if self.config.tracking_uri:
+ self.client.log_metric(self.config.run_id, key='train-loss', value=self.context.train_loss)
+ self.client.log_metric(self.config.run_id, key='train-acc', value=self.context.train_acc)
+
+ def compute_loss(self, input_data, labels, learner: BaseLearner):
+ """compute the loss based on the input_data and labels"""
+ input_data, labels = paddle.to_tensor(input_data, dtype='float32'), paddle.to_tensor(labels, dtype='int64')
+
+ logits = learner(input_data)
+ loss = self.criterion(logits, labels)
+
+ acc = self.metric.compute(logits, labels)
+ acc = self.metric.update(acc)
+
+ return loss, acc
+
+ def train_epoch(self):
+ """train one epoch"""
+ self.learner.train()
+
+ self.context.train_loss = 0
+
+ train_loss, train_acc = 0, 0
+
+ self.metric.reset()
+ self.optimizer.clear_grad()
+ for _ in range(self.config.meta_batch_size):
+ task = self.train_dataset.sample_task_set(
+ ways=self.config.n_way,
+ shots=self.config.k_shot
+ )
+ learner = self.learner.clone()
+
+ # inner loop
+ for _ in range(self.config.train_inner_adapt_steps):
+ inner_loss, _ = self.compute_loss(
+ task.support_data, task.support_labels, learner
+ )
+ learner.adapt(inner_loss)
+
+ # outer loop: compute loss on the validation dataset
+ loss, acc = self.compute_loss(
+ task.query_data, task.query_labels, learner
+ )
+ train_loss += loss
+ train_acc += acc
+
+ self.optimizer.clear_grad()
+ train_loss.backward()
+ self.optimizer.step()
+ self.scheduler.step()
+ self.context.train_loss, self.context.train_acc = train_loss.numpy()[0] / self.config.meta_batch_size, train_acc / self.config.meta_batch_size
+
+ def eval(self, dataset: CVDataset, learner: BaseLearner, mode: str = 'dev'):
+ """eval the model on the dataset
+
+ Args:
+ dataset (CVDataset): the dataset to evaluate
+ learner (BaseLearner): the learner to evaluate the model
+ mode (str): the mode to evaluate, 'dev' or 'test'
+ """
+ logger.info(f'start doing {mode} on the dataset ...')
+ eval_bar = tqdm(total=self.config.test_epoch, desc=f'{mode} Bar')
+ test_loss, test_acc = [], []
+ for _ in range(self.config.test_epoch):
+ val_loss, val_acc = 0.0, 0.0
+ for _ in range(self.config.test_batch_size):
+
+ task = dataset.sample_task_set(
+ ways=self.config.n_way,
+ shots=self.config.k_shot
+ )
+ learner = self.learner.clone()
+
+ # inner loop
+ for _ in range(self.config.test_inner_adapt_steps):
+ inner_loss, _ = self.compute_loss(
+ task.support_data, task.support_labels, learner
+ )
+ learner.adapt(inner_loss)
+
+ # outer loop: compute loss on the validation dataset
+ loss, acc = self.compute_loss(
+ task.query_data, task.query_labels, learner,
+ )
+ val_loss += loss.numpy()[0]
+ val_acc += acc
+
+ test_acc.append(val_acc / self.config.test_batch_size)
+ test_loss.append(val_loss / self.config.test_batch_size)
+
+ eval_bar.update()
+ eval_bar.set_description(
+ f'acc {test_acc[-1]:.6f}'
+ )
+ mean_loss, std_loss = np.mean(test_loss), np.std(test_loss)
+ mean_acc, std_acc = np.mean(test_acc), np.std(test_acc)
+
+ logger.success(f'======================Epoch: {self.context.epoch}/{self.config.epochs}-{mode}======================')
+ logger.success(f'mean-loss: {mean_loss:.6f}, std-loss: {std_loss:.6f}')
+ logger.success(f'mean-acc: {mean_acc:.6f}, std-acc: {std_acc:.6f}')
+ logger.success('================================================')
+ if self.config.tracking_uri:
+ self.client.log_metric(self.config.run_id, f'{mode}-mean-loss', mean_loss)
+ self.client.log_metric(self.config.run_id, f'{mode}-std-loss', std_loss)
+ self.client.log_metric(self.config.run_id, f'{mode}-mean-acc', mean_acc)
+ self.client.log_metric(self.config.run_id, f'{mode}-std-acc', std_acc)
+
+ def train(self):
+ """handle the main train"""
+ for epoch in range(1, self.config.epochs + 1):
+ self.train_epoch()
+ self.on_train_epoch_end()
+ if epoch % self.config.do_eval_step == 0:
+ self.eval(self.dev_dataset, self.learner, 'eval')
+
+ if epoch % self.config.do_test_step ==0:
+ self.eval(self.test_dataset, self.learner, 'test')
diff --git a/PaddleFSL/paddlefsl/metaopt/__init__.py b/PaddleFSL/paddlefsl/metaopt/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/PaddleFSL/paddlefsl/metaopt/anil.py b/PaddleFSL/paddlefsl/metaopt/anil.py
new file mode 100644
index 0000000..4fecb11
--- /dev/null
+++ b/PaddleFSL/paddlefsl/metaopt/anil.py
@@ -0,0 +1,56 @@
+"""ANIL Meta Learner"""
+from __future__ import annotations
+
+from paddle.nn import Layer
+
+from paddlefsl.utils.manual_gradient_descent import manual_gradient_descent
+from paddlefsl.utils.clone_model import clone_model
+from paddlefsl.metaopt.base_learner import BaseLearner
+
+
+class ANILLearner(BaseLearner):
+ """ANIL Meta Learner"""
+ def __init__(self, feature_model: Layer, head_layer: Layer, learning_rate: float, approximate: bool = True) -> None:
+ """The constructor of ANILearner
+
+ Args:
+ module (Layer): the model to be trained
+ optimizer (Optimizer): the optimizer to be used
+ """
+ super().__init__(head_layer)
+ self.feature_model = feature_model
+ self.learning_rate = learning_rate
+ self.approximate = approximate
+
+ def clone(self) -> ANILLearner:
+ """get the cloned model and keep the computation gragh
+
+ Returns:
+ ANILearner: the cloned model
+ """
+ cloned_head_layer = clone_model(self.module)
+ return ANILLearner(
+ feature_model=self.feature_model,
+ head_layer=cloned_head_layer,
+ learning_rate=self.learning_rate,
+ approximate=self.approximate
+ )
+
+ def adapt(self, loss) -> None:
+ """adapt the gradient descent to the module based on the loss
+
+ Args:
+ loss (Tensor): the loss of head layer
+ """
+ manual_gradient_descent(
+ model=self.module,
+ lr=self.learning_rate,
+ loss=loss,
+ approximate=self.approximate
+ )
+
+ def forward(self, inputs):
+ """forward the feature model and the head layer"""
+ y = self.feature_model(inputs)
+ y = self.module(y)
+ return y
\ No newline at end of file
diff --git a/PaddleFSL/paddlefsl/metaopt/base_learner.py b/PaddleFSL/paddlefsl/metaopt/base_learner.py
new file mode 100644
index 0000000..37d2e5d
--- /dev/null
+++ b/PaddleFSL/paddlefsl/metaopt/base_learner.py
@@ -0,0 +1,46 @@
+"""Base Learner"""
+from __future__ import annotations
+from typing import Type, TypeVar
+from abc import ABC, abstractmethod
+from ast import arg
+from turtle import forward
+
+from paddle.nn import Layer
+from paddle.optimizer import Optimizer
+
+
+Learner = TypeVar('Learner')
+
+class BaseLearner(Layer):
+ """Abstract Base Learner Class"""
+ def __init__(self, module: Layer) -> None:
+ """The constructor of BaseLearner
+
+ Args:
+ module (Layer): the model to be trained
+ """
+ super().__init__()
+ self.module = module
+
+ @abstractmethod
+ def adapt(self, loss: Tensor) -> None:
+ """Adapt the model to the current training loss
+
+ Args:
+ loss (Tensor): the current training loss
+ """
+ raise NotImplementedError
+
+ def clone(self: Type[Learner]) -> Learner:
+ """create cloned module and keep the computation gragh
+
+ Args:
+ self (Type[Learner]): the sub-learner
+
+ Returns:
+ Learner: the cloned model
+ """
+ raise NotImplementedError
+
+ def forward(self, *args, **kwargs):
+ return self.module(*args, **kwargs)
diff --git a/PaddleFSL/paddlefsl/metaopt/maml.py b/PaddleFSL/paddlefsl/metaopt/maml.py
new file mode 100644
index 0000000..7e20d45
--- /dev/null
+++ b/PaddleFSL/paddlefsl/metaopt/maml.py
@@ -0,0 +1,50 @@
+"""MAML Meta Learner"""
+from __future__ import annotations
+
+from paddle.nn import Layer
+
+from paddlefsl.utils.manual_gradient_descent import manual_gradient_descent
+from paddlefsl.metaopt.base_learner import BaseLearner
+from paddlefsl.utils.clone_model import clone_model
+
+
+class MAMLLearner(BaseLearner):
+ """MAML Meta Learner"""
+ def __init__(self, module: Layer, learning_rate: float, approximate: bool = True) -> None:
+ """The constructor of MAMLLearner
+
+ Args:
+ module (Layer): the model to be trained
+ optimizer (Optimizer): the optimizer to be used
+ """
+ super().__init__(module)
+
+ self.learning_rate = learning_rate
+ self.approximate = approximate
+
+ def clone(self,) -> Layer:
+ """get the cloned model and keep the computation gragh
+
+ Returns:
+ Layer: the cloned model
+ """
+ cloned_module = clone_model(self.module)
+
+ return MAMLLearner(
+ module=cloned_module,
+ learning_rate=self.learning_rate,
+ approximate=self.approximate
+ )
+
+ def adapt(self, loss: Tensor) -> None:
+ """adapt the gradient descent to the module based on the loss
+
+ Args:
+ loss (Tensor): _description_
+ """
+ manual_gradient_descent(
+ self.module,
+ lr=self.learning_rate,
+ loss=loss,
+ approximate=self.approximate
+ )