Skip to content

Commit

Permalink
✨Stop domain name prefetching when idle for more than 30 minutes (#275)
Browse files Browse the repository at this point in the history
  • Loading branch information
mokeyish authored May 13, 2024
1 parent ba860fd commit d43357b
Show file tree
Hide file tree
Showing 17 changed files with 234 additions and 298 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "smartdns"
version = "0.7.2"
authors = ["YISH <[email protected]>"]
edition = "2021"
rust-version = "1.68.0"
rust-version = "1.70.0"

keywords = ["DNS", "BIND", "dig", "named", "dnssec", "SmartDNS", "Dnsmasq"]
categories = ["network-programming"]
Expand Down
30 changes: 25 additions & 5 deletions src/app.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
use std::{collections::HashMap, ops::DerefMut, path::PathBuf, sync::Arc, time::Duration};
use std::{
collections::HashMap,
ops::DerefMut,
path::PathBuf,
sync::Arc,
time::{Duration, Instant},
};
use tokio::{
runtime::{Handle, Runtime},
sync::RwLock,
Expand Down Expand Up @@ -165,7 +171,7 @@ impl App {

// check if cache enabled.
if cfg.cache_size() > 0 {
let cache_middleware = DnsCacheMiddleware::new(&cfg);
let cache_middleware = DnsCacheMiddleware::new(&cfg, self.dns_handle.clone());
*self.cache.write().await = Some(cache_middleware.cache().clone());
middleware_builder = middleware_builder.with(cache_middleware);
}
Expand Down Expand Up @@ -244,12 +250,26 @@ pub fn bootstrap(conf: Option<PathBuf>) {

let mut inner_join_set = JoinSet::new();

let mut last_activity = Instant::now();

const MAX_IDLE: Duration = Duration::from_secs(30 * 60); // 30 min

while let Some((message, server_opts, sender)) = incoming_request.recv().await {
let handler = app.mw_handler.read().await.clone();

inner_join_set.spawn(async move {
let _ = sender.send(process(handler, message, server_opts).await);
});
if server_opts.is_background {
if Instant::now() - last_activity < MAX_IDLE {
inner_join_set.spawn(async move {
let _ = sender.send(process(handler, message, server_opts).await);
});
}
} else {
last_activity = Instant::now();
inner_join_set.spawn(async move {
let _ = sender.send(process(handler, message, server_opts).await);
});
}

reap_tasks(&mut inner_join_set);
}
});
Expand Down
12 changes: 6 additions & 6 deletions src/config/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,15 @@ pub enum ListenerAddress {
V6(Ipv6Addr),
}

impl ToString for ListenerAddress {
fn to_string(&self) -> String {
impl std::fmt::Display for ListenerAddress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use ListenerAddress::*;

match self {
Localhost => "localhost".to_string(),
All => "*".to_string(),
V4(ip) => ip.to_string(),
V6(ip) => format!("[{ip}]"),
Localhost => write!(f, "localhost"),
All => write!(f, "*"),
V4(ip) => write!(f, "{ip}"),
V6(ip) => write!(f, "[{ip}]"),
}
}
}
Expand Down
28 changes: 14 additions & 14 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,11 @@ impl From<Name> for Domain {
}
}

impl ToString for Domain {
fn to_string(&self) -> String {
impl std::fmt::Display for Domain {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Domain::Name(n) => n.to_string(),
Domain::Set(n) => format!("domain-set:{n}"),
Domain::Name(n) => write!(f, "{n}"),
Domain::Set(n) => write!(f, "domain-set:{n}"),
}
}
}
Expand Down Expand Up @@ -328,18 +328,18 @@ pub enum DomainAddress {
IPv6(Ipv6Addr),
}

impl ToString for DomainAddress {
fn to_string(&self) -> String {
impl std::fmt::Display for DomainAddress {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use DomainAddress::*;
match self {
SOA => "#".to_string(),
SOAv4 => "#4".to_string(),
SOAv6 => "#6".to_string(),
IGN => "-".to_string(),
IGNv4 => "-4".to_string(),
IGNv6 => "-6".to_string(),
IPv4(ip) => format!("{ip}"),
IPv6(ip) => format!("{ip}"),
SOA => write!(f, "#"),
SOAv4 => write!(f, "#4"),
SOAv6 => write!(f, "#6"),
IGN => write!(f, "-"),
IGNv4 => write!(f, "-4"),
IGNv6 => write!(f, "-6"),
IPv4(ip) => write!(f, "{ip}"),
IPv6(ip) => write!(f, "{ip}"),
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/config/server_opts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ pub struct ServerOpts {
/// do not serve expired
#[serde(skip_serializing_if = "Option::is_none")]
pub no_serve_expired: Option<bool>,

/// Indicates whether the query task is a background task.
pub is_background: bool,
}

impl ServerOpts {
Expand Down Expand Up @@ -116,6 +119,7 @@ impl ServerOpts {
no_dualstack_selection,
force_aaaa_soa,
no_serve_expired,
is_background: _,
} = other;

if self.group.is_none() {
Expand Down
30 changes: 21 additions & 9 deletions src/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pub struct DnsContext {
pub fastest_speed: Duration,
pub source: LookupFrom,
pub no_cache: bool,
pub background: bool,
}

impl DnsContext {
Expand All @@ -49,7 +48,6 @@ impl DnsContext {
fastest_speed: Default::default(),
source: Default::default(),
no_cache,
background: false,
}
}

Expand Down Expand Up @@ -113,11 +111,14 @@ impl Default for LookupFrom {

mod serial_message {

use crate::libdns::proto::op::Message;
use crate::dns_error::LookupError;
use crate::libdns::proto::error::ProtoError;
use crate::libdns::Protocol;
use crate::{config::ServerOpts, libdns::proto::op::Message};
use bytes::Bytes;
use hickory_proto::error::ProtoError;
use std::net::SocketAddr;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};

use super::{DnsRequest, DnsResponse};

pub enum SerialMessage {
Raw(Message, SocketAddr, Protocol),
Expand Down Expand Up @@ -151,6 +152,16 @@ mod serial_message {
}
}

impl From<Message> for SerialMessage {
fn from(message: Message) -> Self {
Self::raw(
message,
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 0)),
Protocol::Udp,
)
}
}

impl TryFrom<SerialMessage> for crate::libdns::proto::xfer::SerialMessage {
type Error = ProtoError;
fn try_from(value: SerialMessage) -> Result<Self, Self::Error> {
Expand Down Expand Up @@ -285,10 +296,13 @@ mod request {
fn from(query: Query) -> Self {
use std::net::{Ipv4Addr, SocketAddrV4};

let mut message = Message::new();
message.add_query(query.clone());

Self {
id: rand::random(),
query: query.into(),
message: Arc::new(Message::default()),
message: Arc::new(message),
src: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 53)),
protocol: Protocol::Udp,
}
Expand Down Expand Up @@ -316,11 +330,9 @@ mod request {

mod response {

use hickory_proto::op;

use crate::dns_client::MAX_TTL;
use crate::libdns::proto::{
op::{Header, Message, Query},
op::{self, Header, Message, Query},
rr::{RData, Record},
};
use crate::libdns::resolver::TtlClip as _;
Expand Down
2 changes: 1 addition & 1 deletion src/dns_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ impl NameServerFactory {
let key = format!(
"{}: {}{:?}#{}@{}",
url.proto(),
url.to_string(),
**url,
proxy.as_ref().map(|s| s.to_string()),
so_mark.unwrap_or_default(),
device.as_deref().unwrap_or_default(),
Expand Down
6 changes: 4 additions & 2 deletions src/dns_error.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::dns::DnsResponse;
use crate::libdns::proto::{error::ProtoError, op::ResponseCode};
use crate::libdns::proto::{
error::{ProtoError, ProtoErrorKind},
op::ResponseCode,
};
use crate::libdns::resolver::error::{ResolveError, ResolveErrorKind};
use hickory_proto::error::ProtoErrorKind;
use std::{io, sync::Arc};
use thiserror::Error;

Expand Down
38 changes: 0 additions & 38 deletions src/dns_mw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,6 @@ use crate::{

pub type DnsMiddlewareHost = MiddlewareHost<DnsContext, DnsRequest, DnsResponse, DnsError>;

pub struct BackgroundQueryTask {
pub ctx: DnsContext,
pub req: DnsRequest,
client: Arc<DnsMiddlewareHost>,
}

impl BackgroundQueryTask {
pub fn new(ctx: &DnsContext, req: &DnsRequest, client: Arc<DnsMiddlewareHost>) -> Self {
Self {
ctx: ctx.clone(),
req: req.clone(),
client,
}
}

pub fn from_query(
query: Query,
cfg: Arc<RuntimeConfig>,
client: Arc<DnsMiddlewareHost>,
) -> Self {
let ctx = DnsContext::new(query.name(), cfg, Default::default());
let req = query.into();
Self { ctx, req, client }
}

pub fn spawn(self) -> tokio::task::JoinHandle<(Self, Result<DnsResponse, DnsError>)> {
tokio::spawn(async move {
let Self {
mut ctx,
req,
client,
} = self;
let res = client.execute(&mut ctx, &req).await;
(Self { ctx, req, client }, res)
})
}
}

pub struct DnsMiddlewareHandler {
cfg: Arc<RuntimeConfig>,
host: DnsMiddlewareHost,
Expand Down
9 changes: 5 additions & 4 deletions src/dns_mw_audit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ impl DnsAuditRecord {
}
}

impl ToString for DnsAuditRecord {
fn to_string(&self) -> String {
format!(
impl std::fmt::Display for DnsAuditRecord {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"[{}] {} query {}, type: {}, elapsed: {:?}, speed: {:?}, result {}",
self.date.format("%Y-%m-%d %H:%M:%S,%3f"),
self.client,
Expand Down Expand Up @@ -213,7 +214,7 @@ fn record_audit_to_file(
} else {
// write as nornmal log format.
for audit in audit_records {
if writeln!(audit_file, "{}", audit.to_string()).is_err() {
if writeln!(audit_file, "{}", audit).is_err() {
warn!("Write audit to file '{:?}' failed", audit_file.path());
}
}
Expand Down
Loading

0 comments on commit d43357b

Please sign in to comment.