diff --git a/tests/test_hausdorff_distance_morphologic.py b/tests/test_hausdorff_distance_morphologic.py index 82537f360c..3aa96cdd8b 100644 --- a/tests/test_hausdorff_distance_morphologic.py +++ b/tests/test_hausdorff_distance_morphologic.py @@ -21,11 +21,11 @@ device = torch.device("cuda") -dimA = 11 +dimA = 50 -dimAA = 150 -dimBB = 160 -dimCC = 170 +dimAA = 170 +dimBB = 190 +dimCC = 200 # testing single points diffrent dims # dim1 @@ -101,18 +101,18 @@ TEST_CASES = [ [[a, b, 1.0, compare_values], 10], - # [[a1, b1, 1.0, compare_values], 15], - # [[a2, b2, 1.0, compare_values], 140], - # [[a3, b3, 1.0, compare_values_b], 140], - # [[a4, b4, 1.0, compare_values_b], 110], - # [[a5, b5, 1.0, compare_values_b], 110], - # [[a6, b6, 1.0, compare_values_b], 110], - # [[a7, b7, 1.0, compare_values_b], 110], - # [[a8, b8, 1.0, compare_values_b], 110], # testing robust - # [[a6, b6, 0.9, compare_values_b], 110], - # [[a7, b7, 0.85, compare_values_b], 110], - # [[a8, b8, 0.8, compare_values_b], 110], # multi points - # [[a9, b9, 1.0, compare_values_b], 40] + [[a1, b1, 1.0, compare_values], 15], + [[a2, b2, 1.0, compare_values], 140], + [[a3, b3, 1.0, compare_values_b], 140], + [[a4, b4, 1.0, compare_values_b], 110], + [[a5, b5, 1.0, compare_values_b], 110], + [[a6, b6, 1.0, compare_values_b], 110], + [[a7, b7, 1.0, compare_values_b], 110], + [[a8, b8, 1.0, compare_values_b], 110], # testing robust + [[a6, b6, 0.9, compare_values_b], 110], + [[a7, b7, 0.85, compare_values_b], 110], + [[a8, b8, 0.8, compare_values_b], 110], # multi points + [[a9, b9, 1.0, compare_values_b], 40] ] @@ -123,7 +123,7 @@ def test_value(self, input_data, expected_value): if(not version_leq(f"{torch.version.cuda}", "10.100") and not version_leq(f"{torch.version.cuda}", "10.200")): [y_pred, y, percentt, compare_values] = input_data hd_metric = MorphologicalHausdorffDistanceMetric( - compare_values.to(device), percentt, False + compare_values.to(device), percentt, True ) # True only for tests result = hd_metric._compute_tensor(y_pred.to(device), y.to(device)) np.testing.assert_allclose(expected_value, result, rtol=1e-7)