From 54430579357f235fbf3bfa57165df7fe7900386a Mon Sep 17 00:00:00 2001 From: Felix Pei <64850082+felixp8@users.noreply.github.com> Date: Tue, 6 Feb 2024 12:17:45 +0100 Subject: [PATCH] register c2st --- labproject/metrics/c2st.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/labproject/metrics/c2st.py b/labproject/metrics/c2st.py index b7d594d..a55ffc4 100644 --- a/labproject/metrics/c2st.py +++ b/labproject/metrics/c2st.py @@ -10,10 +10,13 @@ from sklearn.neighbors import KNeighborsClassifier from sklearn.neural_network import MLPClassifier +from labproject.metrics.utils import register_metric + # from sbi: https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py +@register_metric("c2st_nn") def c2st_nn( X: Tensor, Y: Tensor, @@ -101,6 +104,7 @@ def c2st_nn( return value +@register_metric("c2st_rf") def c2st_rf( X: Tensor, Y: Tensor, @@ -179,6 +183,7 @@ def c2st_rf( return value +@register_metric("c2st_knn") def c2st_knn( X: Tensor, Y: Tensor,