From 47471b5b5170c71a47b37b7ba60e8ef5d67c219d Mon Sep 17 00:00:00 2001 From: 4o3F <4o3f@proton.me> Date: Tue, 8 Oct 2024 19:59:52 +0800 Subject: [PATCH] feat: boost class2rgb & add thread pool size arg --- src/common/augment.rs | 13 ++-- src/common/convert.rs | 7 ++- src/common/dataset.rs | 14 ++++- src/common/operation.rs | 6 +- src/common/remap.rs | 134 +++++++++++++++++++++++++--------------- src/main.rs | 27 ++++++-- src/yolo/convert.rs | 6 +- src/yolo/dataset.rs | 7 ++- 8 files changed, 143 insertions(+), 71 deletions(-) diff --git a/src/common/augment.rs b/src/common/augment.rs index a5e1c73..8d68396 100644 --- a/src/common/augment.rs +++ b/src/common/augment.rs @@ -9,6 +9,8 @@ use tokio::{sync::Semaphore, task::JoinSet}; use opencv::{core, imgcodecs, prelude::*}; use tracing_unwrap::{OptionExt, ResultExt}; +use crate::THREAD_POOL; + pub async fn split_images(dataset_path: &String, target_height: &u32, target_width: &u32) { let entries = fs::read_dir(dataset_path).expect_or_log("Failed to read directory"); let mut threads = JoinSet::new(); @@ -426,7 +428,9 @@ pub async fn split_images_with_filter( ); } - let sem = Arc::new(Semaphore::new(5)); + let sem = Arc::new(Semaphore::new( + (*THREAD_POOL.read().expect_or_log("Get pool error")).into(), + )); let valid_id = Arc::new(RwLock::new(Vec::::new())); let mut threads = tokio::task::JoinSet::new(); @@ -762,12 +766,7 @@ pub async fn stich_images(splited_images: &String, target_height: &i32, target_w tracing::trace!("RTL x {} y {}", x, y); let mut roi = Mat::roi_mut( &mut result_mat, - core::Rect::new( - x, - y, - size.unwrap().0, - size.unwrap().1, - ), + core::Rect::new(x, y, size.unwrap().0, size.unwrap().1), ) .expect_or_log("Failed to create roi"); img.copy_to(&mut roi).expect_or_log("Failed to copy image"); diff --git a/src/common/convert.rs b/src/common/convert.rs index d83be74..feb6b44 100644 --- a/src/common/convert.rs +++ b/src/common/convert.rs @@ -10,6 +10,9 @@ use cocotools::{ }; use image::{Rgb, RgbImage}; use tokio::{sync::Semaphore, task::JoinSet}; +use tracing_unwrap::ResultExt; + +use crate::THREAD_POOL; fn mask_image_array(image: &RgbImage, rgb: Rgb) -> ndarray::Array2 { let mut mask = ndarray::Array2::::zeros((image.height() as usize, image.width() as usize)); @@ -55,7 +58,9 @@ pub async fn rgb2rle(dataset_path: &String, rgb_list: &str) { } let mut threads = JoinSet::new(); - let sem = Arc::new(Semaphore::new(10)); + let sem = Arc::new(Semaphore::new( + (*THREAD_POOL.read().expect_or_log("Get pool error")).into(), + )); let image_count = Arc::new(Mutex::new(0)); let annotation_count = Arc::new(Mutex::new(0)); diff --git a/src/common/dataset.rs b/src/common/dataset.rs index c5b62b1..48fe982 100644 --- a/src/common/dataset.rs +++ b/src/common/dataset.rs @@ -11,10 +11,14 @@ use opencv::{ use tokio::{sync::Semaphore, task::JoinSet}; use tracing_unwrap::{OptionExt, ResultExt}; +use crate::THREAD_POOL; + pub async fn split_dataset(dataset_path: &String, train_ratio: &f32) { let entries = fs::read_dir(dataset_path).unwrap(); let mut threads = JoinSet::new(); - let sem = Arc::new(Semaphore::new(10)); + let sem = Arc::new(Semaphore::new( + (*THREAD_POOL.read().expect_or_log("Get pool error")).into(), + )); let result = Arc::new(Mutex::new(Vec::::new())); for entry in entries { let entry = entry.unwrap(); @@ -59,7 +63,9 @@ pub async fn count_classes(dataset_path: &String) { let entries = fs::read_dir(dataset_path).unwrap(); let type_map = Arc::new(Mutex::new(HashMap::::new())); - let sem = Arc::new(Semaphore::new(5)); + let sem = Arc::new(Semaphore::new( + (*THREAD_POOL.read().expect_or_log("Get pool error")).into(), + )); let mut threads = JoinSet::new(); for entry in entries { let entry = entry.unwrap(); @@ -127,7 +133,9 @@ pub async fn calc_mean_std(dataset_path: &String) { let min_value = Arc::new(Mutex::new(f64::MAX)); let max_value = Arc::new(Mutex::new(f64::MIN)); - let sem = Arc::new(Semaphore::new(1)); + let sem = Arc::new(Semaphore::new( + (*THREAD_POOL.read().expect_or_log("Get pool error")).into(), + )); for entry in entries { let entry = entry.expect_or_log("Failed to iterate entries"); if !entry.path().is_file() { diff --git a/src/common/operation.rs b/src/common/operation.rs index d57605f..6bb1706 100644 --- a/src/common/operation.rs +++ b/src/common/operation.rs @@ -7,6 +7,8 @@ use std::{fs, io::Cursor, sync::Arc}; use tokio::{fs::File, io::AsyncWriteExt, sync::Semaphore, task::JoinSet}; use tracing_unwrap::{OptionExt, ResultExt}; +use crate::THREAD_POOL; + pub async fn resize_images( dataset_path: &String, target_height: &u32, @@ -27,7 +29,9 @@ pub async fn resize_images( let entries = fs::read_dir(dataset_path).unwrap(); let mut threads = JoinSet::new(); - let sem = Arc::new(Semaphore::new(10)); + let sem = Arc::new(Semaphore::new( + (*THREAD_POOL.read().expect_or_log("Get pool error")).into(), + )); fs::create_dir_all(format!("{}\\..\\resized\\", dataset_path)).unwrap(); for entry in entries { let entry = entry.unwrap(); diff --git a/src/common/remap.rs b/src/common/remap.rs index c541d52..0527899 100644 --- a/src/common/remap.rs +++ b/src/common/remap.rs @@ -1,6 +1,6 @@ -use image::{GrayImage, Luma, Rgb}; +use image::Rgb; use opencv::{ - core::{Mat, MatTrait, MatTraitConst, Vec3b, Vector, CV_8U}, + core::{Mat, MatTrait, MatTraitConst, Vec3b, Vector, CV_8U, CV_8UC3}, imgcodecs::{self, imread, imwrite}, }; @@ -14,6 +14,8 @@ use std::{ use tokio::{sync::Semaphore, task::JoinSet}; use tracing_unwrap::ResultExt; +use crate::THREAD_POOL; + pub fn remap_color(original_color: &str, new_color: &str, image_path: &String, save_path: &String) { let mut original_color_vec: Vec = vec![]; for splited in original_color.split(',') { @@ -65,7 +67,9 @@ pub async fn remap_color_dir( ) { let entries = fs::read_dir(path).unwrap(); let mut threads = JoinSet::new(); - let sem = Arc::new(Semaphore::new(10)); + let sem = Arc::new(Semaphore::new( + (*THREAD_POOL.read().expect_or_log("Get pool error")).into(), + )); let original_color = Arc::new(original_color.deref().to_string()); let new_color = Arc::new(new_color.deref().to_string()); let path = Arc::new(path.deref().to_string()); @@ -206,7 +210,9 @@ pub async fn class2rgb(dataset_path: &String, rgb_list: &str) { } } let mut threads = JoinSet::new(); - let sem = Arc::new(Semaphore::new(10)); + let sem = Arc::new(Semaphore::new( + (*THREAD_POOL.read().expect_or_log("Get pool error")).into(), + )); for entry in entries { let sem = Arc::clone(&sem); let transform_map = Arc::clone(&transform_map); @@ -252,14 +258,31 @@ pub async fn class2rgb(dataset_path: &String, rgb_list: &str) { }); } while threads.join_next().await.is_some() {} + tracing::info!("All done"); + tracing::info!("Saved to {}\\output\\", dataset_path.to_str().unwrap()); } pub async fn rgb2class(dataset_path: &String, rgb_list: &str) { - let entries = fs::read_dir(dataset_path).unwrap(); + let mut entries: Vec = Vec::new(); + let dataset_path = PathBuf::from(dataset_path.as_str()); + if dataset_path.is_file() { + entries.push(dataset_path.clone()); + fs::create_dir_all(format!( + "{}\\output\\", + dataset_path.parent().unwrap().to_str().unwrap() + )) + .expect_or_log("Failed to create directory"); + } else { + entries = fs::read_dir(dataset_path.clone()) + .unwrap() + .map(|x| x.unwrap().path()) + .filter(|x| x.is_file()) + .collect(); + fs::create_dir_all(format!("{}\\output\\", dataset_path.to_str().unwrap())) + .expect_or_log("Failed to create directory"); + } - fs::create_dir_all(format!("{}\\..\\output\\", dataset_path)).unwrap(); - let transform_map: Arc, Luma>>> = - Arc::new(RwLock::new(HashMap::, Luma>::new())); + let transform_map = Arc::new(RwLock::new(HashMap::<[u8; 3], u8>::new())); { // Split RGB list for (class_id, rgb) in rgb_list.split(";").enumerate() { @@ -269,57 +292,66 @@ pub async fn rgb2class(dataset_path: &String, rgb_list: &str) { rgb_vec.push(splited); } - let rgb = Rgb([rgb_vec[0], rgb_vec[1], rgb_vec[2]]); - let gray = Luma([class_id as u8]); - transform_map.write().unwrap().insert(rgb, gray); + transform_map + .write() + .unwrap() + .insert([rgb_vec[0], rgb_vec[1], rgb_vec[2]], class_id as u8); } } + let mut threads = JoinSet::new(); - let sem = Arc::new(Semaphore::new(10)); + let sem = Arc::new(Semaphore::new( + (*THREAD_POOL.read().expect_or_log("Get pool error")).into(), + )); for entry in entries { - let entry = entry.unwrap(); let sem = Arc::clone(&sem); let transform_map = Arc::clone(&transform_map); - let dataset_path = dataset_path.clone(); + threads.spawn(async move { let _ = sem.acquire().await.unwrap(); - let img = image::open(entry.path()).unwrap(); - let original_img = img.into_rgb8(); - let mut mapped_img = GrayImage::new(original_img.width(), original_img.height()); - for ((original_x, original_y, original_pixel), (mapped_x, mapped_y, mapped_pixel)) in - original_img - .enumerate_pixels() - .zip(mapped_img.enumerate_pixels_mut()) - { - if original_x != mapped_x || original_y != mapped_y { - tracing::error!("Pixel coordinate mismatch"); - return; - } - let Rgb([r, g, b]) = original_pixel; - let transform_map = transform_map.read().unwrap(); - let new_color = match transform_map.get(&Rgb([*r, *g, *b])) { - Some(color) => color, - None => { - tracing::error!( - "Unknown color {},{},{} in {}", - r, - g, - b, - entry.path().as_os_str().to_str().unwrap() - ); - panic!() - } - }; - mapped_pixel.0[0] = new_color.0[0]; + let img = imread(&entry.to_str().unwrap(), imgcodecs::IMREAD_COLOR).unwrap(); + + let mut lut = + Mat::new_rows_cols_with_default(1, 256, CV_8UC3, opencv::core::Scalar::all(0.)) + .expect_or_log("Create LUT error"); + + for i in 0..=255u8 { + let lut_value = lut + .at_2d_mut::(0, i.into()) + .expect_or_log("Get LUT value error"); + + *lut_value = Vec3b::from_array([i, i, i]); } - mapped_img - .save(format!( - "{}\\..\\output\\{}", - dataset_path, - entry.file_name().into_string().unwrap() - )) - .unwrap(); - tracing::info!("{} finished", entry.file_name().into_string().unwrap()); + transform_map.read().unwrap().iter().for_each(|(k, v)| { + let lut_value = lut + .at_2d_mut::(0, *v as i32) + .expect_or_log("Get LUT value error"); + *lut_value = Vec3b::from_array([k[0], k[1], k[2]]); + }); + let mut result = Mat::default(); + opencv::core::lut(&img, &lut, &mut result).unwrap(); + + tracing::trace!( + "Write to {}", + format!( + "{}\\output\\{}", + entry.parent().unwrap().to_str().unwrap(), + entry.file_name().unwrap().to_str().unwrap() + ) + ); + imwrite( + format!( + "{}\\output\\{}", + entry.parent().unwrap().to_str().unwrap(), + entry.file_name().unwrap().to_str().unwrap() + ) + .as_str(), + &result, + &Vector::new(), + ) + .unwrap(); + + tracing::info!("{} finished", entry.file_name().unwrap().to_str().unwrap()); }); } while let Some(result) = threads.join_next().await { @@ -333,5 +365,5 @@ pub async fn rgb2class(dataset_path: &String, rgb_list: &str) { } } tracing::info!("All done"); - tracing::info!("Saved to {}\\..\\output\\", dataset_path); + tracing::info!("Saved to {}\\output\\", dataset_path.to_str().unwrap()); } diff --git a/src/main.rs b/src/main.rs index cb3e2c7..f7d12f7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +use std::sync::{LazyLock, RwLock}; + use clap::{Parser, Subcommand}; use common::operation::EdgePosition; use tracing::Level; @@ -5,7 +7,8 @@ use tracing_unwrap::ResultExt; mod common; mod yolo; -// mod geo; + +static THREAD_POOL: LazyLock> = LazyLock::new(|| RwLock::new(10)); #[derive(Parser)] #[command(version, about, long_about = None)] @@ -14,8 +17,8 @@ struct Cli { #[command(subcommand)] command: Option, - #[arg(long, default_value = "100", help = "Thread pool size")] - thread: usize, + #[arg(long, default_value = "10", help = "Thread pool size")] + thread: u16, } #[derive(Subcommand)] @@ -165,7 +168,11 @@ enum CommonCommands { /// Map 8 bit grayscale PNG class image to RGB image #[command(name = "class2rgb")] Class2RGB { - #[arg(short, long, help = "The path for the folder containing images / The path of the image")] + #[arg( + short, + long, + help = "The path for the folder containing images / The path of the image" + )] dataset_path: String, #[arg(short, long, help = "List of RGB colors, in R0,G0,B0;R1,G1,B1 format")] @@ -175,7 +182,11 @@ enum CommonCommands { /// Map RGB image to 8 bit grayscale PNG class image #[command(name = "rgb2class")] RGB2Class { - #[arg(short, long, help = "The path for the folder containing images")] + #[arg( + short, + long, + help = "The path for the folder containing images / The path of the image" + )] dataset_path: String, #[arg(short, long, help = "List of RGB colors, in R0,G0,B0;R1,G1,B1 format")] @@ -344,10 +355,14 @@ async fn main() { let cli = Cli::parse(); rayon::ThreadPoolBuilder::new() - .num_threads(cli.thread) + .num_threads(cli.thread.into()) .build_global() .unwrap(); + *THREAD_POOL + .write() + .expect_or_log("Get thread pool lock failed") = cli.thread; + match &cli.command { Some(Commands::Common { command }) => match command { CommonCommands::CropRectangle { diff --git a/src/yolo/convert.rs b/src/yolo/convert.rs index 31b8eb6..79ffb73 100644 --- a/src/yolo/convert.rs +++ b/src/yolo/convert.rs @@ -5,6 +5,8 @@ use opencv::{core::MatTrait, imgproc}; use tokio::{fs::File, io::AsyncWriteExt, sync::Semaphore, task::JoinSet}; use tracing_unwrap::ResultExt; +use crate::THREAD_POOL; + pub async fn rgb2yolo(dataset_path: &String, rgb_list: &str) { let mut color_class_map = HashMap::, u32>::new(); // 卫星数据 @@ -52,7 +54,9 @@ pub async fn rgb2yolo(dataset_path: &String, rgb_list: &str) { .expect_or_log("Create output dir error"); let mut threads = JoinSet::new(); - let sem = Arc::new(Semaphore::new(10)); + let sem = Arc::new(Semaphore::new( + (*THREAD_POOL.read().expect_or_log("Get pool error")).into(), + )); // Walk through all images in BASE_PATH let entries = fs::read_dir(dataset_path).unwrap(); diff --git a/src/yolo/dataset.rs b/src/yolo/dataset.rs index cd002f3..392b20f 100644 --- a/src/yolo/dataset.rs +++ b/src/yolo/dataset.rs @@ -6,11 +6,16 @@ use std::{ use itertools::Itertools; use tokio::{sync::Semaphore, task::JoinSet}; +use tracing_unwrap::ResultExt; + +use crate::THREAD_POOL; pub async fn split_dataset(dataset_path: &String, train_ratio: &f32) { let entries = fs::read_dir(dataset_path).unwrap(); let mut threads = JoinSet::new(); - let sem = Arc::new(Semaphore::new(10)); + let sem = Arc::new(Semaphore::new( + (*THREAD_POOL.read().expect_or_log("Get pool error")).into(), + )); let result = Arc::new(Mutex::new(Vec::::new())); for entry in entries { let entry = entry.unwrap();