diff --git a/sdmetrics/single_table/data_augmentation/base.py b/sdmetrics/single_table/data_augmentation/base.py index f2cf75a7..9a6d6eaa 100644 --- a/sdmetrics/single_table/data_augmentation/base.py +++ b/sdmetrics/single_table/data_augmentation/base.py @@ -207,6 +207,7 @@ def compute_breakdown( fixed_recall_value, cls.metric_name, ) + metric_to_fix = 'recall' if cls.metric_name == 'precision' else 'precision' result = { 'real_data_baseline': trainer.get_scores( preprocessed_tables['real_training_data'], @@ -223,7 +224,7 @@ def compute_breakdown( 'prediction_column_name': trainer.prediction_column_name, 'minority_class_label': trainer.minority_class_label, 'classifier': trainer._classifier_name, - 'fixed_recall_value': trainer.fixed_value, + f'fixed_{metric_to_fix}_value': trainer.fixed_value, }, } result['score'] = max(