From 70162f907b00fdb299f78f3a95bcd94d1b0f615f Mon Sep 17 00:00:00 2001 From: YISH Date: Sat, 23 Nov 2024 15:20:52 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20Optimize=20parameter=20parsing?= =?UTF-8?q?=20of=20resolve=20command?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cli.rs | 16 ++- src/resolver.rs | 371 ++++++++++++++++++++++++++++++++++-------------- 2 files changed, 280 insertions(+), 107 deletions(-) diff --git a/src/cli.rs b/src/cli.rs index 5b6dd0a2..c51c86cb 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -56,10 +56,18 @@ impl Cli { let itr = itr.into_iter().collect::>(); match Self::try_parse_from(itr.clone()) { Ok(cli) => cli, - Err(e) => match CompatibleCli::try_parse_from(itr) { - Ok(cli) => cli.into(), - Err(_) => e.exit(), - }, + Err(e) => { + #[cfg(feature = "resolve-cli")] + if let Ok(resolve_command) = ResolveCommand::try_parse_from(itr.clone()) { + return resolve_command.into(); + } + + if let Ok(cli) = CompatibleCli::try_parse_from(itr) { + return cli.into(); + } + + e.exit() + } } } diff --git a/src/resolver.rs b/src/resolver.rs index d894780e..d2c34894 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::ffi::OsString; use std::path::Path; use std::{ops::Deref, str::FromStr, time::Duration}; @@ -8,7 +9,7 @@ use console::{style, StyledObject}; use crate::libdns::proto::{ op::Message, - rr::{DNSClass as QueryClass, Name as Domain, Record, RecordData, RecordType}, + rr::{DNSClass, DNSClass as QueryClass, Name as Domain, Record, RecordData, RecordType}, xfer::Protocol as DnsOverProtocol, }; @@ -27,7 +28,7 @@ impl ResolveCommand { s.set_proto(proto) } } - let domain = self.domain().clone(); + let domains = self.domains(); let query_types = self.q_type(); let palette = Colours::pretty(); @@ -48,18 +49,20 @@ impl ResolveCommand { DnsClient::builder().build().await }; - for query_type in query_types { - let options = LookupOptions { - record_type: *query_type, - ..Default::default() - }; - - match dns_client.lookup(domain.clone(), options).await { - Ok(res) => { - print(&res, &palette); - } - Err(err) => { - println!("{}", err); + for domain in domains { + for query_type in query_types { + let options = LookupOptions { + record_type: *query_type, + ..Default::default() + }; + + match dns_client.lookup(domain.clone(), options).await { + Ok(res) => { + print(&res, &palette); + } + Err(err) => { + println!("{}", err); + } } } } @@ -67,9 +70,37 @@ impl ResolveCommand { } } -#[derive(Parser, Debug)] +#[derive(Parser, Debug, Default, PartialEq, Eq)] #[command(after_help=include_str!("../RESOLVE_EXAMPLES.txt"))] pub struct ResolveCommand { + #[command(flatten)] + proto: ProtocolType, + + #[arg(short = 'J', long)] + json: bool, + + #[arg(short = '1', long)] + short: bool, + + /// is in the Domain Name System + #[arg(value_name = "domain", num_args = 1, value_parser = Variant::parse::)] + domains: Vec, + + /// is one of (a,any,mx,ns,soa,hinfo,axfr,txt,...) + #[arg(value_name = "q-type", num_args = 1, value_parser = Variant::parse::)] + record_types: Vec, + + /// is one of (in,hs,ch,...) + #[arg(value_name = "q-class", value_parser = Variant::parse::)] + q_class: Option, + + /// is the global nameserver + #[arg(value_name = "@global-server", last = true, value_parser = Variant::parse::)] + global_server: Option, +} + +#[derive(Parser, Debug, Default, PartialEq, Eq)] +struct ProtocolType { /// Use the DNS protocol over UDP #[arg(short = 'U', long, group = "proto")] udp: bool, @@ -93,22 +124,6 @@ pub struct ResolveCommand { /// Use the DNS-over-HTTPS/3 protocol #[arg(long, group = "proto")] h3: bool, - - /// is in the Domain Name System - #[arg(value_name = "domain")] - domain: Domain, - - /// is one of (a,any,mx,ns,soa,hinfo,axfr,txt,...) - #[arg(value_name = "q-type", default_value = "a", value_parser = Self::parse_query_type)] - q_type: QueryTypes, - - /// is one of (in,hs,ch,...) - #[arg(value_name = "q-class", default_value = "in", value_parser = Self::parse_query_class)] - q_class: QueryClass, - - /// is the global nameserver - #[arg(value_name = "@global-server")] - global_server: Option, } impl ResolveCommand { @@ -125,15 +140,25 @@ impl ResolveCommand { } pub fn try_parse() -> Result { + Self::try_parse_from(std::env::args()) + } + + pub fn try_parse_from(itr: I) -> Result + where + I: IntoIterator, + T: Into + Clone, + { use DnsOverProtocol::*; let mut proto = None; let mut q_types = vec![]; let mut q_class = None; let mut domain = None; let mut global_server = None; - let mut prev_parsing_qtype = false; - for arg in std::env::args().skip(1) { + for arg in itr.into_iter().skip(1).map(Into::::into) { + let arg = arg + .into_string() + .expect("Failed to convert OsString to String"); if arg == "resolve" { continue; } @@ -172,34 +197,35 @@ impl ResolveCommand { continue; } - if q_types.is_empty() { - if let Ok(t) = Self::parse_query_type(&arg) { - q_types = t.0; - prev_parsing_qtype = true; - continue; - } - } else if prev_parsing_qtype { - if let Ok(t) = Self::parse_query_type(&arg) { - q_types.extend(t.0); - continue; - } - prev_parsing_qtype = false; - } - - if q_class.is_none() { - if let Ok(t) = Self::parse_query_class(arg.as_str()) { - q_class = Some(t); + if arg.contains('+') { + let record_types = arg + .split('+') + .map(|p| p.to_uppercase()) + .flat_map(|s| RecordType::from_str(&s)) + .collect::>(); + if !record_types.is_empty() { + q_types.extend(record_types); continue; } } - if domain.is_none() { - if let Ok(t) = Domain::from_str(arg.as_str()) { - domain = Some(t); - continue; + if let Ok(v) = Variant::from_str(&arg) { + match v { + Variant::Domain(d) => { + domain = Some(d); + } + Variant::RecordType(t) => { + q_types.push(t); + } + Variant::DNSClass(c) => { + q_class = Some(c); + } + Variant::Server(s) => { + global_server = Some(s); + } } + continue; } - return Err(format!("Invalid argument {arg}")); } @@ -210,19 +236,21 @@ impl ResolveCommand { if q_types.is_empty() { q_types.push(RecordType::A); } - let q_class = q_class.unwrap_or(QueryClass::IN); Ok(Self { - udp: matches!(proto, Some(Udp)), - tcp: matches!(proto, Some(Tcp)), - tls: matches!(proto, Some(Tls)), - quic: matches!(proto, Some(Quic)), - https: matches!(proto, Some(Https)), - h3: matches!(proto, Some(H3)), + proto: ProtocolType { + udp: matches!(proto, Some(Udp)), + tcp: matches!(proto, Some(Tcp)), + tls: matches!(proto, Some(Tls)), + quic: matches!(proto, Some(Quic)), + https: matches!(proto, Some(Https)), + h3: matches!(proto, Some(H3)), + }, global_server, - domain, - q_type: QueryTypes(q_types), + domains: vec![domain], + record_types: q_types, q_class, + ..Default::default() }) } @@ -244,17 +272,18 @@ impl ResolveCommand { pub fn proto(&self) -> Option { use DnsOverProtocol::*; - if self.udp { + let proto = &self.proto; + if proto.udp { Some(Udp) - } else if self.tcp { + } else if proto.tcp { Some(Tcp) - } else if self.tls { + } else if proto.tls { Some(Tls) - } else if self.quic { + } else if proto.quic { Some(Quic) - } else if self.https { + } else if proto.https { Some(Https) - } else if self.h3 { + } else if proto.h3 { Some(H3) } else { None @@ -265,61 +294,104 @@ impl ResolveCommand { self.global_server.as_deref() } - pub fn domain(&self) -> &Domain { - &self.domain + pub fn domains(&self) -> &[Domain] { + &self.domains } pub fn q_type(&self) -> &[RecordType] { - &self.q_type.0 + &self.record_types } pub fn q_class(&self) -> QueryClass { - self.q_class + self.q_class.unwrap_or(QueryClass::IN) } +} - fn parse_global_server(s: &str) -> Result { - if let Some(s) = s.strip_prefix('@') { - Ok(s.to_string()) - } else { - Err(format!("Invalid global server: {}", s)) +enum Variant { + Domain(Domain), + RecordType(RecordType), + DNSClass(DNSClass), + Server(String), +} + +impl Variant { + fn parse>(s: &str) -> Result { + Self::from_str(s).and_then(|s| s.try_into()) + } +} + +impl TryFrom for Domain { + type Error = String; + fn try_from(s: Variant) -> Result { + match s { + Variant::Domain(domain) => Ok(domain), + _ => Err("Expected a domain".to_string()), } } - fn parse_query_type(s: &str) -> Result { - if s.contains("+") { - let mut types = Vec::new(); - let mut last_err = None; - for t in s.split('+') { - match RecordType::from_str(t.to_uppercase().as_str()) { - Ok(t) => types.push(t), - Err(err) => last_err = Some(err), - } - } +} - if types.is_empty() { - return Err(last_err - .map(|e| e.to_string()) - .unwrap_or("Invalid query type".to_string())); - } +impl TryFrom for RecordType { + type Error = String; + fn try_from(s: Variant) -> Result { + match s { + Variant::RecordType(record_type) => Ok(record_type), + _ => Err("Expected a record type".to_string()), + } + } +} - Ok(QueryTypes(types)) - } else { - RecordType::from_str(s.to_uppercase().as_str()) - .map(|q| QueryTypes(vec![q])) - .map_err(|e| e.to_string()) +impl TryFrom for DNSClass { + type Error = String; + fn try_from(s: Variant) -> Result { + match s { + Variant::DNSClass(dns_class) => Ok(dns_class), + _ => Err("Expected a DNS class".to_string()), } } - fn parse_query_class(s: &str) -> Result { - QueryClass::from_str(s.to_uppercase().as_str()).map_err(|e| e.to_string()) +} + +impl TryFrom for String { + type Error = String; + fn try_from(s: Variant) -> Result { + match s { + Variant::Server(server) => Ok(server), + _ => Err("Expected a server".to_string()), + } } } -#[derive(Debug, Clone)] -struct QueryTypes(Vec); +impl FromStr for Variant { + type Err = String; + + fn from_str(s: &str) -> Result { + if s.starts_with('@') { + return Ok(Self::Server(s[1..].to_string())); + } + + let upper = s.to_uppercase(); + + if let Ok(record_type) = RecordType::from_str(&upper) { + return Ok(Self::RecordType(record_type)); + } + if let Ok(dns_class) = DNSClass::from_str(&upper) { + return Ok(Self::DNSClass(dns_class)); + } + + if let Ok(name) = Domain::from_str(s) { + return Ok(Self::Domain(name)); + } + + Err(format!("Invalid query variant: {}", s)) + } +} fn print(message: &Message, palette: &Colours) { for r in message.answers() { print_record(&r, palette); } + for r in message.additionals() { + print_record(&r, palette); + } } fn print_record>>(r: &R, palette: &Colours) { @@ -394,3 +466,96 @@ impl Colours { Self::default() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cli_parse() { + assert_eq!( + ResolveCommand::try_parse_from(["dig", "example.com", "a"]).unwrap(), + ResolveCommand { + domains: vec!["example.com".parse().unwrap()], + record_types: vec!["A"] + .iter() + .map(|s| s.parse()) + .collect::, _>>() + .unwrap(), + ..Default::default() + } + ); + + assert_eq!( + ResolveCommand::try_parse_from(["dig", "example.com", "a+aaaa"]).unwrap(), + ResolveCommand { + domains: vec!["example.com".parse().unwrap()], + record_types: vec!["A", "AAAA"] + .iter() + .map(|s| s.parse()) + .collect::, _>>() + .unwrap(), + ..Default::default() + } + ); + + assert_eq!( + ResolveCommand::try_parse_from(["dig", "example.com", "a", "aaaa", "TXT"]).unwrap(), + ResolveCommand { + domains: vec!["example.com".parse().unwrap()], + record_types: vec!["A", "AAAA", "TXT"] + .iter() + .map(|s| s.parse()) + .collect::, _>>() + .unwrap(), + ..Default::default() + } + ); + + assert_eq!( + ResolveCommand::try_parse_from(["dig", "example.com", "a", "aaaa", "in"]).unwrap(), + ResolveCommand { + domains: vec!["example.com".parse().unwrap()], + record_types: vec!["A", "AAAA"] + .iter() + .map(|s| s.parse()) + .collect::, _>>() + .unwrap(), + q_class: Some(DNSClass::IN), + ..Default::default() + } + ); + + assert_eq!( + ResolveCommand::try_parse_from(["dig", "example.com", "a", "aaaa", "in", "@1.1.1.1"]) + .unwrap(), + ResolveCommand { + domains: vec!["example.com".parse().unwrap()], + record_types: vec!["A", "AAAA"] + .iter() + .map(|s| s.parse()) + .collect::, _>>() + .unwrap(), + global_server: Some("1.1.1.1".to_string()), + q_class: Some(DNSClass::IN), + ..Default::default() + } + ); + + assert_eq!( + ResolveCommand::try_parse_from(["dig", "@1.1.1.1", "example.com", "a", "aaaa", "in"]) + .unwrap(), + ResolveCommand { + domains: vec!["example.com".parse().unwrap()], + record_types: vec!["A", "AAAA"] + .iter() + .map(|s| s.parse()) + .collect::, _>>() + .unwrap(), + global_server: Some("1.1.1.1".to_string()), + q_class: Some(DNSClass::IN), + ..Default::default() + } + ); + } +}