Skip to content

Commit

Permalink
Simplify route logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
yuguorui committed Apr 17, 2022
1 parent 37795c3 commit 11d3fcd
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 133 deletions.
189 changes: 106 additions & 83 deletions src/rules.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
use anyhow::Context;
use iprange::IpRange;

use ipnet::Ipv6Net;
use tokio::time::Instant;

use std::fmt::Display;
use std::net::{SocketAddr, ToSocketAddrs};
use std::time::Duration;
use std::{
collections::HashMap,
net::{SocketAddr, ToSocketAddrs},
};
use tokio::{
net::{TcpSocket, TcpStream},
time::timeout,
Expand Down Expand Up @@ -36,7 +34,6 @@ pub struct Condition {
pub maxmind_regions: Vec<String>,
pub dst_ip_range: IpRange<Ipv6Net>,
pub domains: Option<BoomHashSet<String>>,
pub suffix_domains: Option<BoomHashSet<String>>,
}

impl Condition {
Expand All @@ -45,19 +42,62 @@ impl Condition {
maxmind_regions: Vec::new(),
dst_ip_range: IpRange::<Ipv6Net>::new(),
domains: None,
suffix_domains: None,
}
}
}

#[derive(Debug)]
pub struct RouteRule(pub Outbound, pub Option<Condition>);
pub fn match_domain(&self, name: &str) -> bool {
if let Some(ref domains) = self.domains {
if domains.get(&name.to_owned()) {
return true;
}

/* Check the DOMAIN-SUFFIX rules */
let domain = format!(".{}{}", name, RULE_DOMAIN_SUFFIX_TAG);
let indices = domain.match_indices(".");
for index in indices {
if domains.get(&domain[index.0 + 1..].to_string()) {
return true;
}
}
}

false
}

pub fn match_sockaddr(
&self,
dst_sock: &SocketAddr,
ip_db: Option<&maxminddb::Reader<Vec<u8>>>,
) -> bool {
let dst_ip = dst_sock.ip();
if self
.dst_ip_range
.contains(dst_ip.to_ipv6_net().as_ref().unwrap())
{
return true;
}

pub type OutboundName = String;
/* Check the maxmind */
if let Some(reader) = ip_db {
let region: Result<maxminddb::geoip2::Country, maxminddb::MaxMindDBError> =
reader.lookup(dst_ip);
if let Ok(region) = region {
if self
.maxmind_regions
.contains(&region.country.unwrap().iso_code.unwrap().to_string())
{
return true;
}
}
}
return false;
}
}

pub struct RouteTable {
pub default: Outbound,
pub outbound_dict: HashMap<OutboundName, RouteRule>,
default_outbound: Option<u8>,
pub rules: Vec<Condition>,
pub outbounds: Vec<Outbound>,
pub ip_db: Option<maxminddb::Reader<Vec<u8>>>,
}

Expand Down Expand Up @@ -110,105 +150,88 @@ impl Display for RouteContext {
}
}

fn match_rule_with_sockaddr(
dst_sock: &SocketAddr,
rule: &Condition,
ip_db: Option<&maxminddb::Reader<Vec<u8>>>,
) -> bool {
let dst_ip = dst_sock.ip();
if rule
.dst_ip_range
.contains(dst_ip.to_ipv6_net().as_ref().unwrap())
{
return true;
impl RouteTable {
pub fn new() -> Self {
Self {
default_outbound: None,
rules: Vec::new(),
outbounds: Vec::new(),
ip_db: None,
}
}

pub fn set_default_route(&mut self, index: u8) {
self.default_outbound = Some(index);
}

/* Check the maxmind */
if let Some(reader) = ip_db {
let region: Result<maxminddb::geoip2::Country, maxminddb::MaxMindDBError> =
reader.lookup(dst_ip);
if let Ok(region) = region {
if rule
.maxmind_regions
.contains(&region.country.unwrap().iso_code.unwrap().to_string())
{
return true;
pub fn get_outbound_by_name(&self, name: &str) -> Option<&Outbound> {
for outbound in &self.outbounds {
if outbound.name == name {
return Some(outbound);
}
}
return None;
}
return false;
}

impl RouteTable {
pub fn add<S: ToString>(
pub fn get_outbound_index_by_name(&self, name: &str) -> Option<u8> {
for (index, outbound) in self.outbounds.iter().enumerate() {
if outbound.name == name {
return Some(index as u8);
}
}
return None;
}

pub fn add_empty_rule<S: ToString>(
&mut self,
name: S,
url: Option<Url>,
bind_range: Option<IpRange<Ipv6Net>>,
) {
self.outbound_dict.insert(
name.to_string(),
RouteRule(
Outbound {
name: name.to_string(),
url,
bind_range,
},
None,
),
);
self.outbounds.push(Outbound {
name: name.to_string(),
url,
bind_range,
});

self.rules.push(Condition::default());
}

pub fn match_route(&self, context: &RouteContext) -> (&str, &Outbound) {
for (name, outbound_rule) in self.outbound_dict.iter() {
if let Some(rule) = outbound_rule.1.as_ref() {
match &context.target_addr {
TargetAddr::Ip(dst_sock) => {
if match_rule_with_sockaddr(dst_sock, rule, self.ip_db.as_ref()) {
return (name, &outbound_rule.0);
}
fn match_route(&self, context: &RouteContext) -> u8 {
for (index, cond) in self.rules.iter().enumerate() {
match &context.target_addr {
TargetAddr::Ip(dst_sock) => {
if cond.match_sockaddr(dst_sock, self.ip_db.as_ref()) {
return index as u8;
}
TargetAddr::Domain(domain, _port, dst_sock) => {
if let Some(phf) = &rule.domains {
/* Check the DOMAIN rules. */
if phf.get(&domain) {
return (name, &outbound_rule.0);
}

/* Check the DOMAIN-SUFFIX rules */
let domain = format!(".{}{}", domain, RULE_DOMAIN_SUFFIX_TAG);
let indices = domain.match_indices(".");
for index in indices {
if phf.get(&domain[index.0 + 1..].to_string()) {
return (name, &outbound_rule.0);
}
}
}
}

if let Some(dst_sock) = dst_sock {
if match_rule_with_sockaddr(
dst_sock,
rule,
self.ip_db.as_ref(),
) {
return (name, &outbound_rule.0);
}
TargetAddr::Domain(domain, _, dst_sock) => {
if cond.match_domain(domain) {
return index as u8;
}

if let Some(dst_sock) = dst_sock {
if cond.match_sockaddr(dst_sock, self.ip_db.as_ref()) {
return index as u8;
}
}
}
}
}
(&self.default.name, &self.default)

return self.default_outbound.context("no default route").unwrap();
}

pub async fn get_tcp_sock(&self, context: &RouteContext) -> tokio::io::Result<TcpStream> {
let start = Instant::now();
let (name, outbound) = self.match_route(context);
let outbound_index = self.match_route(context);
let outbound = &self.outbounds[outbound_index as usize];
let duration = start.elapsed();
println!(
"{} -> Outbound({}){}",
context,
name,
outbound.name,
if SETTINGS.read().await.debug {
format!(", time: {}us", duration.as_micros())
} else {
Expand Down
Loading

0 comments on commit 11d3fcd

Please sign in to comment.