diff --git a/src/dns.rs b/src/dns.rs index 1d2b1754..2334ab23 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -1,58 +1,86 @@ mod local; +mod default; +mod cache; use crate::types::Result; use std::net::IpAddr; -use crate::dns::local::LocalOverrideTable; -use crate::types; + +/// A single IP entry in the DNS cache +struct IpEntry { + /// Actual IP address (either ipv4 or ipv6) + ip: IpAddr, + /// Epoch when ip address expires + expired: u64, + /// Epoch when this entry was last used + last_used: u64, +} /// A DNS entry is a simple domain name to IP address mapping struct DNSEntry { + // domain name domain_name: String, - ip: IpAddr, + // List of IPv4 addresses for this domain name + ipv4: Vec, + // List of IPv6 addresses for this domain name + ipv6: Vec, + // Round robin pointer in case we need to round robin + rr_ipv4_ptr: usize, + // Round robin pointer in case we need to round robin + rr_ipv6_ptr: usize, } -/// The main DNS resolver struct -struct DnsResolver { - /// Local override table enabled - local_override_enabled: bool, - /// Local override table - local_override: LocalOverrideTable, - /// DNS over HTTPS enabled - doh_enabled: bool, - /// DNS over TLS enabled - dot_enabled: bool, - /// TTL override enabled - ttl_override_enabled: bool, - /// TTL override in seconds - ttl_override: u32, - /// List of DNS resolvers that are used for resolving domains - resolvers: Vec, +impl Default for DNSEntry { + fn default() -> DNSEntry { + DNSEntry { + domain_name: String::new(), + ipv4: Vec::new(), + ipv6: Vec::new(), + rr_ipv4_ptr: 0, + rr_ipv6_ptr: 0, + } + } } -impl DnsResolver { - pub fn new() -> Self { - DnsResolver { - local_override_enabled: false, - local_override: LocalOverrideTable::default(), - doh_enabled: false, - dot_enabled: false, - ttl_override_enabled: false, - ttl_override: 0, - resolvers: Vec::new(), - } +impl DNSEntry { + /// Returns true when this entry has a ipv4 address + fn has_ipv4(&self) -> bool { + !self.ipv4.is_empty() } - pub fn resolve(&mut self, domain: &str) -> Result { - if self.local_override_enabled { - if let Some(ip) = self.local_override.resolve(domain) { - return Ok(ip); - } - } + /// Returns true when this entry has a ipv6 address + fn has_ipv6(&self) -> bool { + !self.ipv6.is_empty() + } + + /// Retrieves the next ipv4 address in the round robin list + fn next_ipv4(&mut self) -> IpAddr { + let entry = self.ipv4.get(self.rr_ipv4_ptr).expect("Invalid round robin pointer"); + self.rr_ipv4_ptr = (self.rr_ipv4_ptr + 1) % self.ipv4.len(); - Err(types::Error::Dns("Could not resolve domain".to_string())) + entry.ip } - pub fn flush_dns_cache(&self) -> Result<()> { - Ok(()) + /// Retrieves the next ipv6 address in the round robin list + fn next_ipv6(&mut self) -> IpAddr { + let entry = self.ipv6.get(self.rr_ipv6_ptr).expect("Invalid round robin pointer"); + self.rr_ipv6_ptr = (self.rr_ipv6_ptr + 1) % self.ipv6.len(); + + entry.ip } +} + +/// Type of DNS resolution +enum ResolveType { + /// Only resolve IPV4 addresses (A) + Ipv4, + /// Only resolve IPV6 addresses (AAAA) + Ipv6, + /// Resolve both IPV4 and IPV6 addresses + Both, +} + +trait DnsResolver { + fn resolve(&mut self, domain: &str, resolve_type: ResolveType) -> Result; + fn flush_cache(&self); + fn flush_entry(&self, domain_name: &str); } \ No newline at end of file diff --git a/src/dns/cache.rs b/src/dns/cache.rs new file mode 100644 index 00000000..34bd4967 --- /dev/null +++ b/src/dns/cache.rs @@ -0,0 +1,69 @@ +use std::collections::HashMap; +use crate::dns::DNSEntry; + +struct Cache { + values: HashMap, + max_entries: usize, + lru: Vec, +} + +impl Cache { + fn new(max_entries: usize) -> Cache { + Cache { + values: HashMap::with_capacity(max_entries), + max_entries, + lru: Vec::with_capacity(max_entries), + } + } + + fn get(&mut self, domain: &str) -> Option<&DNSEntry> { + self.lru.retain(|x| x != domain); + self.lru.push(domain.to_string()); + + self.values.get(domain) + } + + fn insert(&mut self, domain: &str, entry: DNSEntry) { + self.lru.retain(|x| x != domain); + self.lru.push(domain.to_string()); + + self.values.insert(domain.to_string(), entry); + + if self.values.len() > self.max_entries { + let key = self.lru.remove(0); + self.values.remove(&key); + } + } +} + +#[cfg(test)] +mod test { + use crate::dns::DNSEntry; + use super::*; + + #[test] + fn test_cache() { + let mut cache = Cache::new(2); + + cache.insert("example.com", DNSEntry::default()); + cache.insert("example.org", DNSEntry::default()); + + assert_eq!(cache.values.len(), 2); + assert_eq!(cache.lru.len(), 2); + assert_eq!(cache.lru[0], "example.com"); + assert_eq!(cache.lru[1], "example.org"); + + cache.get("example.org"); + assert_eq!(cache.values.len(), 2); + assert_eq!(cache.lru.len(), 2); + assert_eq!(cache.lru[0], "example.org"); + assert_eq!(cache.lru[1], "example.com"); + + cache.insert("example.net", DNSEntry::default()); + + assert_eq!(cache.values.len(), 2); + assert_eq!(cache.lru.len(), 2); + assert_eq!(cache.lru[0], "example.net"); + assert_eq!(cache.lru[1], "example.org"); + } +} \ No newline at end of file diff --git a/src/dns/default.rs b/src/dns/default.rs new file mode 100644 index 00000000..739c5bf1 --- /dev/null +++ b/src/dns/default.rs @@ -0,0 +1,134 @@ +use std::collections::HashMap; +use std::net::IpAddr; +use wildmatch::WildMatch; +use crate::dns::DNSEntry; + +/// Local override table that can be used instead of using /etc/hosts or similar 3rd party dns system. +pub struct LocalOverrideTable { + /// Entries in the local override table. First iteration of the resolver: a simple list that + /// will be queried O(n). Later iterations will use a more efficient data structure like a tree + /// per domain part (e.g. com, org, net, etc.) + entries: Vec, + round_robin_ptrs: HashMap, +} + +impl Default for LocalOverrideTable { + fn default() -> LocalOverrideTable { + LocalOverrideTable { + entries: Vec::new(), + round_robin_ptrs: HashMap::new(), + } + } +} + +impl LocalOverrideTable { + + pub fn resolve(&mut self, domain: &str) -> Option { + let mut ips = Vec::new(); + let mut wildcard_ips = Vec::new(); + + // For now it's ok to iterate the entries list, but later we should use a more efficient + // data structure. + for entry in &self.entries { + // We separate wildcard matches and specific matches. + if WildMatch::new(&entry.domain_name).matches(domain) { + wildcard_ips.push(entry.ip); + } + if entry.domain_name == domain { + ips.push(entry.ip); + } + } + + // if we haven't found a specific domain match, we use the wildcard domain, otherwise + // the specific matches take precedence. + if ips.len() == 0 { + ips = wildcard_ips + } + + // If there is only one ip for a domain, return it + if ips.len() == 1 { + return Some(*ips.get(0).unwrap()); + } + + // If there are multiple ips for a domain, use round robin to fetch the next one + if ips.len() > 1 { + let mut rr_ptr = *self.round_robin_ptrs.get(domain).unwrap_or(&0); + if rr_ptr >= ips.len() { + rr_ptr = 0; + } + + let ip = ips.get(rr_ptr).expect("Invalid round robin pointer"); + self.round_robin_ptrs.insert(domain.to_string(), rr_ptr + 1 % ips.len()); + + return Some(*ip); + } + + None + } +} + + +#[cfg(test)] +mod test { + use core::str::FromStr; + use crate::dns::DnsResolver; + use super::*; + + #[test] + fn test_local_override() { + let mut resolver = DnsResolver::new(); + resolver.local_override_enabled = true; + resolver.local_override.entries.push(DNSEntry { + domain_name: "example.com".to_string(), + ip: IpAddr::from_str("1.2.3.4").unwrap(), + }); + + resolver.local_override.entries.push(DNSEntry { + domain_name: "foo.example.com".to_string(), + ip: IpAddr::from_str("2.3.4.5").unwrap(), + }); + + resolver.local_override.entries.push(DNSEntry { + domain_name: "*.wilcard.com".to_string(), + ip: IpAddr::from_str("6.6.6.6").unwrap(), + }); + + resolver.local_override.entries.push(DNSEntry { + domain_name: "specific.wilcard.com".to_string(), + ip: IpAddr::from_str("8.8.8.8").unwrap(), + }); + + resolver.local_override.entries.push(DNSEntry { + domain_name: "ipv6.com".to_string(), + ip: IpAddr::from_str("2002::1").unwrap(), + }); + resolver.local_override.entries.push(DNSEntry { + domain_name: "ipv6.com".to_string(), + ip: IpAddr::from_str("2002::2").unwrap(), + }); + resolver.local_override.entries.push(DNSEntry { + domain_name: "ipv6.com".to_string(), + ip: IpAddr::from_str("200.200.200.200").unwrap(), + }); + + // Simple resolve + assert_eq!(IpAddr::from_str("1.2.3.4").unwrap(), resolver.resolve("example.com").unwrap()); + assert!(resolver.resolve("xample.com").is_err()); + assert!(resolver.resolve("com").is_err()); + assert!(resolver.resolve("example").is_err()); + + // Wildcard + assert_eq!(IpAddr::from_str("8.8.8.8").unwrap(), resolver.resolve("specific.wilcard.com").unwrap()); + assert_eq!(IpAddr::from_str("6.6.6.6").unwrap(), resolver.resolve("something.wilcard.com").unwrap()); + assert_eq!(IpAddr::from_str("6.6.6.6").unwrap(), resolver.resolve("foobar.wilcard.com").unwrap()); + assert!(resolver.resolve("wilcard.com").is_err()); + + // round robin + assert_eq!(IpAddr::from_str("2002::1").unwrap(), resolver.resolve("ipv6.com").unwrap()); + assert_eq!(IpAddr::from_str("2002::2").unwrap(), resolver.resolve("ipv6.com").unwrap()); + assert_eq!(IpAddr::from_str("200.200.200.200").unwrap(), resolver.resolve("ipv6.com").unwrap()); + assert_eq!(IpAddr::from_str("2002::1").unwrap(), resolver.resolve("ipv6.com").unwrap()); + assert_eq!(IpAddr::from_str("2002::2").unwrap(), resolver.resolve("ipv6.com").unwrap()); + assert_eq!(IpAddr::from_str("200.200.200.200").unwrap(), resolver.resolve("ipv6.com").unwrap()); + } +} \ No newline at end of file diff --git a/src/dns/local.rs b/src/dns/local.rs index 739c5bf1..671479cb 100644 --- a/src/dns/local.rs +++ b/src/dns/local.rs @@ -1,69 +1,103 @@ -use std::collections::HashMap; use std::net::IpAddr; use wildmatch::WildMatch; -use crate::dns::DNSEntry; +use crate::dns::{DNSEntry, ResolveType}; +use crate::types; /// Local override table that can be used instead of using /etc/hosts or similar 3rd party dns system. +#[derive(Default)] pub struct LocalOverrideTable { /// Entries in the local override table. First iteration of the resolver: a simple list that /// will be queried O(n). Later iterations will use a more efficient data structure like a tree /// per domain part (e.g. com, org, net, etc.) entries: Vec, - round_robin_ptrs: HashMap, } -impl Default for LocalOverrideTable { - fn default() -> LocalOverrideTable { - LocalOverrideTable { +impl LocalOverrideTable { + + pub fn new() -> LocalOverrideTable { + let mut table = LocalOverrideTable { entries: Vec::new(), - round_robin_ptrs: HashMap::new(), - } + }; + + table.reload_table_entries(); + table } -} -impl LocalOverrideTable { + /// Regenerates the new entries table + pub fn reload_table_entries(&mut self) -> Result<(), types::Error> { + // @todo: this should reload all table entries from the configuration into the self.entries list + } - pub fn resolve(&mut self, domain: &str) -> Option { - let mut ips = Vec::new(); - let mut wildcard_ips = Vec::new(); + pub fn resolve(&mut self, domain: &str, resolve_type: ResolveType) -> Result { + let mut matched_entry; + let mut matched_wildcard_entries = Vec::new(); // For now it's ok to iterate the entries list, but later we should use a more efficient - // data structure. + // data structure in the future for entry in &self.entries { - // We separate wildcard matches and specific matches. + // We separate wildcard matches and specific matches (e.g. *.example.com vs example.com) if WildMatch::new(&entry.domain_name).matches(domain) { - wildcard_ips.push(entry.ip); + matched_wildcard_entries.push(entry); } if entry.domain_name == domain { - ips.push(entry.ip); + // First match is ok to return + matched_entry = Some(entry); + break; } } // if we haven't found a specific domain match, we use the wildcard domain, otherwise // the specific matches take precedence. - if ips.len() == 0 { - ips = wildcard_ips + if matched_entry.is_none() { + // Entry most specific match + matched_entry = Some(self.find_most_specific_match(matched_wildcard_entries)) } + let mut matched_entry = matched_entry.expect("No matched entries found"); - // If there is only one ip for a domain, return it - if ips.len() == 1 { - return Some(*ips.get(0).unwrap()); - } + // Returned type is the actual returned type (in case resolve_type was BOTH) + let mut returned_type = resolve_type; - // If there are multiple ips for a domain, use round robin to fetch the next one - if ips.len() > 1 { - let mut rr_ptr = *self.round_robin_ptrs.get(domain).unwrap_or(&0); - if rr_ptr >= ips.len() { - rr_ptr = 0; - } + // Update round robin pointers (if needed) + if returned_type == ResolveType::Ipv4 { + let rr_ptr = matched_entry.rr_ipv4_ptr; + matched_entry.rr_ipv4_ptr = (rr_ptr + 1) % matched_entry.ipv4.len(); + } - let ip = ips.get(rr_ptr).expect("Invalid round robin pointer"); - self.round_robin_ptrs.insert(domain.to_string(), rr_ptr + 1 % ips.len()); + if returned_type == ResolveType::Ipv6 { + let rr_ptr = matched_entry.rr_ipv6_ptr; + matched_entry.rr_ipv6_ptr = (rr_ptr + 1) % matched_entry.ipv6.len(); + } - return Some(*ip); + // If there is only one ip for a domain, return it + if matched_entries.len() == 1 { + match resolve_type { + ResolveType::Ipv4 if matched_entries[0].has_ipv4() => { + if matched_entries[0].has_ipv4() { + Ok(matched_entries[0].next_ipv4()) + } else { + Err(types::Error::NoIpv4Found) + } + + } + ResolveType::Ipv6 => { + if matched_entries[0].has_ipv6() { + Ok(matched_entries[0].next_ipv6()) + } else { + Err(types::Error::NoIpv6Found) + } + } + ResolveType::Both => { + // IPv6 takes precedence over IPv4 + if matched_entries[0].has_ipv6() { + Ok(matched_entries[0].next_ipv6()) + } else { + Ok(matched_entries[0].next_ipv4()) + } + } + } } - None + Err(types::Error::NoIpAddressFound) } } @@ -76,6 +110,17 @@ mod test { #[test] fn test_local_override() { + let mut table = LocalOverrideTable::new(); + table.entries.push(DNSEntry { + domain_name: "example.com".to_string(), + ipv4: vec![IpAddr::from_str("1.2.3.4").unwrap()], + ipv6: Vec::new(), + rr_ipv4_ptr: 0, + rr_ipv6_ptr: 0, + }); + + + let mut resolver = DnsResolver::new(); resolver.local_override_enabled = true; resolver.local_override.entries.push(DNSEntry { diff --git a/src/types.rs b/src/types.rs index 802d1df8..8971e4ec 100644 --- a/src/types.rs +++ b/src/types.rs @@ -47,8 +47,16 @@ pub enum Error { #[error("query error: {0}")] Query(String), - #[error("dns error: {0}")] + #[error("dns: error: {0}")] Dns(String), + + #[error("dns: no ipv6 address found: {0}")] + NoIpv6Found, + + #[error("dns: no ipv4 address found: {0}")] + NoIpv4Found, + + NoIpAddressFound, } /// Result that can be returned which holds either T or an Error