From d3dfffffb4e87249bfdd001795eee64b20d7bac5 Mon Sep 17 00:00:00 2001 From: ia0 Date: Tue, 4 Jun 2024 10:42:54 +0200 Subject: [PATCH] Implement same colored output in Rust CLI as Python In particular, show the group instead of the score by default. --- rust/cli/Cargo.lock | 17 +++++++++++++ rust/cli/Cargo.toml | 1 + rust/cli/src/main.rs | 57 ++++++++++++++++++++++++++++++++++++-------- 3 files changed, 65 insertions(+), 10 deletions(-) diff --git a/rust/cli/Cargo.lock b/rust/cli/Cargo.lock index 3f69c0b4..6fd14992 100644 --- a/rust/cli/Cargo.lock +++ b/rust/cli/Cargo.lock @@ -196,6 +196,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "colored" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbf2150cce219b664a8a70df7a1f933836724b503f8a413af9365b4dcc4d90b8" +dependencies = [ + "lazy_static", + "windows-sys 0.48.0", +] + [[package]] name = "concurrent-queue" version = "2.4.0" @@ -382,6 +392,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.153" @@ -427,6 +443,7 @@ dependencies = [ "anyhow", "async-channel", "clap", + "colored", "magika", "ort", "tokio", diff --git a/rust/cli/Cargo.toml b/rust/cli/Cargo.toml index 766930fe..a9b2293d 100644 --- a/rust/cli/Cargo.toml +++ b/rust/cli/Cargo.toml @@ -18,6 +18,7 @@ path = "src/main.rs" anyhow = "1.0.82" async-channel = "2.2.0" clap = { version = "4.5.1", features = ["derive"] } +colored = "2.1.0" magika = { version = "0.1.0-dev", path = "../lib", features = ["tokio"] } ort = "2.0.0-rc.1" tokio = { version = "1.37.0", features = ["full"] } diff --git a/rust/cli/src/main.rs b/rust/cli/src/main.rs index 7cd4371a..fa9b022f 100644 --- a/rust/cli/src/main.rs +++ b/rust/cli/src/main.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use anyhow::{bail, ensure, Result}; use clap::Parser; +use colored::{Color, Colorize}; use ort::GraphOptimizationLevel; use tokio::fs::File; @@ -28,6 +29,18 @@ struct Flags { /// List of paths to the files to analyze. path: Vec, + /// Forces usage of colors. + /// + /// Colors are automatically enabled if the terminal supports them. + #[arg(long)] + colors: bool, + + /// Disables usage of colors. + /// + /// Colors are automatically disabled if the terminal doesn't support them. + #[arg(long, conflicts_with = "colors")] + no_colors: bool, + /// Format string (use --help for details). /// /// The following placeholders are supported: @@ -41,7 +54,7 @@ struct Flags { /// %e The file extensions of the content type /// %s The score of the content type for the file /// %% A literal % - #[arg(long, default_value = "%D (confidence: %s)", verbatim_doc_comment)] + #[arg(long, default_value = "%D (%g)", verbatim_doc_comment)] format: String, /// Number of files to identify in a single inference. @@ -76,6 +89,12 @@ async fn main() -> Result<()> { let flags = Arc::new(Flags::parse()); ensure!(0 < flags.batch_size, "--batch-size cannot be zero"); ensure!(0 < flags.num_tasks, "--num-tasks cannot be zero"); + if flags.colors { + colored::control::set_override(true); + } + if flags.no_colors { + colored::control::set_override(false); + } let (result_sender, mut result_receiver) = tokio::sync::mpsc::channel::>(flags.num_tasks * flags.batch_size); let (batch_sender, batch_receiver) = async_channel::bounded::(flags.num_tasks); @@ -106,13 +125,17 @@ async fn main() -> Result<()> { let mut results = vec![None; flags.path.len()]; drop(result_sender); while let Some(response) = result_receiver.recv().await { - let Response { index, result } = response?; - results[index] = Some(result); + let Response { index, result, color } = response?; + results[index] = Some((result, color)); } for (path, result) in flags.path.iter().zip(results.into_iter()) { let path = path.display(); - let result = result.unwrap(); - println!("{path}: {result}"); + let (result, color) = result.unwrap(); + let mut output = format!("{path}: {result}").bold(); + if let Some(color) = color { + output = output.color(color); + } + println!("{output}"); } Ok(()) } @@ -140,7 +163,7 @@ async fn extract_features( Err(error) => result = Some(format!("{error}")), } if let Some(result) = result { - result_sender.send(Ok(Response { index, result })).await?; + result_sender.send(Ok(Response { index, result, color: None })).await?; continue; } let file = File::open(path).await?; @@ -190,14 +213,14 @@ async fn infer_batch( let batch = magika.identify_many_async(&features).await?; assert_eq!(batch.len(), indices.len()); for (&index, output) in indices.iter().zip(batch.into_iter()) { - let result = format(flags, output); - sender.send(Ok(Response { index, result })).await?; + let (result, color) = format(flags, output); + sender.send(Ok(Response { index, result, color })).await?; } } Ok(()) } -fn format(flags: &Flags, output: magika::Output) -> String { +fn format(flags: &Flags, output: magika::Output) -> (String, Option) { let mut result = String::new(); let mut format = flags.format.chars(); let label = output.label(); @@ -219,7 +242,20 @@ fn format(flags: &Flags, output: magika::Output) -> String { None => break, } } - result + (result, label.group().iter().find_map(|x| group_color(x))) +} + +fn group_color(group: &str) -> Option { + Some(match group { + "document" => Color::BrightMagenta, + "executable" => Color::BrightGreen, + "archive" => Color::BrightRed, + "audio" => Color::Yellow, + "image" => Color::Yellow, + "video" => Color::Yellow, + "code" => Color::BrightBlue, + _ => return None, + }) } struct Batch { @@ -230,4 +266,5 @@ struct Batch { struct Response { index: usize, result: String, + color: Option, }