From 7d332097574f9f73e33eb7af296d6d46608df8d8 Mon Sep 17 00:00:00 2001 From: Alan Van Omen <46762315+alanjvano@users.noreply.github.com> Date: Mon, 15 Jan 2024 09:58:06 -0700 Subject: [PATCH] Update ARAD constructor test, change default to v2. --- riid/models/neural_nets/arad.py | 2 +- tests/model_tests.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/riid/models/neural_nets/arad.py b/riid/models/neural_nets/arad.py index 7a447a7..eaccfbe 100644 --- a/riid/models/neural_nets/arad.py +++ b/riid/models/neural_nets/arad.py @@ -238,7 +238,7 @@ def call(self, x): class ARAD(PyRIIDModel): """PyRIID-compatible ARAD model to work with SampleSets. """ - def __init__(self, arad_version: str = "v1", latent_dim=None): + def __init__(self, arad_version: str = "v2", latent_dim=None): """ Args: arad_version: version of ARAD to use, can be "v1" or "v2". diff --git a/tests/model_tests.py b/tests/model_tests.py index 7b14589..aa2b88c 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -88,10 +88,8 @@ def test_all_constructors(self): _ = MLPClassifier() _ = LabelProportionEstimator() _ = MultiEventClassifier() - arad_v1 = ARADv1TF() - _ = ARAD(arad_v1) - arad_v2 = ARADv2TF() - _ = ARAD(arad_v2) + _ = ARAD(arad_version="v1") + _ = ARAD(arad_version="v2") if __name__ == "__main__":