From 092361f95bc1a91f63db5576cd9eef17d612d282 Mon Sep 17 00:00:00 2001 From: 4o3F <4o3f@proton.me> Date: Tue, 8 Oct 2024 12:56:19 +0800 Subject: [PATCH] fix: fix confusion matrix calc error --- src/common/metric.rs | 109 +++++++++++++++++++------------------------ 1 file changed, 48 insertions(+), 61 deletions(-) diff --git a/src/common/metric.rs b/src/common/metric.rs index 9c60bab..b8728ff 100644 --- a/src/common/metric.rs +++ b/src/common/metric.rs @@ -23,8 +23,6 @@ pub fn calc_iou(target_img: &String, gt_img: &String) { return; } - let intersection: Arc>> = Arc::new(Mutex::new(HashMap::new())); - let union: Arc>> = Arc::new(Mutex::new(HashMap::new())); let confusion_matrix: Arc>>> = Arc::new(Mutex::new(HashMap::new())); @@ -35,87 +33,43 @@ pub fn calc_iou(target_img: &String, gt_img: &String) { let row_progress = row_iter.items_processed(); let row_total = row_iter.len(); row_iter.for_each(|i| { - let mut row_intersection: HashMap<(u8, u8, u8), i64> = HashMap::new(); - let mut row_union: HashMap<(u8, u8, u8), i64> = HashMap::new(); let mut row_confusion_matrix: HashMap<(u8, u8, u8), HashMap<(u8, u8, u8), i64>> = HashMap::new(); for j in 0..cols { - let pixel1 = target_img + let predicted_pixel = target_img .at_2d::(i, j) .expect_or_log("Get output pixel error"); - let pixel2 = gt_img + let true_pixel = gt_img .at_2d::(i, j) .expect_or_log("Get ground truth pixel error"); - let color1 = (pixel1[0], pixel1[1], pixel1[2]); - let color2 = (pixel2[0], pixel2[1], pixel2[2]); + let predicted_color = (predicted_pixel[0], predicted_pixel[1], predicted_pixel[2]); + let true_color = (true_pixel[0], true_pixel[1], true_pixel[2]); { - *row_union.entry(color1).or_insert(0) += 1; - *row_union.entry(color2).or_insert(0) += 1; - - if pixel1 == pixel2 { - *row_intersection.entry(color1).or_insert(0) += 1; - } - let entry = row_confusion_matrix - .entry(color2) + .entry(true_color) .or_insert_with(HashMap::new); - *entry.entry(color1).or_insert(0) += 1; + + *entry.entry(predicted_color).or_insert(0) += 1; } } - - let mut intersection = intersection.lock().unwrap(); - let mut union = union.lock().unwrap(); let mut confusion_matrix = confusion_matrix.lock().unwrap(); - for (color, value) in row_intersection.into_iter() { - *intersection.entry(color).or_insert(0) += value; - } - for (color, value) in row_union.into_iter() { - *union.entry(color).or_insert(0) += value; - } - for (color, value) in row_confusion_matrix.into_iter() { - *confusion_matrix.entry(color).or_insert_with(HashMap::new) = value; + for (true_color, value) in row_confusion_matrix.into_iter() { + let entry = confusion_matrix + .entry(true_color) + .or_insert_with(HashMap::new); + for (predicted_color, count) in value.into_iter() { + *entry.entry(predicted_color).or_insert(0) += count; + } } - if row_progress.get() != 0 && row_progress.get() % 1000 == 0 { + if row_progress.get() != 0 && row_progress.get() % 1000 == 0 { tracing::info!("Row {} / {} done", row_progress.get(), row_total); } }); - let mut iou = HashMap::new(); - let mut total_iou = 0.0; - let mut num_categories = 0; - - let intersection = intersection.lock().unwrap(); - let union = union.lock().unwrap(); let confusion_matrix = confusion_matrix.lock().unwrap(); - for (&color, &inter) in &*intersection { - let uni = union.get(&color).unwrap_or(&0); - if *uni > 0 { - let iou_value = inter as f64 / *uni as f64; - iou.insert(color, iou_value); - total_iou += iou_value; - num_categories += 1; - } - } - - let mean_iou = if num_categories > 0 { - total_iou / num_categories as f64 - } else { - 0.0 - }; - - for (color, &iou_value) in &iou { - tracing::info!( - "IoU for color RGB({},{},{}): {}", - color.0, - color.1, - color.2, - iou_value - ); - } - tracing::info!("Mean IoU: {}", mean_iou); tracing::info!("Confusion Matrix:"); for (true_color, predictions) in &*confusion_matrix { for (predicted_color, count) in predictions { @@ -131,4 +85,37 @@ pub fn calc_iou(target_img: &String, gt_img: &String) { ); } } + + let mut iou_results: HashMap<(u8, u8, u8), f64> = HashMap::new(); + let mut total_intersection: HashMap<(u8, u8, u8), i64> = HashMap::new(); + let mut total_union: HashMap<(u8, u8, u8), i64> = HashMap::new(); + + for (true_color, predictions) in &*confusion_matrix { + let mut intersection = 0; + let mut union = 0; + + for (predicted_color, count) in predictions { + intersection += *count; + union += count + *total_union.entry(predicted_color.clone()).or_insert(0); + } + + total_intersection.insert(*true_color, intersection); + total_union.insert(*true_color, union); + + if union > 0 { + let iou = intersection as f64 / union as f64; + iou_results.insert(*true_color, iou); + } + } + + tracing::info!("IoU Results:"); + for (color, iou) in &iou_results { + tracing::info!( + "Color RGB({},{},{}) IoU: {:.4}", + color.0, + color.1, + color.2, + iou + ); + } }