Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement same colored output in Rust CLI as Python #490

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions rust/cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ path = "src/main.rs"
anyhow = "1.0.82"
async-channel = "2.2.0"
clap = { version = "4.5.1", features = ["derive"] }
colored = "2.1.0"
magika = { version = "0.1.0-dev", path = "../lib", features = ["tokio"] }
ort = "2.0.0-rc.1"
tokio = { version = "1.37.0", features = ["full"] }
57 changes: 47 additions & 10 deletions rust/cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::sync::Arc;

use anyhow::{bail, ensure, Result};
use clap::Parser;
use colored::{Color, Colorize};
use ort::GraphOptimizationLevel;
use tokio::fs::File;

Expand All @@ -28,6 +29,18 @@ struct Flags {
/// List of paths to the files to analyze.
path: Vec<PathBuf>,

/// Forces usage of colors.
///
/// Colors are automatically enabled if the terminal supports them.
#[arg(long)]
colors: bool,

/// Disables usage of colors.
///
/// Colors are automatically disabled if the terminal doesn't support them.
#[arg(long, conflicts_with = "colors")]
no_colors: bool,

/// Format string (use --help for details).
///
/// The following placeholders are supported:
Expand All @@ -41,7 +54,7 @@ struct Flags {
/// %e The file extensions of the content type
/// %s The score of the content type for the file
/// %% A literal %
#[arg(long, default_value = "%D (confidence: %s)", verbatim_doc_comment)]
#[arg(long, default_value = "%D (%g)", verbatim_doc_comment)]
format: String,

/// Number of files to identify in a single inference.
Expand Down Expand Up @@ -76,6 +89,12 @@ async fn main() -> Result<()> {
let flags = Arc::new(Flags::parse());
ensure!(0 < flags.batch_size, "--batch-size cannot be zero");
ensure!(0 < flags.num_tasks, "--num-tasks cannot be zero");
if flags.colors {
colored::control::set_override(true);
}
if flags.no_colors {
colored::control::set_override(false);
}
let (result_sender, mut result_receiver) =
tokio::sync::mpsc::channel::<Result<Response>>(flags.num_tasks * flags.batch_size);
let (batch_sender, batch_receiver) = async_channel::bounded::<Batch>(flags.num_tasks);
Expand Down Expand Up @@ -106,13 +125,17 @@ async fn main() -> Result<()> {
let mut results = vec![None; flags.path.len()];
drop(result_sender);
while let Some(response) = result_receiver.recv().await {
let Response { index, result } = response?;
results[index] = Some(result);
let Response { index, result, color } = response?;
results[index] = Some((result, color));
}
for (path, result) in flags.path.iter().zip(results.into_iter()) {
let path = path.display();
let result = result.unwrap();
println!("{path}: {result}");
let (result, color) = result.unwrap();
let mut output = format!("{path}: {result}").bold();
if let Some(color) = color {
output = output.color(color);
}
println!("{output}");
}
Ok(())
}
Expand Down Expand Up @@ -140,7 +163,7 @@ async fn extract_features(
Err(error) => result = Some(format!("{error}")),
}
if let Some(result) = result {
result_sender.send(Ok(Response { index, result })).await?;
result_sender.send(Ok(Response { index, result, color: None })).await?;
continue;
}
let file = File::open(path).await?;
Expand Down Expand Up @@ -190,14 +213,14 @@ async fn infer_batch(
let batch = magika.identify_many_async(&features).await?;
assert_eq!(batch.len(), indices.len());
for (&index, output) in indices.iter().zip(batch.into_iter()) {
let result = format(flags, output);
sender.send(Ok(Response { index, result })).await?;
let (result, color) = format(flags, output);
sender.send(Ok(Response { index, result, color })).await?;
}
}
Ok(())
}

fn format(flags: &Flags, output: magika::Output) -> String {
fn format(flags: &Flags, output: magika::Output) -> (String, Option<Color>) {
let mut result = String::new();
let mut format = flags.format.chars();
let label = output.label();
Expand All @@ -219,7 +242,20 @@ fn format(flags: &Flags, output: magika::Output) -> String {
None => break,
}
}
result
(result, label.group().iter().find_map(|x| group_color(x)))
}

fn group_color(group: &str) -> Option<Color> {
Some(match group {
"document" => Color::BrightMagenta,
"executable" => Color::BrightGreen,
"archive" => Color::BrightRed,
"audio" => Color::Yellow,
"image" => Color::Yellow,
"video" => Color::Yellow,
"code" => Color::BrightBlue,
_ => return None,
})
}

struct Batch {
Expand All @@ -230,4 +266,5 @@ struct Batch {
struct Response {
index: usize,
result: String,
color: Option<Color>,
}
Loading