diff --git a/src/config/parser/domain_rule.rs b/src/config/parser/domain_rule.rs index 1be2bc66..44652476 100644 --- a/src/config/parser/domain_rule.rs +++ b/src/config/parser/domain_rule.rs @@ -50,6 +50,18 @@ impl NomParser for DomainRule { map(options::parse_no_value(tag_no_case("no-cache")), |v| { rule.no_cache = Some(v); }), + map( + options::parse_value(tag_no_case("rr-ttl-min"), NomParser::parse), + |v| { + rule.rr_ttl_min = Some(v); + }, + ), + map( + options::parse_value(tag_no_case("rr-ttl-max"), NomParser::parse), + |v| { + rule.rr_ttl_max = Some(v); + }, + ), map(options::unkown_options, |(n, v)| { log::warn!("domain rule: unkown options {}={:?}", n, v) }), @@ -146,5 +158,15 @@ mod tests { } )) ); + assert_eq!( + DomainRule::parse("-rr-ttl-min 60"), + Ok(( + "", + DomainRule { + rr_ttl_min: Some(60), + ..Default::default() + } + )) + ); } } diff --git a/src/dns.rs b/src/dns.rs index 9cc482b6..48483133 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -11,6 +11,8 @@ use crate::dns_rule::DomainRuleTreeNode; use crate::config::ServerOpts; use crate::dns_conf::RuntimeConfig; +pub use crate::dns_rule::DomainRuleGetter; + pub use crate::libdns::proto::{ error::ProtoErrorKind, op, @@ -40,10 +42,7 @@ impl DnsContext { pub fn new(name: &Name, cfg: Arc, server_opts: ServerOpts) -> Self { let domain_rule = cfg.find_domain_rule(name); - let no_cache = domain_rule - .as_ref() - .and_then(|r| r.get(|n| n.no_cache)) - .unwrap_or_default(); + let no_cache = domain_rule.get(|n| n.no_cache).unwrap_or_default(); DnsContext { cfg, diff --git a/src/dns_conf.rs b/src/dns_conf.rs index 815a5c9e..9cb8c979 100644 --- a/src/dns_conf.rs +++ b/src/dns_conf.rs @@ -9,6 +9,7 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; pub use crate::config::*; +use crate::dns::DomainRuleGetter; use crate::log; use crate::{ dns_rule::{DomainRuleMap, DomainRuleTreeNode}, @@ -926,7 +927,7 @@ fn resolve_filepath>(filepath: P, base_file: Option<&PathBuf>) -> #[cfg(test)] mod tests { - use crate::libdns::Protocol; + use crate::{dns::DomainRuleGetter, libdns::Protocol}; use byte_unit::Byte; use crate::config::{HttpsListenerConfig, ListenerAddress, ServerOpts, SslConfig}; @@ -1261,13 +1262,13 @@ mod tests { assert_eq!( cfg.find_domain_rule(&"cloudflare.com".parse().unwrap()) - .and_then(|r| r.get(|n| n.address)), + .get(|n| n.address), Some(DomainAddress::SOA) ); assert_eq!( cfg.find_domain_rule(&"google.com".parse().unwrap()) - .and_then(|r| r.get(|n| n.address)), + .get(|n| n.address), Some(DomainAddress::IGN) ); } @@ -1279,13 +1280,13 @@ mod tests { .build(); assert_eq!( cfg.find_domain_rule(&"example.com".parse().unwrap()) - .and_then(|r| r.get(|n| n.address)), + .get(|n| n.address), Some(DomainAddress::SOA) ); assert_eq!( cfg.find_domain_rule(&"aa.example.com".parse().unwrap()) - .and_then(|r| r.get(|n| n.address)), + .get(|n| n.address), None ); } @@ -1295,13 +1296,13 @@ mod tests { let cfg = RuntimeConfig::builder().with("address /*/#").build(); assert_eq!( cfg.find_domain_rule(&"localhost".parse().unwrap()) - .and_then(|r| r.get(|n| n.address)), + .get(|n| n.address), Some(DomainAddress::SOA) ); assert_eq!( cfg.find_domain_rule(&"aa.example.com".parse().unwrap()) - .and_then(|r| r.get(|n| n.address)), + .get(|n| n.address), None ); } @@ -1311,13 +1312,13 @@ mod tests { let cfg = RuntimeConfig::builder().with("address /+/#").build(); assert_eq!( cfg.find_domain_rule(&"localhost".parse().unwrap()) - .and_then(|r| r.get(|n| n.address)), + .get(|n| n.address), Some(DomainAddress::SOA) ); assert_eq!( cfg.find_domain_rule(&"aa.example.com".parse().unwrap()) - .and_then(|r| r.get(|n| n.address)), + .get(|n| n.address), Some(DomainAddress::SOA) ); } @@ -1327,13 +1328,13 @@ mod tests { let cfg = RuntimeConfig::builder().with("address /./#").build(); assert_eq!( cfg.find_domain_rule(&"localhost".parse().unwrap()) - .and_then(|r| r.get(|n| n.address)), + .get(|n| n.address), Some(DomainAddress::SOA) ); assert_eq!( cfg.find_domain_rule(&"aa.example.com".parse().unwrap()) - .and_then(|r| r.get(|n| n.address)), + .get(|n| n.address), Some(DomainAddress::SOA) ); } diff --git a/src/dns_mw_addr.rs b/src/dns_mw_addr.rs index 2addf756..7c11e2db 100644 --- a/src/dns_mw_addr.rs +++ b/src/dns_mw_addr.rs @@ -73,8 +73,19 @@ impl Middleware for AddressMiddle } } - let rr_ttl_min = ctx.cfg().rr_ttl_min().map(|i| i as u32); - let rr_ttl_max = ctx.cfg().rr_ttl_max().map(|i| i as u32); + let rr_ttl_min = ctx + .domain_rule + .get_ref(|r| r.rr_ttl_min.as_ref()) + .cloned() + .or_else(|| ctx.cfg().rr_ttl_min()) + .map(|i| i as u32); + + let rr_ttl_max = ctx + .domain_rule + .get_ref(|r| r.rr_ttl_max.as_ref()) + .cloned() + .or_else(|| ctx.cfg().rr_ttl_max()) + .map(|i| i as u32); let rr_ttl_reply_max = ctx.cfg().rr_ttl_reply_max().map(|i| i as u32); if rr_ttl_min.is_some() || rr_ttl_max.is_some() || rr_ttl_reply_max.is_some() { diff --git a/src/dns_mw_dualstack.rs b/src/dns_mw_dualstack.rs index c372d020..314d8aec 100644 --- a/src/dns_mw_dualstack.rs +++ b/src/dns_mw_dualstack.rs @@ -64,8 +64,7 @@ impl Middleware let speed_check_mode = ctx .domain_rule - .as_ref() - .and_then(|r| r.speed_check_mode.as_ref()) + .get_ref(|r| r.speed_check_mode.as_ref()) .cloned() .unwrap_or_default(); diff --git a/src/dns_mw_ns.rs b/src/dns_mw_ns.rs index 038f673d..9b776e7b 100644 --- a/src/dns_mw_ns.rs +++ b/src/dns_mw_ns.rs @@ -76,7 +76,7 @@ impl Middleware for NameServerMid _ => None, }) }) - .or_else(|| ctx.domain_rule.as_ref().and_then(|r| r.subnet)), + .or_else(|| ctx.domain_rule.get_ref(|r| r.subnet.as_ref()).cloned()), }; // skip nameserver rule diff --git a/src/dns_mw_zone.rs b/src/dns_mw_zone.rs index a4eea581..4b8b4858 100644 --- a/src/dns_mw_zone.rs +++ b/src/dns_mw_zone.rs @@ -102,11 +102,7 @@ impl Middleware for DnsZoneMiddle } } RecordType::SRV => { - if let Some(srv) = ctx - .domain_rule - .as_ref() - .and_then(|r| r.get_ref(|r| r.srv.as_ref())) - { + if let Some(srv) = ctx.domain_rule.get_ref(|r| r.srv.as_ref()) { return Ok(DnsResponse::from_rdata( req.query().original().to_owned(), RData::SRV(srv.clone()), @@ -114,11 +110,7 @@ impl Middleware for DnsZoneMiddle } } RecordType::HTTPS => { - if let Some(https_rule) = ctx - .domain_rule - .as_ref() - .and_then(|r| r.get_ref(|r| r.https.as_ref())) - { + if let Some(https_rule) = ctx.domain_rule.get_ref(|r| r.https.as_ref()) { match https_rule { HttpsRecordRule::Ignore => (), HttpsRecordRule::SOA => { diff --git a/src/dns_rule.rs b/src/dns_rule.rs index 1d167cea..53f2e502 100644 --- a/src/dns_rule.rs +++ b/src/dns_rule.rs @@ -133,16 +133,44 @@ impl DomainRuleTreeNode { pub fn zone(&self) -> Option<&Arc> { self.zone.as_ref() } +} + +pub trait DomainRuleGetter { + fn get(&self, f: impl Fn(&DomainRuleTreeNode) -> Option) -> Option; + + fn get_ref(&self, f: impl Fn(&DomainRuleTreeNode) -> Option<&T>) -> Option<&T>; +} - pub fn get(&self, f: impl Fn(&Self) -> Option) -> Option { +impl DomainRuleGetter for DomainRuleTreeNode { + fn get(&self, f: impl Fn(&DomainRuleTreeNode) -> Option) -> Option { f(self).or_else(|| self.zone().and_then(|z| f(z))) } - pub fn get_ref(&self, f: impl Fn(&Self) -> Option<&T>) -> Option<&T> { + fn get_ref(&self, f: impl Fn(&DomainRuleTreeNode) -> Option<&T>) -> Option<&T> { f(self).or_else(|| self.zone().and_then(|z| f(z))) } } +impl> DomainRuleGetter for N { + fn get(&self, f: impl Fn(&DomainRuleTreeNode) -> Option) -> Option { + self.as_ref().get(f) + } + + fn get_ref(&self, f: impl Fn(&DomainRuleTreeNode) -> Option<&T>) -> Option<&T> { + self.as_ref().get_ref(f) + } +} + +impl DomainRuleGetter for Option> { + fn get(&self, f: impl Fn(&DomainRuleTreeNode) -> Option) -> Option { + self.as_deref().and_then(f) + } + + fn get_ref(&self, f: impl Fn(&DomainRuleTreeNode) -> Option<&T>) -> Option<&T> { + self.as_deref().and_then(f) + } +} + impl Deref for DomainRuleTreeNode { type Target = DomainRule;