Skip to content

Commit

Permalink
🐛 Fix domain rules minimum TTL not work
Browse files Browse the repository at this point in the history
  • Loading branch information
mokeyish committed Oct 10, 2024
1 parent c9ad11d commit efdcf94
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 32 deletions.
22 changes: 22 additions & 0 deletions src/config/parser/domain_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ impl NomParser for DomainRule {
map(options::parse_no_value(tag_no_case("no-cache")), |v| {
rule.no_cache = Some(v);
}),
map(
options::parse_value(tag_no_case("rr-ttl-min"), NomParser::parse),
|v| {
rule.rr_ttl_min = Some(v);
},
),
map(
options::parse_value(tag_no_case("rr-ttl-max"), NomParser::parse),
|v| {
rule.rr_ttl_max = Some(v);
},
),
map(options::unkown_options, |(n, v)| {
log::warn!("domain rule: unkown options {}={:?}", n, v)
}),
Expand Down Expand Up @@ -146,5 +158,15 @@ mod tests {
}
))
);
assert_eq!(
DomainRule::parse("-rr-ttl-min 60"),
Ok((
"",
DomainRule {
rr_ttl_min: Some(60),
..Default::default()
}
))
);
}
}
7 changes: 3 additions & 4 deletions src/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use crate::dns_rule::DomainRuleTreeNode;
use crate::config::ServerOpts;
use crate::dns_conf::RuntimeConfig;

pub use crate::dns_rule::DomainRuleGetter;

pub use crate::libdns::proto::{
error::ProtoErrorKind,
op,
Expand Down Expand Up @@ -40,10 +42,7 @@ impl DnsContext {
pub fn new(name: &Name, cfg: Arc<RuntimeConfig>, server_opts: ServerOpts) -> Self {
let domain_rule = cfg.find_domain_rule(name);

let no_cache = domain_rule
.as_ref()
.and_then(|r| r.get(|n| n.no_cache))
.unwrap_or_default();
let no_cache = domain_rule.get(|n| n.no_cache).unwrap_or_default();

DnsContext {
cfg,
Expand Down
23 changes: 12 additions & 11 deletions src/dns_conf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;

pub use crate::config::*;
use crate::dns::DomainRuleGetter;
use crate::log;
use crate::{
dns_rule::{DomainRuleMap, DomainRuleTreeNode},
Expand Down Expand Up @@ -926,7 +927,7 @@ fn resolve_filepath<P: AsRef<Path>>(filepath: P, base_file: Option<&PathBuf>) ->

#[cfg(test)]
mod tests {
use crate::libdns::Protocol;
use crate::{dns::DomainRuleGetter, libdns::Protocol};
use byte_unit::Byte;

use crate::config::{HttpsListenerConfig, ListenerAddress, ServerOpts, SslConfig};
Expand Down Expand Up @@ -1261,13 +1262,13 @@ mod tests {

assert_eq!(
cfg.find_domain_rule(&"cloudflare.com".parse().unwrap())
.and_then(|r| r.get(|n| n.address)),
.get(|n| n.address),
Some(DomainAddress::SOA)
);

assert_eq!(
cfg.find_domain_rule(&"google.com".parse().unwrap())
.and_then(|r| r.get(|n| n.address)),
.get(|n| n.address),
Some(DomainAddress::IGN)
);
}
Expand All @@ -1279,13 +1280,13 @@ mod tests {
.build();
assert_eq!(
cfg.find_domain_rule(&"example.com".parse().unwrap())
.and_then(|r| r.get(|n| n.address)),
.get(|n| n.address),
Some(DomainAddress::SOA)
);

assert_eq!(
cfg.find_domain_rule(&"aa.example.com".parse().unwrap())
.and_then(|r| r.get(|n| n.address)),
.get(|n| n.address),
None
);
}
Expand All @@ -1295,13 +1296,13 @@ mod tests {
let cfg = RuntimeConfig::builder().with("address /*/#").build();
assert_eq!(
cfg.find_domain_rule(&"localhost".parse().unwrap())
.and_then(|r| r.get(|n| n.address)),
.get(|n| n.address),
Some(DomainAddress::SOA)
);

assert_eq!(
cfg.find_domain_rule(&"aa.example.com".parse().unwrap())
.and_then(|r| r.get(|n| n.address)),
.get(|n| n.address),
None
);
}
Expand All @@ -1311,13 +1312,13 @@ mod tests {
let cfg = RuntimeConfig::builder().with("address /+/#").build();
assert_eq!(
cfg.find_domain_rule(&"localhost".parse().unwrap())
.and_then(|r| r.get(|n| n.address)),
.get(|n| n.address),
Some(DomainAddress::SOA)
);

assert_eq!(
cfg.find_domain_rule(&"aa.example.com".parse().unwrap())
.and_then(|r| r.get(|n| n.address)),
.get(|n| n.address),
Some(DomainAddress::SOA)
);
}
Expand All @@ -1327,13 +1328,13 @@ mod tests {
let cfg = RuntimeConfig::builder().with("address /./#").build();
assert_eq!(
cfg.find_domain_rule(&"localhost".parse().unwrap())
.and_then(|r| r.get(|n| n.address)),
.get(|n| n.address),
Some(DomainAddress::SOA)
);

assert_eq!(
cfg.find_domain_rule(&"aa.example.com".parse().unwrap())
.and_then(|r| r.get(|n| n.address)),
.get(|n| n.address),
Some(DomainAddress::SOA)
);
}
Expand Down
15 changes: 13 additions & 2 deletions src/dns_mw_addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,19 @@ impl Middleware<DnsContext, DnsRequest, DnsResponse, DnsError> for AddressMiddle
}
}

let rr_ttl_min = ctx.cfg().rr_ttl_min().map(|i| i as u32);
let rr_ttl_max = ctx.cfg().rr_ttl_max().map(|i| i as u32);
let rr_ttl_min = ctx
.domain_rule
.get_ref(|r| r.rr_ttl_min.as_ref())
.cloned()
.or_else(|| ctx.cfg().rr_ttl_min())
.map(|i| i as u32);

let rr_ttl_max = ctx
.domain_rule
.get_ref(|r| r.rr_ttl_max.as_ref())
.cloned()
.or_else(|| ctx.cfg().rr_ttl_max())
.map(|i| i as u32);
let rr_ttl_reply_max = ctx.cfg().rr_ttl_reply_max().map(|i| i as u32);

if rr_ttl_min.is_some() || rr_ttl_max.is_some() || rr_ttl_reply_max.is_some() {
Expand Down
3 changes: 1 addition & 2 deletions src/dns_mw_dualstack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ impl Middleware<DnsContext, DnsRequest, DnsResponse, DnsError>

let speed_check_mode = ctx
.domain_rule
.as_ref()
.and_then(|r| r.speed_check_mode.as_ref())
.get_ref(|r| r.speed_check_mode.as_ref())
.cloned()
.unwrap_or_default();

Expand Down
2 changes: 1 addition & 1 deletion src/dns_mw_ns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl Middleware<DnsContext, DnsRequest, DnsResponse, DnsError> for NameServerMid
_ => None,
})
})
.or_else(|| ctx.domain_rule.as_ref().and_then(|r| r.subnet)),
.or_else(|| ctx.domain_rule.get_ref(|r| r.subnet.as_ref()).cloned()),
};

// skip nameserver rule
Expand Down
12 changes: 2 additions & 10 deletions src/dns_mw_zone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,23 +102,15 @@ impl Middleware<DnsContext, DnsRequest, DnsResponse, DnsError> for DnsZoneMiddle
}
}
RecordType::SRV => {
if let Some(srv) = ctx
.domain_rule
.as_ref()
.and_then(|r| r.get_ref(|r| r.srv.as_ref()))
{
if let Some(srv) = ctx.domain_rule.get_ref(|r| r.srv.as_ref()) {
return Ok(DnsResponse::from_rdata(
req.query().original().to_owned(),
RData::SRV(srv.clone()),
));
}
}
RecordType::HTTPS => {
if let Some(https_rule) = ctx
.domain_rule
.as_ref()
.and_then(|r| r.get_ref(|r| r.https.as_ref()))
{
if let Some(https_rule) = ctx.domain_rule.get_ref(|r| r.https.as_ref()) {
match https_rule {
HttpsRecordRule::Ignore => (),
HttpsRecordRule::SOA => {
Expand Down
32 changes: 30 additions & 2 deletions src/dns_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,44 @@ impl DomainRuleTreeNode {
pub fn zone(&self) -> Option<&Arc<DomainRuleTreeNode>> {
self.zone.as_ref()
}
}

pub trait DomainRuleGetter {
fn get<T>(&self, f: impl Fn(&DomainRuleTreeNode) -> Option<T>) -> Option<T>;

fn get_ref<T>(&self, f: impl Fn(&DomainRuleTreeNode) -> Option<&T>) -> Option<&T>;
}

pub fn get<T>(&self, f: impl Fn(&Self) -> Option<T>) -> Option<T> {
impl DomainRuleGetter for DomainRuleTreeNode {
fn get<T>(&self, f: impl Fn(&DomainRuleTreeNode) -> Option<T>) -> Option<T> {
f(self).or_else(|| self.zone().and_then(|z| f(z)))
}

pub fn get_ref<T>(&self, f: impl Fn(&Self) -> Option<&T>) -> Option<&T> {
fn get_ref<T>(&self, f: impl Fn(&DomainRuleTreeNode) -> Option<&T>) -> Option<&T> {
f(self).or_else(|| self.zone().and_then(|z| f(z)))
}
}

impl<N: AsRef<DomainRuleTreeNode>> DomainRuleGetter for N {
fn get<T>(&self, f: impl Fn(&DomainRuleTreeNode) -> Option<T>) -> Option<T> {
self.as_ref().get(f)
}

fn get_ref<T>(&self, f: impl Fn(&DomainRuleTreeNode) -> Option<&T>) -> Option<&T> {
self.as_ref().get_ref(f)
}
}

impl DomainRuleGetter for Option<Arc<DomainRuleTreeNode>> {
fn get<T>(&self, f: impl Fn(&DomainRuleTreeNode) -> Option<T>) -> Option<T> {
self.as_deref().and_then(f)
}

fn get_ref<T>(&self, f: impl Fn(&DomainRuleTreeNode) -> Option<&T>) -> Option<&T> {
self.as_deref().and_then(f)
}
}

impl Deref for DomainRuleTreeNode {
type Target = DomainRule;

Expand Down

0 comments on commit efdcf94

Please sign in to comment.