-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(neuralop): add the rest of tests
- Loading branch information
Showing
10 changed files
with
408 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
default: &DEFAULT | ||
|
||
#General | ||
verbose: True | ||
arch: 'tfno2d' | ||
|
||
# FNO related | ||
tfno2d: | ||
data_channels: 3 | ||
n_modes_height: 8 | ||
n_modes_width: 8 | ||
hidden_channels: 32 | ||
projection_channels: 32 | ||
n_layers: 2 | ||
domain_padding: 0 | ||
domain_padding_mode: 'symmetric' | ||
fft_norm: 'forward' | ||
norm: None | ||
skip: 'soft-gating' | ||
implementation: 'factorized' | ||
|
||
use_mlp: 1 | ||
mlp: | ||
expansion: 0.5 | ||
dropout: 0 | ||
|
||
factorization: None | ||
rank: 1.0 | ||
fixed_rank_modes: None | ||
dropout: 0.0 | ||
tensor_lasso_penalty: 0.0 | ||
joint_factorization: False | ||
|
||
data: | ||
batch_size: 4 | ||
n_train: 10 | ||
size: 32 | ||
|
||
# Optimizer | ||
opt: | ||
n_epochs: 500 | ||
learning_rate: 1e-3 | ||
training_loss: 'h1' | ||
weight_decay: 1e-4 | ||
amp_autocast: True | ||
|
||
scheduler_T_max: 500 # For cosine only, typically take n_epochs | ||
scheduler_patience: 5 # For ReduceLROnPlateau only | ||
scheduler: 'StepLR' # Or 'CosineAnnealingLR' OR 'ReduceLROnPlateau' | ||
step_size: 100 | ||
gamma: 0.5 | ||
|
||
# Patching | ||
patching: | ||
levels: 0 | ||
padding: 0 #.1 | ||
stitching: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
my_secret_key |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
|
||
import paddle | ||
import time | ||
from tensorly import tenalg | ||
tenalg.set_backend('einsum') | ||
from pathlib import Path | ||
|
||
from configmypy import ConfigPipeline, YamlConfig | ||
from neuralop import get_model | ||
|
||
def test_from_config(): | ||
"""Test forward/backward from a config file""" | ||
# Read the configuration | ||
config_name = 'default' | ||
config_path = Path(__file__).parent.as_posix() | ||
pipe = ConfigPipeline([YamlConfig('./test_config.yaml', config_name=config_name, config_folder=config_path), | ||
]) | ||
config = pipe.read_conf() | ||
config_name = pipe.steps[-1].config_name | ||
|
||
batch_size = config.data.batch_size | ||
size = config.data.size | ||
|
||
if paddle.device.cuda.device_count() >= 1: | ||
device = 'cuda' | ||
else: | ||
device = 'cpu' | ||
|
||
paddle.device.set_device(device=device) | ||
|
||
model = get_model(config) | ||
model = model | ||
|
||
in_data = paddle.randn([batch_size, 3, size, size]) | ||
print(model.__class__) | ||
print(model) | ||
|
||
t1 = time.time() | ||
out = model(in_data) | ||
t = time.time() - t1 | ||
print(f'Output of size {out.shape} in {t}.') | ||
|
||
loss = out.sum() | ||
loss.backward() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from ..utils import get_wandb_api_key, wandb_login | ||
from ..utils import count_model_params, count_tensor_params | ||
from pathlib import Path | ||
import pytest | ||
import wandb | ||
import os | ||
import paddle | ||
from paddle import nn | ||
|
||
|
||
def test_count_model_params(): | ||
# A nested dummy model to make sure all parameters are counted | ||
class DumyModel(nn.Layer): | ||
def __init__(self, n_submodels=0, dtype=paddle.float32): | ||
super().__init__() | ||
|
||
self.n_submodels = n_submodels | ||
self.param = paddle.base.framework.EagerParamBase.from_tensor(paddle.randn((2, 3, 4), dtype=dtype)) | ||
if n_submodels: | ||
self.model = DumyModel(n_submodels - 1, dtype=dtype) | ||
|
||
n_submodels = 2 | ||
model = DumyModel(n_submodels=n_submodels) | ||
n_params = count_model_params(model) | ||
print(model) | ||
assert n_params == (n_submodels+1) * 2 * 3 * 4 | ||
|
||
model = DumyModel(n_submodels=n_submodels, dtype=paddle.complex64) | ||
n_params = count_model_params(model) | ||
print(model) | ||
assert n_params == 2 * (n_submodels+1) * 2*3*4 | ||
|
||
|
||
def test_count_tensor_params(): | ||
# Case 1 : real tensor | ||
x = paddle.randn((2, 3, 4, 5, 6), dtype=paddle.float32) | ||
|
||
# dims = None: count all params | ||
n_params = count_tensor_params(x) | ||
assert n_params == 2*3*4*5*6 | ||
# Only certain dims | ||
n_params = count_tensor_params(x, dims=[1, 3]) | ||
assert n_params == 3*5 | ||
|
||
# Case 2 : complex tensor | ||
x = paddle.randn((2, 3, 4, 5, 6), dtype=paddle.complex64) | ||
|
||
# dims = None: count all params | ||
n_params = count_tensor_params(x) | ||
assert n_params == 2*3*4*5*6 * 2 | ||
# Only certain dims | ||
n_params = count_tensor_params(x, dims=[1, 3]) | ||
assert n_params == 3*5 * 2 | ||
|
||
|
||
def test_get_wandb_api_key(): | ||
# Make sure no env var key set | ||
os.environ.pop("WANDB_API_KEY", None) | ||
|
||
# Read from file | ||
filepath = Path(__file__).parent.joinpath('test_config_key.txt').as_posix() | ||
key = get_wandb_api_key(filepath) | ||
assert key == 'my_secret_key' | ||
|
||
# Read from env var | ||
os.environ["WANDB_API_KEY"] = 'key_from_env' | ||
key = get_wandb_api_key(filepath) | ||
assert key == 'key_from_env' | ||
|
||
# Read from env var with incorrect file | ||
os.environ["WANDB_API_KEY"] = 'key_from_env' | ||
key = get_wandb_api_key('wrong_path') | ||
assert key == 'key_from_env' | ||
|
||
|
||
def test_ArgparseConfig(monkeypatch): | ||
def login(key): | ||
if key == 'my_secret_key': | ||
return True | ||
|
||
raise ValueError('Wrong key') | ||
|
||
monkeypatch.setattr(wandb, "login", login) | ||
|
||
# Make sure no env var key set | ||
os.environ.pop("WANDB_API_KEY", None) | ||
|
||
# Read from file | ||
filepath = Path(__file__).parent.joinpath('test_config_key.txt').as_posix() | ||
assert wandb_login(filepath) is None | ||
|
||
# Read from env var | ||
os.environ["WANDB_API_KEY"] = 'my_secret_key' | ||
assert wandb_login() is None | ||
|
||
# Read from env var | ||
os.environ["WANDB_API_KEY"] = 'wrong_key' | ||
assert wandb_login(key='my_secret_key') is None | ||
|
||
# Read from env var | ||
os.environ["WANDB_API_KEY"] = 'wrong_key' | ||
with pytest.raises(ValueError): | ||
wandb_login() |
Oops, something went wrong.