Skip to content

Commit

Permalink
fix: fix confusion matrix calc error
Browse files Browse the repository at this point in the history
  • Loading branch information
4o3F committed Oct 8, 2024
1 parent 0c0bbaa commit 092361f
Showing 1 changed file with 48 additions and 61 deletions.
109 changes: 48 additions & 61 deletions src/common/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ pub fn calc_iou(target_img: &String, gt_img: &String) {
return;
}

let intersection: Arc<Mutex<HashMap<(u8, u8, u8), i64>>> = Arc::new(Mutex::new(HashMap::new()));
let union: Arc<Mutex<HashMap<(u8, u8, u8), i64>>> = Arc::new(Mutex::new(HashMap::new()));
let confusion_matrix: Arc<Mutex<HashMap<(u8, u8, u8), HashMap<(u8, u8, u8), i64>>>> =
Arc::new(Mutex::new(HashMap::new()));

Expand All @@ -35,87 +33,43 @@ pub fn calc_iou(target_img: &String, gt_img: &String) {
let row_progress = row_iter.items_processed();
let row_total = row_iter.len();
row_iter.for_each(|i| {
let mut row_intersection: HashMap<(u8, u8, u8), i64> = HashMap::new();
let mut row_union: HashMap<(u8, u8, u8), i64> = HashMap::new();
let mut row_confusion_matrix: HashMap<(u8, u8, u8), HashMap<(u8, u8, u8), i64>> =
HashMap::new();
for j in 0..cols {
let pixel1 = target_img
let predicted_pixel = target_img
.at_2d::<core::Vec3b>(i, j)
.expect_or_log("Get output pixel error");
let pixel2 = gt_img
let true_pixel = gt_img
.at_2d::<core::Vec3b>(i, j)
.expect_or_log("Get ground truth pixel error");

let color1 = (pixel1[0], pixel1[1], pixel1[2]);
let color2 = (pixel2[0], pixel2[1], pixel2[2]);
let predicted_color = (predicted_pixel[0], predicted_pixel[1], predicted_pixel[2]);
let true_color = (true_pixel[0], true_pixel[1], true_pixel[2]);

{
*row_union.entry(color1).or_insert(0) += 1;
*row_union.entry(color2).or_insert(0) += 1;

if pixel1 == pixel2 {
*row_intersection.entry(color1).or_insert(0) += 1;
}

let entry = row_confusion_matrix
.entry(color2)
.entry(true_color)
.or_insert_with(HashMap::new);
*entry.entry(color1).or_insert(0) += 1;

*entry.entry(predicted_color).or_insert(0) += 1;
}
}

let mut intersection = intersection.lock().unwrap();
let mut union = union.lock().unwrap();
let mut confusion_matrix = confusion_matrix.lock().unwrap();

for (color, value) in row_intersection.into_iter() {
*intersection.entry(color).or_insert(0) += value;
}
for (color, value) in row_union.into_iter() {
*union.entry(color).or_insert(0) += value;
}
for (color, value) in row_confusion_matrix.into_iter() {
*confusion_matrix.entry(color).or_insert_with(HashMap::new) = value;
for (true_color, value) in row_confusion_matrix.into_iter() {
let entry = confusion_matrix
.entry(true_color)
.or_insert_with(HashMap::new);
for (predicted_color, count) in value.into_iter() {
*entry.entry(predicted_color).or_insert(0) += count;
}
}
if row_progress.get() != 0 && row_progress.get() % 1000 == 0 {
if row_progress.get() != 0 && row_progress.get() % 1000 == 0 {
tracing::info!("Row {} / {} done", row_progress.get(), row_total);
}
});

let mut iou = HashMap::new();
let mut total_iou = 0.0;
let mut num_categories = 0;

let intersection = intersection.lock().unwrap();
let union = union.lock().unwrap();
let confusion_matrix = confusion_matrix.lock().unwrap();
for (&color, &inter) in &*intersection {
let uni = union.get(&color).unwrap_or(&0);
if *uni > 0 {
let iou_value = inter as f64 / *uni as f64;
iou.insert(color, iou_value);
total_iou += iou_value;
num_categories += 1;
}
}

let mean_iou = if num_categories > 0 {
total_iou / num_categories as f64
} else {
0.0
};

for (color, &iou_value) in &iou {
tracing::info!(
"IoU for color RGB({},{},{}): {}",
color.0,
color.1,
color.2,
iou_value
);
}
tracing::info!("Mean IoU: {}", mean_iou);
tracing::info!("Confusion Matrix:");
for (true_color, predictions) in &*confusion_matrix {
for (predicted_color, count) in predictions {
Expand All @@ -131,4 +85,37 @@ pub fn calc_iou(target_img: &String, gt_img: &String) {
);
}
}

let mut iou_results: HashMap<(u8, u8, u8), f64> = HashMap::new();
let mut total_intersection: HashMap<(u8, u8, u8), i64> = HashMap::new();
let mut total_union: HashMap<(u8, u8, u8), i64> = HashMap::new();

for (true_color, predictions) in &*confusion_matrix {
let mut intersection = 0;
let mut union = 0;

for (predicted_color, count) in predictions {
intersection += *count;
union += count + *total_union.entry(predicted_color.clone()).or_insert(0);
}

total_intersection.insert(*true_color, intersection);
total_union.insert(*true_color, union);

if union > 0 {
let iou = intersection as f64 / union as f64;
iou_results.insert(*true_color, iou);
}
}

tracing::info!("IoU Results:");
for (color, iou) in &iou_results {
tracing::info!(
"Color RGB({},{},{}) IoU: {:.4}",
color.0,
color.1,
color.2,
iou
);
}
}

0 comments on commit 092361f

Please sign in to comment.