diff --git a/jetnet/evaluation/gen_metrics.py b/jetnet/evaluation/gen_metrics.py index 30092bf..0052282 100644 --- a/jetnet/evaluation/gen_metrics.py +++ b/jetnet/evaluation/gen_metrics.py @@ -707,10 +707,10 @@ def fpd( stacklevel=2, ) - real_features, gen_features = _check_get_ndarray(real_features, gen_features) + X, Y = _check_get_ndarray(real_features, gen_features) if normalise: - X, Y = _normalise_features(real_features, gen_features) + X, Y = _normalise_features(X, Y) # regular intervals in 1/N batches = (1 / np.linspace(1.0 / min_samples, 1.0 / max_samples, num_points)).astype("int32") @@ -836,10 +836,10 @@ def kpd( Returns: Tuple[float, float]: median and error of KPD. """ - real_features, gen_features = _check_get_ndarray(real_features, gen_features) + X, Y = _check_get_ndarray(real_features, gen_features) if normalise: - X, Y = _normalise_features(real_features, gen_features) + X, Y = _normalise_features(X, Y) if num_threads is None: vals_point = _kpd_batches(X, Y, num_batches, batch_size, seed) diff --git a/tests/evaluation/test_gen_metrics.py b/tests/evaluation/test_gen_metrics.py index 285dbc8..533a1f9 100644 --- a/tests/evaluation/test_gen_metrics.py +++ b/tests/evaluation/test_gen_metrics.py @@ -7,6 +7,7 @@ test_zeros = np.zeros((50_000, 2)) test_ones = np.ones((50_000, 2)) +test_twos = np.ones((50_000, 2)) * 2 def test_fpd(): @@ -14,12 +15,25 @@ def test_fpd(): assert val == approx(0, abs=0.01) assert err < 1e-3 - val, err = evaluation.fpd(test_zeros, test_ones) - assert val == approx(2, rel=0.01) + val, err = evaluation.fpd(test_twos, test_zeros) + assert val == approx(2, rel=0.01) # 1^2 + 1^2 + assert err < 1e-3 + + # test normalization + val, err = evaluation.fpd(test_zeros, test_zeros, normalise=False) # should have no effect + assert val == approx(0, abs=0.01) + assert err < 1e-3 + + val, err = evaluation.fpd(test_twos, test_zeros, normalise=False) + assert val == approx(8, rel=0.01) # 2^2 + 2^2 assert err < 1e-3 @pytest.mark.parametrize("num_threads", [None, 2]) # test numba parallelization def test_kpd(num_threads): assert evaluation.kpd(test_zeros, test_zeros, num_threads=num_threads) == approx([0, 0]) - assert evaluation.kpd(test_zeros, test_ones, num_threads=num_threads) == approx([15, 0]) + assert evaluation.kpd(test_twos, test_zeros, num_threads=num_threads) == approx([15, 0]) + + # test normalization + assert evaluation.kpd(test_zeros, test_zeros, normalise=False, num_threads=num_threads) == approx([0, 0]) + assert evaluation.kpd(test_twos, test_zeros, normalise=False, num_threads=num_threads) == approx([624, 0]) \ No newline at end of file