diff --git a/Cargo.toml b/Cargo.toml index 579f62f..fa98e56 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ name = "libmdns" version = "0.2.4" authors = ["Will Stott "] +edition = "2018" description = "mDNS Responder library for building discoverable LAN services in Rust" repository = "https://github.com/librespot-org/libmdns" @@ -9,16 +10,16 @@ readme = "README.md" license = "MIT" [dependencies] -byteorder = "1.2" -futures = "0.1" +byteorder = "1.3" get_if_addrs = "0.5" -hostname = "0.2" +hostname = "0.3" log = "0.4" -multimap = "0.4" +multimap = "0.8" net2 = "0.2" -rand = "0.5" -tokio-core = "0.1" +rand = "0.7" +futures-util = "0.3" +tokio = { version = "0.2.16", features = ["sync","udp","stream","rt-core"] } quick-error = "1.2" [dev-dependencies] -env_logger = "0.5" +env_logger = { version = "0.7", default-features = false, features = ["termcolor","humantime","atty"] } diff --git a/examples/register.rs b/examples/register.rs index 851b99b..7800153 100644 --- a/examples/register.rs +++ b/examples/register.rs @@ -2,7 +2,9 @@ extern crate env_logger; extern crate libmdns; pub fn main() { - env_logger::init(); + let mut builder = env_logger::Builder::new(); + builder.parse_filters("libmdns=debug"); + builder.init(); let responder = libmdns::Responder::new().unwrap(); let _svc = responder.register( diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..32a9786 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +edition = "2018" diff --git a/src/dns_parser/builder.rs b/src/dns_parser/builder.rs index da81356..2b192fc 100644 --- a/src/dns_parser/builder.rs +++ b/src/dns_parser/builder.rs @@ -1,16 +1,17 @@ use std::marker::PhantomData; -use byteorder::{ByteOrder, BigEndian, WriteBytesExt}; +use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; -use super::{Opcode, ResponseCode, Header, Name, RRData, QueryType, QueryClass}; +use super::{Header, Name, Opcode, QueryClass, QueryType, RRData, ResponseCode}; pub enum Questions {} pub enum Answers {} -#[allow(dead_code)] pub enum Nameservers {} +#[allow(dead_code)] +pub enum Nameservers {} pub enum Additional {} -pub trait MoveTo { } -impl MoveTo for T {} +pub trait MoveTo {} +impl MoveTo for T {} impl MoveTo for Questions {} @@ -55,7 +56,11 @@ impl Builder { }; buf.extend([0u8; 12].iter()); head.write(&mut buf[..12]); - Builder { buf: buf, max_size: Some(512), _state: PhantomData } + Builder { + buf: buf, + max_size: Some(512), + _state: PhantomData, + } } pub fn new_response(id: u16, recursion: bool, authoritative: bool) -> Builder { @@ -76,14 +81,16 @@ impl Builder { }; buf.extend([0u8; 12].iter()); head.write(&mut buf[..12]); - Builder { buf: buf, max_size: Some(512), _state: PhantomData } + Builder { + buf: buf, + max_size: Some(512), + _state: PhantomData, + } } } -impl Builder { - fn write_rr(&mut self, name: &Name, - cls: QueryClass, ttl: u32, data: &RRData) { - +impl Builder { + fn write_rr(&mut self, name: &Name, cls: QueryClass, ttl: u32, data: &RRData) { name.write_to(&mut self.buf).unwrap(); self.buf.write_u16::(data.typ() as u16).unwrap(); self.buf.write_u16::(cls as u16).unwrap(); @@ -96,7 +103,10 @@ impl Builder { data.write_to(&mut self.buf).unwrap(); let data_size = self.buf.len() - data_offset; - BigEndian::write_u16(&mut self.buf[size_offset..size_offset+2], data_size as u16); + BigEndian::write_u16( + &mut self.buf[size_offset..size_offset + 2], + data_size as u16, + ); } /// Returns the final packet @@ -113,7 +123,7 @@ impl Builder { /// appropriate. // TODO(tailhook) does the truncation make sense for TCP, and how // to treat it for EDNS0? - pub fn build(mut self) -> Result,Vec> { + pub fn build(mut self) -> Result, Vec> { // TODO(tailhook) optimize labels match self.max_size { Some(max_size) if self.buf.len() > max_size => { @@ -124,8 +134,15 @@ impl Builder { } } - pub fn move_to(self) -> Builder where T: MoveTo { - Builder { buf: self.buf, max_size: self.max_size, _state: PhantomData } + pub fn move_to(self) -> Builder + where + T: MoveTo, + { + Builder { + buf: self.buf, + max_size: self.max_size, + _state: PhantomData, + } } pub fn set_max_size(&mut self, max_size: Option) { @@ -133,61 +150,66 @@ impl Builder { } pub fn is_empty(&self) -> bool { - Header::question_count(&self.buf) == 0 && - Header::answer_count(&self.buf) == 0 && - Header::nameserver_count(&self.buf) == 0 && - Header::additional_count(&self.buf) == 0 + Header::question_count(&self.buf) == 0 + && Header::answer_count(&self.buf) == 0 + && Header::nameserver_count(&self.buf) == 0 + && Header::additional_count(&self.buf) == 0 } } -impl > Builder { +impl> Builder { /// Adds a question to the packet /// /// # Panics /// /// * There are already 65535 questions in the buffer. #[allow(dead_code)] - pub fn add_question(self, qname: &Name, - qtype: QueryType, qclass: QueryClass) - -> Builder - { + pub fn add_question( + self, + qname: &Name, + qtype: QueryType, + qclass: QueryClass, + ) -> Builder { let mut builder = self.move_to::(); qname.write_to(&mut builder.buf).unwrap(); builder.buf.write_u16::(qtype as u16).unwrap(); builder.buf.write_u16::(qclass as u16).unwrap(); - Header::inc_questions(&mut builder.buf) - .expect("Too many questions"); + Header::inc_questions(&mut builder.buf).expect("Too many questions"); builder } } -impl > Builder { - pub fn add_answer(self, name: &Name, - cls: QueryClass, ttl: u32, data: &RRData) - -> Builder - { +impl> Builder { + pub fn add_answer( + self, + name: &Name, + cls: QueryClass, + ttl: u32, + data: &RRData, + ) -> Builder { let mut builder = self.move_to::(); builder.write_rr(name, cls, ttl, data); - Header::inc_answers(&mut builder.buf) - .expect("Too many answers"); + Header::inc_answers(&mut builder.buf).expect("Too many answers"); builder } } -impl > Builder { +impl> Builder { #[allow(dead_code)] - pub fn add_nameserver(self, name: &Name, - cls: QueryClass, ttl: u32, data: &RRData) - -> Builder - { + pub fn add_nameserver( + self, + name: &Name, + cls: QueryClass, + ttl: u32, + data: &RRData, + ) -> Builder { let mut builder = self.move_to::(); builder.write_rr(name, cls, ttl, data); - Header::inc_nameservers(&mut builder.buf) - .expect("Too many nameservers"); + Header::inc_nameservers(&mut builder.buf).expect("Too many nameservers"); builder } @@ -195,15 +217,17 @@ impl > Builder { impl Builder { #[allow(dead_code)] - pub fn add_additional(self, name: &Name, - cls: QueryClass, ttl: u32, data: &RRData) - -> Builder - { + pub fn add_additional( + self, + name: &Name, + cls: QueryClass, + ttl: u32, + data: &RRData, + ) -> Builder { let mut builder = self.move_to::(); builder.write_rr(name, cls, ttl, data); - Header::inc_nameservers(&mut builder.buf) - .expect("Too many additional answers"); + Header::inc_nameservers(&mut builder.buf).expect("Too many additional answers"); builder } @@ -211,10 +235,10 @@ impl Builder { #[cfg(test)] mod test { - use super::QueryType as QT; - use super::QueryClass as QC; - use super::Name; use super::Builder; + use super::Name; + use super::QueryClass as QC; + use super::QueryType as QT; #[test] fn build_query() { diff --git a/src/dns_parser/enums.rs b/src/dns_parser/enums.rs index 55f63b3..e2c7b1c 100644 --- a/src/dns_parser/enums.rs +++ b/src/dns_parser/enums.rs @@ -92,7 +92,6 @@ pub enum QueryType { All = 255, } - /// The CLASS value according to RFC 1035 #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] pub enum Class { @@ -171,13 +170,13 @@ impl From for ResponseCode { fn from(code: u8) -> ResponseCode { use self::ResponseCode::*; match code { - 0 => NoError, - 1 => FormatError, - 2 => ServerFailure, - 3 => NameError, - 4 => NotImplemented, - 5 => Refused, - 6...15 => Reserved(code), + 0 => NoError, + 1 => FormatError, + 2 => ServerFailure, + 3 => NameError, + 4 => NotImplemented, + 5 => Refused, + 6..=15 => Reserved(code), x => panic!("Invalid response code {}", x), } } @@ -201,23 +200,23 @@ impl QueryType { pub fn parse(code: u16) -> Result { use self::QueryType::*; match code { - 1 => Ok(A), - 2 => Ok(NS), - 4 => Ok(MF), - 5 => Ok(CNAME), - 6 => Ok(SOA), - 7 => Ok(MB), - 8 => Ok(MG), - 9 => Ok(MR), - 10 => Ok(NULL), - 11 => Ok(WKS), - 12 => Ok(PTR), - 13 => Ok(HINFO), - 14 => Ok(MINFO), - 15 => Ok(MX), - 16 => Ok(TXT), - 28 => Ok(AAAA), - 33 => Ok(SRV), + 1 => Ok(A), + 2 => Ok(NS), + 4 => Ok(MF), + 5 => Ok(CNAME), + 6 => Ok(SOA), + 7 => Ok(MB), + 8 => Ok(MG), + 9 => Ok(MR), + 10 => Ok(NULL), + 11 => Ok(WKS), + 12 => Ok(PTR), + 13 => Ok(HINFO), + 14 => Ok(MINFO), + 15 => Ok(MX), + 16 => Ok(TXT), + 28 => Ok(AAAA), + 33 => Ok(SRV), 252 => Ok(AXFR), 253 => Ok(MAILB), 254 => Ok(MAILA), @@ -231,10 +230,10 @@ impl QueryClass { pub fn parse(code: u16) -> Result { use self::QueryClass::*; match code { - 1 => Ok(IN), - 2 => Ok(CS), - 3 => Ok(CH), - 4 => Ok(HS), + 1 => Ok(IN), + 2 => Ok(CS), + 3 => Ok(CH), + 4 => Ok(HS), 255 => Ok(Any), x => Err(Error::InvalidQueryClass(x)), } @@ -245,24 +244,24 @@ impl Type { pub fn parse(code: u16) -> Result { use self::Type::*; match code { - 1 => Ok(A), - 2 => Ok(NS), - 4 => Ok(MF), - 5 => Ok(CNAME), - 6 => Ok(SOA), - 7 => Ok(MB), - 8 => Ok(MG), - 9 => Ok(MR), - 10 => Ok(NULL), - 11 => Ok(WKS), - 12 => Ok(PTR), - 13 => Ok(HINFO), - 14 => Ok(MINFO), - 15 => Ok(MX), - 16 => Ok(TXT), - 28 => Ok(AAAA), - 33 => Ok(SRV), - 41 => Ok(OPT), + 1 => Ok(A), + 2 => Ok(NS), + 4 => Ok(MF), + 5 => Ok(CNAME), + 6 => Ok(SOA), + 7 => Ok(MB), + 8 => Ok(MG), + 9 => Ok(MR), + 10 => Ok(NULL), + 11 => Ok(WKS), + 12 => Ok(PTR), + 13 => Ok(HINFO), + 14 => Ok(MINFO), + 15 => Ok(MX), + 16 => Ok(TXT), + 28 => Ok(AAAA), + 33 => Ok(SRV), + 41 => Ok(OPT), x => Err(Error::InvalidType(x)), } } @@ -272,10 +271,10 @@ impl Class { pub fn parse(code: u16) -> Result { use self::Class::*; match code { - 1 => Ok(IN), - 2 => Ok(CS), - 3 => Ok(CH), - 4 => Ok(HS), + 1 => Ok(IN), + 2 => Ok(CS), + 3 => Ok(CH), + 4 => Ok(HS), x => Err(Error::InvalidClass(x)), } } diff --git a/src/dns_parser/header.rs b/src/dns_parser/header.rs index ddc908f..263d002 100644 --- a/src/dns_parser/header.rs +++ b/src/dns_parser/header.rs @@ -1,16 +1,16 @@ use byteorder::{BigEndian, ByteOrder}; -use super::{Error, ResponseCode, Opcode}; +use super::{Error, Opcode, ResponseCode}; mod flag { - pub const QUERY: u16 = 0b1000_0000_0000_0000; - pub const OPCODE_MASK: u16 = 0b0111_1000_0000_0000; - pub const AUTHORITATIVE: u16 = 0b0000_0100_0000_0000; - pub const TRUNCATED: u16 = 0b0000_0010_0000_0000; - pub const RECURSION_DESIRED: u16 = 0b0000_0001_0000_0000; + pub const QUERY: u16 = 0b1000_0000_0000_0000; + pub const OPCODE_MASK: u16 = 0b0111_1000_0000_0000; + pub const AUTHORITATIVE: u16 = 0b0000_0100_0000_0000; + pub const TRUNCATED: u16 = 0b0000_0010_0000_0000; + pub const RECURSION_DESIRED: u16 = 0b0000_0001_0000_0000; pub const RECURSION_AVAILABLE: u16 = 0b0000_0000_1000_0000; - pub const RESERVED_MASK: u16 = 0b0000_0000_0111_0000; - pub const RESPONSE_CODE_MASK: u16 = 0b0000_0000_0000_1111; + pub const RESERVED_MASK: u16 = 0b0000_0000_0111_0000; + pub const RESPONSE_CODE_MASK: u16 = 0b0000_0000_0000_1111; } /// Represents parsed header of the packet @@ -42,13 +42,12 @@ impl Header { let header = Header { id: BigEndian::read_u16(&data[..2]), query: flags & flag::QUERY == 0, - opcode: (flags & flag::OPCODE_MASK - >> flag::OPCODE_MASK.trailing_zeros()).into(), + opcode: (flags & flag::OPCODE_MASK >> flag::OPCODE_MASK.trailing_zeros()).into(), authoritative: flags & flag::AUTHORITATIVE != 0, truncated: flags & flag::TRUNCATED != 0, recursion_desired: flags & flag::RECURSION_DESIRED != 0, recursion_available: flags & flag::RECURSION_AVAILABLE != 0, - response_code: From::from((flags&flag::RESPONSE_CODE_MASK) as u8), + response_code: From::from((flags & flag::RESPONSE_CODE_MASK) as u8), questions: BigEndian::read_u16(&data[4..6]), answers: BigEndian::read_u16(&data[6..8]), nameservers: BigEndian::read_u16(&data[8..10]), @@ -66,14 +65,23 @@ impl Header { panic!("Header size is exactly 12 bytes"); } let mut flags = 0u16; - flags |= Into::::into(self.opcode) - << flag::OPCODE_MASK.trailing_zeros(); + flags |= Into::::into(self.opcode) << flag::OPCODE_MASK.trailing_zeros(); flags |= Into::::into(self.response_code) as u16; - if !self.query { flags |= flag::QUERY; } - if self.authoritative { flags |= flag::AUTHORITATIVE; } - if self.recursion_desired { flags |= flag::RECURSION_DESIRED; } - if self.recursion_available { flags |= flag::RECURSION_AVAILABLE; } - if self.truncated { flags |= flag::TRUNCATED; } + if !self.query { + flags |= flag::QUERY; + } + if self.authoritative { + flags |= flag::AUTHORITATIVE; + } + if self.recursion_desired { + flags |= flag::RECURSION_DESIRED; + } + if self.recursion_available { + flags |= flag::RECURSION_AVAILABLE; + } + if self.truncated { + flags |= flag::TRUNCATED; + } BigEndian::write_u16(&mut data[..2], self.id); BigEndian::write_u16(&mut data[2..4], flags); BigEndian::write_u16(&mut data[4..6], self.questions); @@ -105,7 +113,7 @@ impl Header { pub fn inc_questions(data: &mut [u8]) -> Option { let oldq = BigEndian::read_u16(&data[4..6]); if oldq < 65535 { - BigEndian::write_u16(&mut data[4..6], oldq+1); + BigEndian::write_u16(&mut data[4..6], oldq + 1); Some(oldq + 1) } else { None @@ -115,7 +123,7 @@ impl Header { pub fn inc_answers(data: &mut [u8]) -> Option { let oldq = BigEndian::read_u16(&data[6..8]); if oldq < 65535 { - BigEndian::write_u16(&mut data[6..8], oldq+1); + BigEndian::write_u16(&mut data[6..8], oldq + 1); Some(oldq + 1) } else { None @@ -125,7 +133,7 @@ impl Header { pub fn inc_nameservers(data: &mut [u8]) -> Option { let oldq = BigEndian::read_u16(&data[8..10]); if oldq < 65535 { - BigEndian::write_u16(&mut data[8..10], oldq+1); + BigEndian::write_u16(&mut data[8..10], oldq + 1); Some(oldq + 1) } else { None @@ -136,16 +144,17 @@ impl Header { pub fn inc_additional(data: &mut [u8]) -> Option { let oldq = BigEndian::read_u16(&data[10..12]); if oldq < 65535 { - BigEndian::write_u16(&mut data[10..12], oldq+1); + BigEndian::write_u16(&mut data[10..12], oldq + 1); Some(oldq + 1) } else { None } } - pub fn size() -> usize { 12 } + pub fn size() -> usize { + 12 + } } - #[cfg(test)] mod test { @@ -158,20 +167,23 @@ mod test { let query = b"\x06%\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\ \x07example\x03com\x00\x00\x01\x00\x01"; let header = Header::parse(query).unwrap(); - assert_eq!(header, Header { - id: 1573, - query: true, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: false, - response_code: NoError, - questions: 1, - answers: 0, - nameservers: 0, - additional: 0, - }); + assert_eq!( + header, + Header { + id: 1573, + query: true, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: false, + response_code: NoError, + questions: 1, + answers: 0, + nameservers: 0, + additional: 0, + } + ); } #[test] @@ -181,19 +193,22 @@ mod test { \xc0\x0c\x00\x01\x00\x01\x00\x00\x04\xf8\ \x00\x04]\xb8\xd8\""; let header = Header::parse(response).unwrap(); - assert_eq!(header, Header { - id: 1573, - query: false, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: true, - response_code: NoError, - questions: 1, - answers: 1, - nameservers: 0, - additional: 0, - }); + assert_eq!( + header, + Header { + id: 1573, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 1, + nameservers: 0, + additional: 0, + } + ); } } diff --git a/src/dns_parser/mod.rs b/src/dns_parser/mod.rs index a05cb31..23234a3 100644 --- a/src/dns_parser/mod.rs +++ b/src/dns_parser/mod.rs @@ -1,15 +1,15 @@ mod error; -pub use self::error::{Error}; +pub use self::error::Error; mod enums; -pub use self::enums::{Type, QueryType, Class, QueryClass, ResponseCode, Opcode}; +pub use self::enums::{Class, Opcode, QueryClass, QueryType, ResponseCode, Type}; mod structs; -pub use self::structs::{Question, ResourceRecord, Packet}; +pub use self::structs::{Packet, Question, ResourceRecord}; mod name; -pub use self::name::{Name}; -mod parser; +pub use self::name::Name; mod header; -pub use self::header::{Header}; +mod parser; +pub use self::header::Header; mod rrdata; -pub use self::rrdata::{RRData}; +pub use self::rrdata::RRData; mod builder; -pub use self::builder::{Builder, Questions, Answers}; +pub use self::builder::{Answers, Builder, Questions}; diff --git a/src/dns_parser/name.rs b/src/dns_parser/name.rs index eec408e..40a4157 100644 --- a/src/dns_parser/name.rs +++ b/src/dns_parser/name.rs @@ -1,13 +1,13 @@ -use std::io; +use std::borrow::Cow; use std::fmt; use std::fmt::Write; -use std::str::from_utf8; -use std::borrow::Cow; use std::hash; +use std::io; +use std::str::from_utf8; use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; -use super::{Error}; +use super::Error; /// The DNS name as stored in the original packet /// @@ -26,7 +26,7 @@ pub enum Name<'a> { } impl<'a> Name<'a> { - pub fn scan(data: &'a[u8], original: &'a[u8]) -> Result<(Name<'a>, usize), Error> { + pub fn scan(data: &'a [u8], original: &'a [u8]) -> Result<(Name<'a>, usize), Error> { let mut pos = 0; loop { if data.len() <= pos { @@ -34,22 +34,34 @@ impl<'a> Name<'a> { } let byte = data[pos]; if byte == 0 { - return Ok((Name::FromPacket { labels: &data[..pos+1], original: original }, pos + 1)); + return Ok(( + Name::FromPacket { + labels: &data[..pos + 1], + original: original, + }, + pos + 1, + )); } else if byte & 0b1100_0000 == 0b1100_0000 { - if data.len() < pos+2 { + if data.len() < pos + 2 { return Err(Error::UnexpectedEOF); } - let off = (BigEndian::read_u16(&data[pos..pos+2]) - & !0b1100_0000_0000_0000) as usize; + let off = + (BigEndian::read_u16(&data[pos..pos + 2]) & !0b1100_0000_0000_0000) as usize; if off >= original.len() { return Err(Error::UnexpectedEOF); } // Validate referred to location - try!(Name::scan(&original[off..], original)); - return Ok((Name::FromPacket { labels: &data[..pos+2], original: original }, pos + 2)); + Name::scan(&original[off..], original)?; + return Ok(( + Name::FromPacket { + labels: &data[..pos + 2], + original: original, + }, + pos + 2, + )); } else if byte & 0b1100_0000 == 0 { let end = pos + byte as usize + 1; - if from_utf8(&data[pos+1..end]).is_err() { + if from_utf8(&data[pos + 1..end]).is_err() { return Err(Error::LabelIsNotAscii); } pos = end; @@ -74,15 +86,18 @@ impl<'a> Name<'a> { loop { let byte = labels[pos]; if byte == 0 { - try!(writer.write_u8(0)); + writer.write_u8(0)?; return Ok(()); } else if byte & 0b1100_0000 == 0b1100_0000 { - let off = (BigEndian::read_u16(&labels[pos..pos+2]) - & !0b1100_0000_0000_0000) as usize; - return Name::scan(&original[off..], original).unwrap().0.write_to(writer) + let off = (BigEndian::read_u16(&labels[pos..pos + 2]) + & !0b1100_0000_0000_0000) as usize; + return Name::scan(&original[off..], original) + .unwrap() + .0 + .write_to(writer); } else if byte & 0b1100_0000 == 0 { let end = pos + byte as usize + 1; - try!(writer.write(&labels[pos..end])); + writer.write(&labels[pos..end])?; pos = end; continue; } else { @@ -95,10 +110,10 @@ impl<'a> Name<'a> { for part in name.split('.') { assert!(part.len() < 63); let ln = part.len() as u8; - try!(writer.write_u8(ln)); - try!(writer.write(part.as_bytes())); + writer.write_u8(ln)?; + writer.write(part.as_bytes())?; } - try!(writer.write_u8(0)); + writer.write_u8(0)?; Ok(()) } @@ -116,19 +131,21 @@ impl<'a> fmt::Display for Name<'a> { if byte == 0 { return Ok(()); } else if byte & 0b1100_0000 == 0b1100_0000 { - let off = (BigEndian::read_u16(&labels[pos..pos+2]) - & !0b1100_0000_0000_0000) as usize; + let off = (BigEndian::read_u16(&labels[pos..pos + 2]) + & !0b1100_0000_0000_0000) as usize; if pos != 0 { - try!(fmt.write_char('.')); + fmt.write_char('.')?; } return fmt::Display::fmt( - &Name::scan(&original[off..], original).unwrap().0, fmt) + &Name::scan(&original[off..], original).unwrap().0, + fmt, + ); } else if byte & 0b1100_0000 == 0 { if pos != 0 { - try!(fmt.write_char('.')); + fmt.write_char('.')?; } let end = pos + byte as usize + 1; - try!(fmt.write_str(from_utf8(&labels[pos+1..end]).unwrap())); + fmt.write_str(from_utf8(&labels[pos + 1..end]).unwrap())?; pos = end; continue; } else { @@ -137,20 +154,23 @@ impl<'a> fmt::Display for Name<'a> { } } - Name::FromStr(ref name) => fmt.write_str(&name) + Name::FromStr(ref name) => fmt.write_str(&name), } } } -impl <'a> hash::Hash for Name<'a> { - fn hash(&self, state: &mut H) where H: hash::Hasher { +impl<'a> hash::Hash for Name<'a> { + fn hash(&self, state: &mut H) + where + H: hash::Hasher, + { let mut buffer = Vec::new(); self.write_to(&mut buffer).unwrap(); hash::Hash::hash(&buffer, state) } } -impl <'a> PartialEq for Name<'a> { +impl<'a> PartialEq for Name<'a> { fn eq(&self, other: &Name) -> bool { let mut buffer = Vec::new(); self.write_to(&mut buffer).unwrap(); @@ -162,4 +182,4 @@ impl <'a> PartialEq for Name<'a> { } } -impl <'a> Eq for Name<'a> {} +impl<'a> Eq for Name<'a> {} diff --git a/src/dns_parser/parser.rs b/src/dns_parser/parser.rs index d264f14..9705aec 100644 --- a/src/dns_parser/parser.rs +++ b/src/dns_parser/parser.rs @@ -2,26 +2,24 @@ use std::i32; use byteorder::{BigEndian, ByteOrder}; -use super::{Header, Packet, Error, Question, Name, QueryType, QueryClass}; -use super::{Type, Class, ResourceRecord, RRData}; - +use super::{Class, RRData, ResourceRecord, Type}; +use super::{Error, Header, Name, Packet, QueryClass, QueryType, Question}; impl<'a> Packet<'a> { pub fn parse(data: &[u8]) -> Result { - let header = try!(Header::parse(data)); + let header = Header::parse(data)?; let mut offset = Header::size(); let mut questions = Vec::with_capacity(header.questions as usize); for _ in 0..header.questions { - let (name, name_size) = try!(Name::scan(&data[offset..], data)); + let (name, name_size) = Name::scan(&data[offset..], data)?; offset += name_size; if offset + 4 > data.len() { return Err(Error::UnexpectedEOF); } - let qtype = try!(QueryType::parse( - BigEndian::read_u16(&data[offset..offset+2]))); + let qtype = QueryType::parse(BigEndian::read_u16(&data[offset..offset + 2]))?; offset += 2; - let qclass_qu = BigEndian::read_u16(&data[offset..offset+2]); - let qclass = try!(QueryClass::parse(qclass_qu & 0x7fff)); + let qclass_qu = BigEndian::read_u16(&data[offset..offset + 2]); + let qclass = QueryClass::parse(qclass_qu & 0x7fff)?; let qu = (qclass_qu & 0x8000) != 0; offset += 2; @@ -34,11 +32,11 @@ impl<'a> Packet<'a> { } let mut answers = Vec::with_capacity(header.answers as usize); for _ in 0..header.answers { - answers.push(try!(parse_record(data, &mut offset))); + answers.push(parse_record(data, &mut offset)?); } let mut nameservers = Vec::with_capacity(header.nameservers as usize); for _ in 0..header.nameservers { - nameservers.push(try!(parse_record(data, &mut offset))); + nameservers.push(parse_record(data, &mut offset)?); } Ok(Packet { header: header, @@ -52,29 +50,26 @@ impl<'a> Packet<'a> { // Generic function to parse answer, nameservers, and additional records. fn parse_record<'a>(data: &'a [u8], offset: &mut usize) -> Result, Error> { - let (name, name_size) = try!(Name::scan(&data[*offset..], data)); + let (name, name_size) = Name::scan(&data[*offset..], data)?; *offset += name_size; if *offset + 10 > data.len() { return Err(Error::UnexpectedEOF); } - let typ = try!(Type::parse( - BigEndian::read_u16(&data[*offset..*offset+2]))); + let typ = Type::parse(BigEndian::read_u16(&data[*offset..*offset + 2]))?; *offset += 2; - let cls = try!(Class::parse( - BigEndian::read_u16(&data[*offset..*offset+2]) & 0x7fff )); + let cls = Class::parse(BigEndian::read_u16(&data[*offset..*offset + 2]) & 0x7fff)?; *offset += 2; - let mut ttl = BigEndian::read_u32(&data[*offset..*offset+4]); + let mut ttl = BigEndian::read_u32(&data[*offset..*offset + 4]); if ttl > i32::MAX as u32 { ttl = 0; } *offset += 4; - let rdlen = BigEndian::read_u16(&data[*offset..*offset+2]) as usize; + let rdlen = BigEndian::read_u16(&data[*offset..*offset + 2]) as usize; *offset += 2; if *offset + rdlen > data.len() { return Err(Error::UnexpectedEOF); } - let data = try!(RRData::parse(typ, - &data[*offset..*offset+rdlen], data)); + let data = RRData::parse(typ, &data[*offset..*offset + rdlen], data)?; *offset += rdlen; Ok(ResourceRecord { name: name, @@ -87,34 +82,37 @@ fn parse_record<'a>(data: &'a [u8], offset: &mut usize) -> Result { - assert_eq!(&cname.to_string()[..], "livecms.trafficmanager.net"); - } - ref x => panic!("Wrong rdata {:?}", x), - } - assert_eq!(packet.nameservers.len(), 1); - assert_eq!(&packet.nameservers[0].name.to_string()[..], "net"); - assert_eq!(packet.nameservers[0].cls, C::IN); - assert_eq!(packet.nameservers[0].ttl, 120275); - match packet.nameservers[0].data { - RRData::NS(ref ns) => { - assert_eq!(&ns.to_string()[..], "g.gtld-servers.net"); - } - ref x => panic!("Wrong rdata {:?}", x), - } - } + let packet = Packet::parse(response).unwrap(); + assert_eq!( + packet.header, + Header { + id: 19184, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 1, + nameservers: 1, + additional: 0, + } + ); + assert_eq!(packet.questions.len(), 1); + assert_eq!(packet.questions[0].qtype, QT::A); + assert_eq!(packet.questions[0].qclass, QC::IN); + assert_eq!(&packet.questions[0].qname.to_string()[..], "www.skype.com"); + assert_eq!(packet.answers.len(), 1); + assert_eq!(&packet.answers[0].name.to_string()[..], "www.skype.com"); + assert_eq!(packet.answers[0].cls, C::IN); + assert_eq!(packet.answers[0].ttl, 3600); + match packet.answers[0].data { + RRData::CNAME(ref cname) => { + assert_eq!(&cname.to_string()[..], "livecms.trafficmanager.net"); + } + ref x => panic!("Wrong rdata {:?}", x), + } + assert_eq!(packet.nameservers.len(), 1); + assert_eq!(&packet.nameservers[0].name.to_string()[..], "net"); + assert_eq!(packet.nameservers[0].cls, C::IN); + assert_eq!(packet.nameservers[0].ttl, 120275); + match packet.nameservers[0].data { + RRData::NS(ref ns) => { + assert_eq!(&ns.to_string()[..], "g.gtld-servers.net"); + } + ref x => panic!("Wrong rdata {:?}", x), + } + } #[test] fn parse_multiple_answers() { @@ -224,20 +228,23 @@ mod test { \xe9\xa4e\xc0\x0c\x00\x01\x00\x01\x00\x00\x00\xef\ \x00\x04@\xe9\xa4\x8a"; let packet = Packet::parse(response).unwrap(); - assert_eq!(packet.header, Header { - id: 40425, - query: false, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: true, - response_code: NoError, - questions: 1, - answers: 6, - nameservers: 0, - additional: 0, - }); + assert_eq!( + packet.header, + Header { + id: 40425, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 6, + nameservers: 0, + additional: 0, + } + ); assert_eq!(packet.questions.len(), 1); assert_eq!(packet.questions[0].qtype, QT::A); assert_eq!(packet.questions[0].qclass, QC::IN); @@ -269,25 +276,30 @@ mod test { let query = b"[\xd9\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\ \x0c_xmpp-server\x04_tcp\x05gmail\x03com\x00\x00!\x00\x01"; let packet = Packet::parse(query).unwrap(); - assert_eq!(packet.header, Header { - id: 23513, - query: true, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: false, - response_code: NoError, - questions: 1, - answers: 0, - nameservers: 0, - additional: 0, - }); + assert_eq!( + packet.header, + Header { + id: 23513, + query: true, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: false, + response_code: NoError, + questions: 1, + answers: 0, + nameservers: 0, + additional: 0, + } + ); assert_eq!(packet.questions.len(), 1); assert_eq!(packet.questions[0].qtype, QT::SRV); assert_eq!(packet.questions[0].qclass, QC::IN); - assert_eq!(&packet.questions[0].qname.to_string()[..], - "_xmpp-server._tcp.gmail.com"); + assert_eq!( + &packet.questions[0].qname.to_string()[..], + "_xmpp-server._tcp.gmail.com" + ); assert_eq!(packet.answers.len(), 0); } @@ -306,25 +318,30 @@ mod test { \xc0\x0c\x00!\x00\x01\x00\x00\x03\x84\x00%\x00\x14\x00\x00\ \x14\x95\x04alt4\x0bxmpp-server\x01l\x06google\x03com\x00"; let packet = Packet::parse(response).unwrap(); - assert_eq!(packet.header, Header { - id: 23513, - query: false, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: true, - response_code: NoError, - questions: 1, - answers: 5, - nameservers: 0, - additional: 0, - }); + assert_eq!( + packet.header, + Header { + id: 23513, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 5, + nameservers: 0, + additional: 0, + } + ); assert_eq!(packet.questions.len(), 1); assert_eq!(packet.questions[0].qtype, QT::SRV); assert_eq!(packet.questions[0].qclass, QC::IN); - assert_eq!(&packet.questions[0].qname.to_string()[..], - "_xmpp-server._tcp.gmail.com"); + assert_eq!( + &packet.questions[0].qname.to_string()[..], + "_xmpp-server._tcp.gmail.com" + ); assert_eq!(packet.answers.len(), 5); let items = vec![ (5, 0, 5269, "xmpp-server.l.google.com"), @@ -334,12 +351,19 @@ mod test { (20, 0, 5269, "alt4.xmpp-server.l.google.com"), ]; for i in 0..5 { - assert_eq!(&packet.answers[i].name.to_string()[..], - "_xmpp-server._tcp.gmail.com"); + assert_eq!( + &packet.answers[i].name.to_string()[..], + "_xmpp-server._tcp.gmail.com" + ); assert_eq!(packet.answers[i].cls, C::IN); assert_eq!(packet.answers[i].ttl, 900); match *&packet.answers[i].data { - RRData::SRV { priority, weight, port, ref target } => { + RRData::SRV { + priority, + weight, + port, + ref target, + } => { assert_eq!(priority, items[i].0); assert_eq!(weight, items[i].1); assert_eq!(port, items[i].2); @@ -361,40 +385,44 @@ mod test { \x00\x04|\x00\t\x00\x14\x04alt2\xc0)\xc0\x0c\x00\x0f\ \x00\x01\x00\x00\x04|\x00\t\x00\x1e\x04alt3\xc0)"; let packet = Packet::parse(response).unwrap(); - assert_eq!(packet.header, Header { - id: 58344, - query: false, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: true, - response_code: NoError, - questions: 1, - answers: 5, - nameservers: 0, - additional: 0, - }); + assert_eq!( + packet.header, + Header { + id: 58344, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 5, + nameservers: 0, + additional: 0, + } + ); assert_eq!(packet.questions.len(), 1); assert_eq!(packet.questions[0].qtype, QT::MX); assert_eq!(packet.questions[0].qclass, QC::IN); - assert_eq!(&packet.questions[0].qname.to_string()[..], - "gmail.com"); + assert_eq!(&packet.questions[0].qname.to_string()[..], "gmail.com"); assert_eq!(packet.answers.len(), 5); let items = vec![ - ( 5, "gmail-smtp-in.l.google.com"), + (5, "gmail-smtp-in.l.google.com"), (10, "alt1.gmail-smtp-in.l.google.com"), (40, "alt4.gmail-smtp-in.l.google.com"), (20, "alt2.gmail-smtp-in.l.google.com"), (30, "alt3.gmail-smtp-in.l.google.com"), ]; for i in 0..5 { - assert_eq!(&packet.answers[i].name.to_string()[..], - "gmail.com"); + assert_eq!(&packet.answers[i].name.to_string()[..], "gmail.com"); assert_eq!(packet.answers[i].cls, C::IN); assert_eq!(packet.answers[i].ttl, 1148); match *&packet.answers[i].data { - RRData::MX { preference, ref exchange } => { + RRData::MX { + preference, + ref exchange, + } => { assert_eq!(preference, items[i].0); assert_eq!(exchange.to_string(), (items[i].1).to_string()); } @@ -410,20 +438,23 @@ mod test { \x00\x8b\x00\x10*\x00\x14P@\t\x08\x12\x00\x00\x00\x00\x00\x00 \x0e"; let packet = Packet::parse(response).unwrap(); - assert_eq!(packet.header, Header { - id: 43481, - query: false, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: true, - response_code: NoError, - questions: 1, - answers: 1, - nameservers: 0, - additional: 0, - }); + assert_eq!( + packet.header, + Header { + id: 43481, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 1, + nameservers: 0, + additional: 0, + } + ); assert_eq!(packet.questions.len(), 1); assert_eq!(packet.questions[0].qtype, QT::AAAA); @@ -435,8 +466,9 @@ mod test { assert_eq!(packet.answers[0].ttl, 139); match packet.answers[0].data { RRData::AAAA(addr) => { - assert_eq!(addr, Ipv6Addr::new( - 0x2A00, 0x1450, 0x4009, 0x812, 0, 0, 0, 0x200e) + assert_eq!( + addr, + Ipv6Addr::new(0x2A00, 0x1450, 0x4009, 0x812, 0, 0, 0, 0x200e) ); } ref x => panic!("Wrong rdata {:?}", x), @@ -458,25 +490,31 @@ mod test { \x00\x99L\x00\x04\xad\xf5;\x04"; let packet = Packet::parse(response).unwrap(); - assert_eq!(packet.header, Header { - id: 64669, - query: false, - opcode: StandardQuery, - authoritative: false, - truncated: false, - recursion_desired: true, - recursion_available: true, - response_code: NoError, - questions: 1, - answers: 6, - nameservers: 2, - additional: 2, - }); + assert_eq!( + packet.header, + Header { + id: 64669, + query: false, + opcode: StandardQuery, + authoritative: false, + truncated: false, + recursion_desired: true, + recursion_available: true, + response_code: NoError, + questions: 1, + answers: 6, + nameservers: 2, + additional: 2, + } + ); assert_eq!(packet.questions.len(), 1); assert_eq!(packet.questions[0].qtype, QT::A); assert_eq!(packet.questions[0].qclass, QC::IN); - assert_eq!(&packet.questions[0].qname.to_string()[..], "cdn.sstatic.net"); + assert_eq!( + &packet.questions[0].qname.to_string()[..], + "cdn.sstatic.net" + ); assert_eq!(packet.answers.len(), 6); assert_eq!(&packet.answers[0].name.to_string()[..], "cdn.sstatic.net"); assert_eq!(packet.answers[0].cls, C::IN); @@ -501,7 +539,7 @@ mod test { assert_eq!(packet.answers[i].ttl, 102); match packet.answers[i].data { RRData::A(addr) => { - assert_eq!(addr, ips[i-1]); + assert_eq!(addr, ips[i - 1]); } ref x => panic!("Wrong rdata {:?}", x), } diff --git a/src/dns_parser/rrdata.rs b/src/dns_parser/rrdata.rs index 67ca2f8..0c4f40b 100644 --- a/src/dns_parser/rrdata.rs +++ b/src/dns_parser/rrdata.rs @@ -3,8 +3,7 @@ use std::net::{Ipv4Addr, Ipv6Addr}; use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; -use super::{Name, Type, Error}; - +use super::{Error, Name, Type}; /// The enumeration that represents known types of DNS resource records data #[derive(Debug, Clone)] @@ -14,11 +13,22 @@ pub enum RRData<'a> { PTR(Name<'a>), A(Ipv4Addr), AAAA(Ipv6Addr), - SRV { priority: u16, weight: u16, port: u16, target: Name<'a> }, - MX { preference: u16, exchange: Name<'a> }, + SRV { + priority: u16, + weight: u16, + port: u16, + target: Name<'a>, + }, + MX { + preference: u16, + exchange: Name<'a>, + }, TXT(&'a [u8]), // Anything that can't be parsed yet - Unknown { typ: Type, data: &'a [u8] }, + Unknown { + typ: Type, + data: &'a [u8], + }, } impl<'a> RRData<'a> { @@ -38,26 +48,34 @@ impl<'a> RRData<'a> { pub fn write_to(&self, writer: &mut T) -> io::Result<()> { match *self { - RRData::CNAME(ref name) | - RRData::NS(ref name) | - RRData::PTR(ref name) => name.write_to(writer), + RRData::CNAME(ref name) | RRData::NS(ref name) | RRData::PTR(ref name) => { + name.write_to(writer) + } RRData::A(ip) => writer.write_u32::(ip.into()), RRData::AAAA(ip) => { - for segment in ip.segments().into_iter() { - try!(writer.write_u16::(*segment)); + for segment in ip.segments().iter() { + writer.write_u16::(*segment)?; } Ok(()) } - RRData::SRV { priority, weight, port, ref target } => { - try!(writer.write_u16::(priority)); - try!(writer.write_u16::(weight)); - try!(writer.write_u16::(port)); + RRData::SRV { + priority, + weight, + port, + ref target, + } => { + writer.write_u16::(priority)?; + writer.write_u16::(weight)?; + writer.write_u16::(port)?; target.write_to(writer) } - RRData::MX { preference, ref exchange } => { - try!(writer.write_u16::(preference)); + RRData::MX { + preference, + ref exchange, + } => { + writer.write_u16::(preference)?; exchange.write_to(writer) } RRData::TXT(data) => writer.write_all(data), @@ -65,16 +83,13 @@ impl<'a> RRData<'a> { } } - pub fn parse(typ: Type, rdata: &'a [u8], original: &'a [u8]) - -> Result, Error> - { + pub fn parse(typ: Type, rdata: &'a [u8], original: &'a [u8]) -> Result, Error> { match typ { Type::A => { if rdata.len() != 4 { return Err(Error::WrongRdataLength); } - Ok(RRData::A( - Ipv4Addr::from(BigEndian::read_u32(rdata)))) + Ok(RRData::A(Ipv4Addr::from(BigEndian::read_u32(rdata)))) } Type::AAAA => { if rdata.len() != 16 { @@ -91,22 +106,16 @@ impl<'a> RRData<'a> { BigEndian::read_u16(&rdata[14..16]), ))) } - Type::CNAME => { - Ok(RRData::CNAME(try!(Name::scan(rdata, original)).0)) - } - Type::NS => { - Ok(RRData::NS(try!(Name::scan(rdata, original)).0)) - } - Type::PTR => { - Ok(RRData::PTR(try!(Name::scan(rdata, original)).0)) - } + Type::CNAME => Ok(RRData::CNAME(Name::scan(rdata, original)?.0)), + Type::NS => Ok(RRData::NS(Name::scan(rdata, original)?.0)), + Type::PTR => Ok(RRData::PTR(Name::scan(rdata, original)?.0)), Type::MX => { if rdata.len() < 3 { return Err(Error::WrongRdataLength); } Ok(RRData::MX { preference: BigEndian::read_u16(&rdata[..2]), - exchange: try!(Name::scan(&rdata[2..], original)).0, + exchange: Name::scan(&rdata[2..], original)?.0, }) } Type::SRV => { @@ -117,16 +126,14 @@ impl<'a> RRData<'a> { priority: BigEndian::read_u16(&rdata[..2]), weight: BigEndian::read_u16(&rdata[2..4]), port: BigEndian::read_u16(&rdata[4..6]), - target: try!(Name::scan(&rdata[6..], original)).0, + target: Name::scan(&rdata[6..], original)?.0, }) } Type::TXT => Ok(RRData::TXT(rdata)), - typ => { - Ok(RRData::Unknown { - typ: typ, - data: rdata - }) - } + typ => Ok(RRData::Unknown { + typ: typ, + data: rdata, + }), } } } diff --git a/src/dns_parser/structs.rs b/src/dns_parser/structs.rs index ac0d125..50bdd44 100644 --- a/src/dns_parser/structs.rs +++ b/src/dns_parser/structs.rs @@ -1,5 +1,4 @@ -use super::{QueryType, QueryClass, Name, Class, Header, RRData}; - +use super::{Class, Header, Name, QueryClass, QueryType, RRData}; /// Parsed DNS packet #[derive(Debug)] diff --git a/src/fsm.rs b/src/fsm.rs index 429d703..5834a7c 100644 --- a/src/fsm.rs +++ b/src/fsm.rs @@ -1,18 +1,22 @@ -use dns_parser::{self, Name, QueryClass, QueryType, RRData}; -use futures::sync::mpsc; -use futures::{Async, Future, Poll, Stream}; +use crate::dns_parser::{self, Name, QueryClass, QueryType, RRData}; + use get_if_addrs::get_if_addrs; use std::collections::VecDeque; use std::io; use std::io::ErrorKind::WouldBlock; use std::marker::PhantomData; use std::net::{IpAddr, SocketAddr}; -use tokio::net::UdpSocket; -use tokio::reactor::Handle; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use tokio::{net::UdpSocket, stream::Stream, sync::mpsc}; use super::{DEFAULT_TTL, MDNS_PORT}; -use address_family::AddressFamily; -use services::{ServiceData, Services}; +use crate::address_family::AddressFamily; +use crate::services::{ServiceData, Services}; pub type AnswerBuilder = dns_parser::Builder; @@ -26,6 +30,16 @@ pub enum Command { Shutdown, } +quick_error! { +#[derive(Debug)] +pub enum Error { + BufferTooSmall(bytes: usize, size: usize) { + description("buffer isn't big enough") + display("Incoming packet {}>{}", bytes, size) + } +} +} + pub struct FSM { socket: UdpSocket, services: Services, @@ -35,13 +49,12 @@ pub struct FSM { } impl FSM { - pub fn new( - handle: &Handle, - services: &Services, - ) -> io::Result<(FSM, mpsc::UnboundedSender)> { + // Will panic if called from outside the context of a runtime + pub fn new(services: &Services) -> io::Result<(FSM, mpsc::UnboundedSender)> { let std_socket = AF::bind()?; - let socket = UdpSocket::from_socket(std_socket, handle)?; - let (tx, rx) = mpsc::unbounded(); + let socket = UdpSocket::from_std(std_socket)?; + + let (tx, rx) = mpsc::unbounded_channel(); let fsm = FSM { socket: socket, @@ -54,22 +67,25 @@ impl FSM { Ok((fsm, tx)) } - fn recv_packets(&mut self) -> io::Result<()> { + fn recv_packets(&mut self, cx: &mut Context) -> io::Result<()> { let mut buf = [0u8; 4096]; loop { - let (bytes, addr) = match self.socket.recv_from(&mut buf) { - Ok((bytes, addr)) => (bytes, addr), - Err(ref ioerr) if ioerr.kind() == WouldBlock => break, - Err(err) => return Err(err), + let (bytes, addr) = match self.socket.poll_recv_from(cx, &mut buf) { + Poll::Ready(Ok((bytes, addr))) => (bytes, addr), + Poll::Ready(Err(err)) => return Err(err), + Poll::Pending => break, }; - + // Is moot for certain platforms (Windows will throw a <10040> error from poll_recv) if bytes >= buf.len() { warn!("buffer too small for packet from {:?}", addr); - continue; + return Err(io::Error::new( + io::ErrorKind::Other, + Error::BufferTooSmall(bytes, buf.len()), + )); } - self.handle_packet(&buf[..bytes], addr); } + Ok(()) } @@ -218,47 +234,44 @@ impl FSM { } } -impl Future for FSM { - type Item = (); - type Error = io::Error; - fn poll(&mut self) -> Poll<(), io::Error> { - while let Async::Ready(cmd) = self.commands.poll().unwrap() { +impl Future for FSM { + type Output = (); + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> { + let pinned = Pin::get_mut(self); + while let Poll::Ready(cmd) = Pin::new(&mut pinned.commands).poll_next(cx) { match cmd { - Some(Command::Shutdown) => return Ok(Async::Ready(())), + Some(Command::Shutdown) => return Poll::Ready(()), Some(Command::SendUnsolicited { svc, ttl, include_ip, }) => { - self.send_unsolicited(&svc, ttl, include_ip); + pinned.send_unsolicited(&svc, ttl, include_ip); } None => { warn!("responder disconnected without shutdown"); - return Ok(Async::Ready(())); + return Poll::Ready(()); } } } - while let Async::Ready(()) = self.socket.poll_read() { - self.recv_packets()?; + match pinned.recv_packets(cx) { + Ok(_) => (), + Err(e) => error!("ResponderRecvPacket Error: {:?}", e), } - loop { - if let Some(&(ref response, ref addr)) = self.outgoing.front() { - trace!("sending packet to {:?}", addr); + while let Some(&(ref response, ref addr)) = pinned.outgoing.front() { + trace!("sending packet to {:?}", addr); - match self.socket.send_to(response, addr) { - Ok(_) => (), - Err(ref ioerr) if ioerr.kind() == WouldBlock => break, - Err(err) => warn!("error sending packet {:?}", err), - } - } else { - break; + match pinned.socket.poll_send_to(cx, response, addr) { + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(ref ioerr)) if ioerr.kind() == WouldBlock => break, + Poll::Ready(Err(err)) => warn!("error sending packet {:?}", err), + Poll::Pending => (break), } - - self.outgoing.pop_front(); } + pinned.outgoing.pop_front(); - Ok(Async::NotReady) + Poll::Pending } } diff --git a/src/lib.rs b/src/lib.rs index 89e1ee1..848ccf2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,32 +5,36 @@ extern crate quick_error; extern crate log; extern crate byteorder; -extern crate futures; extern crate get_if_addrs; extern crate hostname; extern crate multimap; extern crate net2; extern crate rand; -extern crate tokio_core as tokio; +extern crate tokio; -use futures::sync::mpsc; -use futures::Future; +use futures_util::{future, future::FutureExt}; use std::cell::RefCell; +use std::future::Future; use std::io; +use std::marker::Unpin; use std::sync::{Arc, RwLock}; + use std::thread; -use tokio::reactor::{Core, Handle}; +use tokio::{ + runtime::{Handle, Runtime}, + sync::mpsc, +}; mod dns_parser; -use dns_parser::Name; +use crate::dns_parser::Name; mod address_family; mod fsm; mod services; -use address_family::{Inet, Inet6}; -use fsm::{Command, FSM}; -use services::{ServiceData, Services, ServicesInner}; +use crate::address_family::{Inet, Inet6}; +use crate::fsm::{Command, FSM}; +use crate::services::{ServiceData, Services, ServicesInner}; const DEFAULT_TTL: u32 = 60; const MDNS_PORT: u16 = 5353; @@ -48,42 +52,32 @@ pub struct Service { _shutdown: Arc, } -type ResponderTask = Box + Send>; +type ResponderTask = Box + Send + Unpin>; impl Responder { - fn setup_core() -> io::Result<(Core, ResponderTask, Responder)> { - let core = Core::new()?; - let (responder, task) = Self::with_handle(&core.handle())?; - Ok((core, task, responder)) - } - pub fn new() -> io::Result { let (tx, rx) = std::sync::mpsc::sync_channel(0); thread::Builder::new() .name("mdns-responder".to_owned()) - .spawn(move || match Self::setup_core() { - Ok((mut core, task, responder)) => { + .spawn(move || { + let mut rt = Runtime::new().unwrap(); + rt.block_on(async { + let (responder, task) = Self::with_default_handle()?; tx.send(Ok(responder)).expect("tx responder channel closed"); - core.run(task).expect("mdns thread failed"); - } - Err(err) => { - tx.send(Err(err)).expect("tx responder channel closed"); - } + task.await; + Ok::<(), io::Error>(()) + }) })?; - rx.recv().expect("rx responder channel closed") } pub fn spawn(handle: &Handle) -> io::Result { - let (responder, task) = Responder::with_handle(handle)?; - handle.spawn(task.map_err(|e| { - warn!("mdns error {:?}", e); - () - })); + let (responder, task) = Self::with_default_handle()?; + handle.spawn(task); Ok(responder) } - pub fn with_handle(handle: &Handle) -> io::Result<(Responder, ResponderTask)> { + pub fn with_default_handle() -> io::Result<(Responder, ResponderTask)> { let mut hostname = match hostname::get() { Ok(s) => match s.into_string() { Ok(s) => s, @@ -102,15 +96,13 @@ impl Responder { let services = Arc::new(RwLock::new(ServicesInner::new(hostname))); - let v4 = FSM::::new(handle, &services); - let v6 = FSM::::new(handle, &services); + let v4 = FSM::::new(&services); + let v6 = FSM::::new(&services); let (task, commands): (ResponderTask, _) = match (v4, v6) { (Ok((v4_task, v4_command)), Ok((v6_task, v6_command))) => { - let task = v4_task.join(v6_task).map(|((), ())| ()); - let task = Box::new(task); - let commands = vec![v4_command, v6_command]; - (task, commands) + let tasks = future::join(v4_task, v6_task).map(|((), ())| ()); + (Box::new(tasks), vec![v4_command, v6_command]) } (Ok((v4_task, v4_command)), Err(err)) => { @@ -191,7 +183,7 @@ struct CommandSender(Vec>); impl CommandSender { fn send(&mut self, cmd: Command) { for tx in self.0.iter_mut() { - tx.unbounded_send(cmd.clone()).expect("responder died"); + tx.send(cmd.clone()).expect("responder died"); } } diff --git a/src/services.rs b/src/services.rs index 11fb264..f0b8c0e 100644 --- a/src/services.rs +++ b/src/services.rs @@ -1,4 +1,4 @@ -use dns_parser::{self, Name, QueryClass, RRData}; +use crate::dns_parser::{self, Name, QueryClass, RRData}; use multimap::MultiMap; use rand::{thread_rng, Rng}; use std::collections::HashMap;