Skip to content

Commit

Permalink
feat: boost class2rgb & add thread pool size arg
Browse files Browse the repository at this point in the history
  • Loading branch information
4o3F committed Oct 8, 2024
1 parent 270bab8 commit 47471b5
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 71 deletions.
13 changes: 6 additions & 7 deletions src/common/augment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use tokio::{sync::Semaphore, task::JoinSet};
use opencv::{core, imgcodecs, prelude::*};
use tracing_unwrap::{OptionExt, ResultExt};

use crate::THREAD_POOL;

pub async fn split_images(dataset_path: &String, target_height: &u32, target_width: &u32) {
let entries = fs::read_dir(dataset_path).expect_or_log("Failed to read directory");
let mut threads = JoinSet::new();
Expand Down Expand Up @@ -426,7 +428,9 @@ pub async fn split_images_with_filter(
);
}

let sem = Arc::new(Semaphore::new(5));
let sem = Arc::new(Semaphore::new(
(*THREAD_POOL.read().expect_or_log("Get pool error")).into(),
));
let valid_id = Arc::new(RwLock::new(Vec::<String>::new()));
let mut threads = tokio::task::JoinSet::new();

Expand Down Expand Up @@ -762,12 +766,7 @@ pub async fn stich_images(splited_images: &String, target_height: &i32, target_w
tracing::trace!("RTL x {} y {}", x, y);
let mut roi = Mat::roi_mut(
&mut result_mat,
core::Rect::new(
x,
y,
size.unwrap().0,
size.unwrap().1,
),
core::Rect::new(x, y, size.unwrap().0, size.unwrap().1),
)
.expect_or_log("Failed to create roi");
img.copy_to(&mut roi).expect_or_log("Failed to copy image");
Expand Down
7 changes: 6 additions & 1 deletion src/common/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ use cocotools::{
};
use image::{Rgb, RgbImage};
use tokio::{sync::Semaphore, task::JoinSet};
use tracing_unwrap::ResultExt;

use crate::THREAD_POOL;

fn mask_image_array(image: &RgbImage, rgb: Rgb<u8>) -> ndarray::Array2<u8> {
let mut mask = ndarray::Array2::<u8>::zeros((image.height() as usize, image.width() as usize));
Expand Down Expand Up @@ -55,7 +58,9 @@ pub async fn rgb2rle(dataset_path: &String, rgb_list: &str) {
}

let mut threads = JoinSet::new();
let sem = Arc::new(Semaphore::new(10));
let sem = Arc::new(Semaphore::new(
(*THREAD_POOL.read().expect_or_log("Get pool error")).into(),
));

let image_count = Arc::new(Mutex::new(0));
let annotation_count = Arc::new(Mutex::new(0));
Expand Down
14 changes: 11 additions & 3 deletions src/common/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ use opencv::{
use tokio::{sync::Semaphore, task::JoinSet};
use tracing_unwrap::{OptionExt, ResultExt};

use crate::THREAD_POOL;

pub async fn split_dataset(dataset_path: &String, train_ratio: &f32) {
let entries = fs::read_dir(dataset_path).unwrap();
let mut threads = JoinSet::new();
let sem = Arc::new(Semaphore::new(10));
let sem = Arc::new(Semaphore::new(
(*THREAD_POOL.read().expect_or_log("Get pool error")).into(),
));
let result = Arc::new(Mutex::new(Vec::<String>::new()));
for entry in entries {
let entry = entry.unwrap();
Expand Down Expand Up @@ -59,7 +63,9 @@ pub async fn count_classes(dataset_path: &String) {
let entries = fs::read_dir(dataset_path).unwrap();

let type_map = Arc::new(Mutex::new(HashMap::<u8, i32>::new()));
let sem = Arc::new(Semaphore::new(5));
let sem = Arc::new(Semaphore::new(
(*THREAD_POOL.read().expect_or_log("Get pool error")).into(),
));
let mut threads = JoinSet::new();
for entry in entries {
let entry = entry.unwrap();
Expand Down Expand Up @@ -127,7 +133,9 @@ pub async fn calc_mean_std(dataset_path: &String) {
let min_value = Arc::new(Mutex::new(f64::MAX));
let max_value = Arc::new(Mutex::new(f64::MIN));

let sem = Arc::new(Semaphore::new(1));
let sem = Arc::new(Semaphore::new(
(*THREAD_POOL.read().expect_or_log("Get pool error")).into(),
));
for entry in entries {
let entry = entry.expect_or_log("Failed to iterate entries");
if !entry.path().is_file() {
Expand Down
6 changes: 5 additions & 1 deletion src/common/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use std::{fs, io::Cursor, sync::Arc};
use tokio::{fs::File, io::AsyncWriteExt, sync::Semaphore, task::JoinSet};
use tracing_unwrap::{OptionExt, ResultExt};

use crate::THREAD_POOL;

pub async fn resize_images(
dataset_path: &String,
target_height: &u32,
Expand All @@ -27,7 +29,9 @@ pub async fn resize_images(

let entries = fs::read_dir(dataset_path).unwrap();
let mut threads = JoinSet::new();
let sem = Arc::new(Semaphore::new(10));
let sem = Arc::new(Semaphore::new(
(*THREAD_POOL.read().expect_or_log("Get pool error")).into(),
));
fs::create_dir_all(format!("{}\\..\\resized\\", dataset_path)).unwrap();
for entry in entries {
let entry = entry.unwrap();
Expand Down
134 changes: 83 additions & 51 deletions src/common/remap.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use image::{GrayImage, Luma, Rgb};
use image::Rgb;
use opencv::{
core::{Mat, MatTrait, MatTraitConst, Vec3b, Vector, CV_8U},
core::{Mat, MatTrait, MatTraitConst, Vec3b, Vector, CV_8U, CV_8UC3},
imgcodecs::{self, imread, imwrite},
};

Expand All @@ -14,6 +14,8 @@ use std::{
use tokio::{sync::Semaphore, task::JoinSet};
use tracing_unwrap::ResultExt;

use crate::THREAD_POOL;

pub fn remap_color(original_color: &str, new_color: &str, image_path: &String, save_path: &String) {
let mut original_color_vec: Vec<u8> = vec![];
for splited in original_color.split(',') {
Expand Down Expand Up @@ -65,7 +67,9 @@ pub async fn remap_color_dir(
) {
let entries = fs::read_dir(path).unwrap();
let mut threads = JoinSet::new();
let sem = Arc::new(Semaphore::new(10));
let sem = Arc::new(Semaphore::new(
(*THREAD_POOL.read().expect_or_log("Get pool error")).into(),
));
let original_color = Arc::new(original_color.deref().to_string());
let new_color = Arc::new(new_color.deref().to_string());
let path = Arc::new(path.deref().to_string());
Expand Down Expand Up @@ -206,7 +210,9 @@ pub async fn class2rgb(dataset_path: &String, rgb_list: &str) {
}
}
let mut threads = JoinSet::new();
let sem = Arc::new(Semaphore::new(10));
let sem = Arc::new(Semaphore::new(
(*THREAD_POOL.read().expect_or_log("Get pool error")).into(),
));
for entry in entries {
let sem = Arc::clone(&sem);
let transform_map = Arc::clone(&transform_map);
Expand Down Expand Up @@ -252,14 +258,31 @@ pub async fn class2rgb(dataset_path: &String, rgb_list: &str) {
});
}
while threads.join_next().await.is_some() {}
tracing::info!("All done");
tracing::info!("Saved to {}\\output\\", dataset_path.to_str().unwrap());
}

pub async fn rgb2class(dataset_path: &String, rgb_list: &str) {
let entries = fs::read_dir(dataset_path).unwrap();
let mut entries: Vec<PathBuf> = Vec::new();
let dataset_path = PathBuf::from(dataset_path.as_str());
if dataset_path.is_file() {
entries.push(dataset_path.clone());
fs::create_dir_all(format!(
"{}\\output\\",
dataset_path.parent().unwrap().to_str().unwrap()
))
.expect_or_log("Failed to create directory");
} else {
entries = fs::read_dir(dataset_path.clone())
.unwrap()
.map(|x| x.unwrap().path())
.filter(|x| x.is_file())
.collect();
fs::create_dir_all(format!("{}\\output\\", dataset_path.to_str().unwrap()))
.expect_or_log("Failed to create directory");
}

fs::create_dir_all(format!("{}\\..\\output\\", dataset_path)).unwrap();
let transform_map: Arc<RwLock<HashMap<Rgb<u8>, Luma<u8>>>> =
Arc::new(RwLock::new(HashMap::<Rgb<u8>, Luma<u8>>::new()));
let transform_map = Arc::new(RwLock::new(HashMap::<[u8; 3], u8>::new()));
{
// Split RGB list
for (class_id, rgb) in rgb_list.split(";").enumerate() {
Expand All @@ -269,57 +292,66 @@ pub async fn rgb2class(dataset_path: &String, rgb_list: &str) {
rgb_vec.push(splited);
}

let rgb = Rgb([rgb_vec[0], rgb_vec[1], rgb_vec[2]]);
let gray = Luma([class_id as u8]);
transform_map.write().unwrap().insert(rgb, gray);
transform_map
.write()
.unwrap()
.insert([rgb_vec[0], rgb_vec[1], rgb_vec[2]], class_id as u8);
}
}

let mut threads = JoinSet::new();
let sem = Arc::new(Semaphore::new(10));
let sem = Arc::new(Semaphore::new(
(*THREAD_POOL.read().expect_or_log("Get pool error")).into(),
));
for entry in entries {
let entry = entry.unwrap();
let sem = Arc::clone(&sem);
let transform_map = Arc::clone(&transform_map);
let dataset_path = dataset_path.clone();

threads.spawn(async move {
let _ = sem.acquire().await.unwrap();
let img = image::open(entry.path()).unwrap();
let original_img = img.into_rgb8();
let mut mapped_img = GrayImage::new(original_img.width(), original_img.height());
for ((original_x, original_y, original_pixel), (mapped_x, mapped_y, mapped_pixel)) in
original_img
.enumerate_pixels()
.zip(mapped_img.enumerate_pixels_mut())
{
if original_x != mapped_x || original_y != mapped_y {
tracing::error!("Pixel coordinate mismatch");
return;
}
let Rgb([r, g, b]) = original_pixel;
let transform_map = transform_map.read().unwrap();
let new_color = match transform_map.get(&Rgb([*r, *g, *b])) {
Some(color) => color,
None => {
tracing::error!(
"Unknown color {},{},{} in {}",
r,
g,
b,
entry.path().as_os_str().to_str().unwrap()
);
panic!()
}
};
mapped_pixel.0[0] = new_color.0[0];
let img = imread(&entry.to_str().unwrap(), imgcodecs::IMREAD_COLOR).unwrap();

let mut lut =
Mat::new_rows_cols_with_default(1, 256, CV_8UC3, opencv::core::Scalar::all(0.))
.expect_or_log("Create LUT error");

for i in 0..=255u8 {
let lut_value = lut
.at_2d_mut::<Vec3b>(0, i.into())
.expect_or_log("Get LUT value error");

*lut_value = Vec3b::from_array([i, i, i]);
}
mapped_img
.save(format!(
"{}\\..\\output\\{}",
dataset_path,
entry.file_name().into_string().unwrap()
))
.unwrap();
tracing::info!("{} finished", entry.file_name().into_string().unwrap());
transform_map.read().unwrap().iter().for_each(|(k, v)| {
let lut_value = lut
.at_2d_mut::<Vec3b>(0, *v as i32)
.expect_or_log("Get LUT value error");
*lut_value = Vec3b::from_array([k[0], k[1], k[2]]);
});
let mut result = Mat::default();
opencv::core::lut(&img, &lut, &mut result).unwrap();

tracing::trace!(
"Write to {}",
format!(
"{}\\output\\{}",
entry.parent().unwrap().to_str().unwrap(),
entry.file_name().unwrap().to_str().unwrap()
)
);
imwrite(
format!(
"{}\\output\\{}",
entry.parent().unwrap().to_str().unwrap(),
entry.file_name().unwrap().to_str().unwrap()
)
.as_str(),
&result,
&Vector::new(),
)
.unwrap();

tracing::info!("{} finished", entry.file_name().unwrap().to_str().unwrap());
});
}
while let Some(result) = threads.join_next().await {
Expand All @@ -333,5 +365,5 @@ pub async fn rgb2class(dataset_path: &String, rgb_list: &str) {
}
}
tracing::info!("All done");
tracing::info!("Saved to {}\\..\\output\\", dataset_path);
tracing::info!("Saved to {}\\output\\", dataset_path.to_str().unwrap());
}
27 changes: 21 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::sync::{LazyLock, RwLock};

use clap::{Parser, Subcommand};
use common::operation::EdgePosition;
use tracing::Level;
use tracing_unwrap::ResultExt;

mod common;
mod yolo;
// mod geo;

static THREAD_POOL: LazyLock<RwLock<u16>> = LazyLock::new(|| RwLock::new(10));

#[derive(Parser)]
#[command(version, about, long_about = None)]
Expand All @@ -14,8 +17,8 @@ struct Cli {
#[command(subcommand)]
command: Option<Commands>,

#[arg(long, default_value = "100", help = "Thread pool size")]
thread: usize,
#[arg(long, default_value = "10", help = "Thread pool size")]
thread: u16,
}

#[derive(Subcommand)]
Expand Down Expand Up @@ -165,7 +168,11 @@ enum CommonCommands {
/// Map 8 bit grayscale PNG class image to RGB image
#[command(name = "class2rgb")]
Class2RGB {
#[arg(short, long, help = "The path for the folder containing images / The path of the image")]
#[arg(
short,
long,
help = "The path for the folder containing images / The path of the image"
)]
dataset_path: String,

#[arg(short, long, help = "List of RGB colors, in R0,G0,B0;R1,G1,B1 format")]
Expand All @@ -175,7 +182,11 @@ enum CommonCommands {
/// Map RGB image to 8 bit grayscale PNG class image
#[command(name = "rgb2class")]
RGB2Class {
#[arg(short, long, help = "The path for the folder containing images")]
#[arg(
short,
long,
help = "The path for the folder containing images / The path of the image"
)]
dataset_path: String,

#[arg(short, long, help = "List of RGB colors, in R0,G0,B0;R1,G1,B1 format")]
Expand Down Expand Up @@ -344,10 +355,14 @@ async fn main() {
let cli = Cli::parse();

rayon::ThreadPoolBuilder::new()
.num_threads(cli.thread)
.num_threads(cli.thread.into())
.build_global()
.unwrap();

*THREAD_POOL
.write()
.expect_or_log("Get thread pool lock failed") = cli.thread;

match &cli.command {
Some(Commands::Common { command }) => match command {
CommonCommands::CropRectangle {
Expand Down
6 changes: 5 additions & 1 deletion src/yolo/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use opencv::{core::MatTrait, imgproc};
use tokio::{fs::File, io::AsyncWriteExt, sync::Semaphore, task::JoinSet};
use tracing_unwrap::ResultExt;

use crate::THREAD_POOL;

pub async fn rgb2yolo(dataset_path: &String, rgb_list: &str) {
let mut color_class_map = HashMap::<Rgb<u8>, u32>::new();
// 卫星数据
Expand Down Expand Up @@ -52,7 +54,9 @@ pub async fn rgb2yolo(dataset_path: &String, rgb_list: &str) {
.expect_or_log("Create output dir error");

let mut threads = JoinSet::new();
let sem = Arc::new(Semaphore::new(10));
let sem = Arc::new(Semaphore::new(
(*THREAD_POOL.read().expect_or_log("Get pool error")).into(),
));

// Walk through all images in BASE_PATH
let entries = fs::read_dir(dataset_path).unwrap();
Expand Down
Loading

0 comments on commit 47471b5

Please sign in to comment.