Skip to content

Commit

Permalink
register c2st
Browse files Browse the repository at this point in the history
  • Loading branch information
felixp8 authored Feb 6, 2024
1 parent 0ec8f76 commit 5443057
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions labproject/metrics/c2st.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier

from labproject.metrics.utils import register_metric


# from sbi: https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py


@register_metric("c2st_nn")
def c2st_nn(
X: Tensor,
Y: Tensor,
Expand Down Expand Up @@ -101,6 +104,7 @@ def c2st_nn(
return value


@register_metric("c2st_rf")
def c2st_rf(
X: Tensor,
Y: Tensor,
Expand Down Expand Up @@ -179,6 +183,7 @@ def c2st_rf(
return value


@register_metric("c2st_knn")
def c2st_knn(
X: Tensor,
Y: Tensor,
Expand Down

0 comments on commit 5443057

Please sign in to comment.