Skip to content

Commit

Permalink
tmp commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jaytaph committed Nov 15, 2023
1 parent 33b025c commit df8e9f9
Show file tree
Hide file tree
Showing 5 changed files with 357 additions and 73 deletions.
106 changes: 67 additions & 39 deletions src/dns.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,86 @@
mod local;
mod default;
mod cache;

use crate::types::Result;
use std::net::IpAddr;
use crate::dns::local::LocalOverrideTable;
use crate::types;

/// A single IP entry in the DNS cache
struct IpEntry {
/// Actual IP address (either ipv4 or ipv6)
ip: IpAddr,
/// Epoch when ip address expires
expired: u64,
/// Epoch when this entry was last used
last_used: u64,
}

/// A DNS entry is a simple domain name to IP address mapping
struct DNSEntry {
// domain name
domain_name: String,
ip: IpAddr,
// List of IPv4 addresses for this domain name
ipv4: Vec<IpEntry>,
// List of IPv6 addresses for this domain name
ipv6: Vec<IpEntry>,
// Round robin pointer in case we need to round robin
rr_ipv4_ptr: usize,
// Round robin pointer in case we need to round robin
rr_ipv6_ptr: usize,
}

/// The main DNS resolver struct
struct DnsResolver {
/// Local override table enabled
local_override_enabled: bool,
/// Local override table
local_override: LocalOverrideTable,
/// DNS over HTTPS enabled
doh_enabled: bool,
/// DNS over TLS enabled
dot_enabled: bool,
/// TTL override enabled
ttl_override_enabled: bool,
/// TTL override in seconds
ttl_override: u32,
/// List of DNS resolvers that are used for resolving domains
resolvers: Vec<IpAddr>,
impl Default for DNSEntry {
fn default() -> DNSEntry {
DNSEntry {
domain_name: String::new(),
ipv4: Vec::new(),
ipv6: Vec::new(),
rr_ipv4_ptr: 0,
rr_ipv6_ptr: 0,
}
}
}

impl DnsResolver {
pub fn new() -> Self {
DnsResolver {
local_override_enabled: false,
local_override: LocalOverrideTable::default(),
doh_enabled: false,
dot_enabled: false,
ttl_override_enabled: false,
ttl_override: 0,
resolvers: Vec::new(),
}
impl DNSEntry {
/// Returns true when this entry has a ipv4 address
fn has_ipv4(&self) -> bool {
!self.ipv4.is_empty()
}

pub fn resolve(&mut self, domain: &str) -> Result<IpAddr> {
if self.local_override_enabled {
if let Some(ip) = self.local_override.resolve(domain) {
return Ok(ip);
}
}
/// Returns true when this entry has a ipv6 address
fn has_ipv6(&self) -> bool {
!self.ipv6.is_empty()
}

/// Retrieves the next ipv4 address in the round robin list
fn next_ipv4(&mut self) -> IpAddr {
let entry = self.ipv4.get(self.rr_ipv4_ptr).expect("Invalid round robin pointer");
self.rr_ipv4_ptr = (self.rr_ipv4_ptr + 1) % self.ipv4.len();

Err(types::Error::Dns("Could not resolve domain".to_string()))
entry.ip
}

pub fn flush_dns_cache(&self) -> Result<()> {
Ok(())
/// Retrieves the next ipv6 address in the round robin list
fn next_ipv6(&mut self) -> IpAddr {
let entry = self.ipv6.get(self.rr_ipv6_ptr).expect("Invalid round robin pointer");
self.rr_ipv6_ptr = (self.rr_ipv6_ptr + 1) % self.ipv6.len();

entry.ip
}
}

/// Type of DNS resolution
enum ResolveType {
/// Only resolve IPV4 addresses (A)
Ipv4,
/// Only resolve IPV6 addresses (AAAA)
Ipv6,
/// Resolve both IPV4 and IPV6 addresses
Both,
}

trait DnsResolver {
fn resolve(&mut self, domain: &str, resolve_type: ResolveType) -> Result<IpAddr>;
fn flush_cache(&self);
fn flush_entry(&self, domain_name: &str);
}
69 changes: 69 additions & 0 deletions src/dns/cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::collections::HashMap;
use crate::dns::DNSEntry;

struct Cache {
values: HashMap<String, DNSEntry>,
max_entries: usize,
lru: Vec<String>,
}

impl Cache {
fn new(max_entries: usize) -> Cache {
Cache {
values: HashMap::with_capacity(max_entries),
max_entries,
lru: Vec::with_capacity(max_entries),
}
}

fn get(&mut self, domain: &str) -> Option<&DNSEntry> {
self.lru.retain(|x| x != domain);
self.lru.push(domain.to_string());

self.values.get(domain)
}

fn insert(&mut self, domain: &str, entry: DNSEntry) {
self.lru.retain(|x| x != domain);
self.lru.push(domain.to_string());

self.values.insert(domain.to_string(), entry);

if self.values.len() > self.max_entries {
let key = self.lru.remove(0);
self.values.remove(&key);
}
}
}

#[cfg(test)]
mod test {
use crate::dns::DNSEntry;
use super::*;

#[test]
fn test_cache() {
let mut cache = Cache::new(2);

cache.insert("example.com", DNSEntry::default());
cache.insert("example.org", DNSEntry::default());

assert_eq!(cache.values.len(), 2);
assert_eq!(cache.lru.len(), 2);
assert_eq!(cache.lru[0], "example.com");
assert_eq!(cache.lru[1], "example.org");

cache.get("example.org");
assert_eq!(cache.values.len(), 2);
assert_eq!(cache.lru.len(), 2);
assert_eq!(cache.lru[0], "example.org");
assert_eq!(cache.lru[1], "example.com");

cache.insert("example.net", DNSEntry::default());

assert_eq!(cache.values.len(), 2);
assert_eq!(cache.lru.len(), 2);
assert_eq!(cache.lru[0], "example.net");
assert_eq!(cache.lru[1], "example.org");
}
}
134 changes: 134 additions & 0 deletions src/dns/default.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use std::collections::HashMap;
use std::net::IpAddr;
use wildmatch::WildMatch;
use crate::dns::DNSEntry;

/// Local override table that can be used instead of using /etc/hosts or similar 3rd party dns system.
pub struct LocalOverrideTable {
/// Entries in the local override table. First iteration of the resolver: a simple list that
/// will be queried O(n). Later iterations will use a more efficient data structure like a tree
/// per domain part (e.g. com, org, net, etc.)
entries: Vec<DNSEntry>,
round_robin_ptrs: HashMap<String, usize>,
}

impl Default for LocalOverrideTable {
fn default() -> LocalOverrideTable {
LocalOverrideTable {
entries: Vec::new(),
round_robin_ptrs: HashMap::new(),
}
}
}

impl LocalOverrideTable {

pub fn resolve(&mut self, domain: &str) -> Option<IpAddr> {
let mut ips = Vec::new();
let mut wildcard_ips = Vec::new();

// For now it's ok to iterate the entries list, but later we should use a more efficient
// data structure.
for entry in &self.entries {
// We separate wildcard matches and specific matches.
if WildMatch::new(&entry.domain_name).matches(domain) {
wildcard_ips.push(entry.ip);
}
if entry.domain_name == domain {
ips.push(entry.ip);
}
}

// if we haven't found a specific domain match, we use the wildcard domain, otherwise
// the specific matches take precedence.
if ips.len() == 0 {
ips = wildcard_ips
}

// If there is only one ip for a domain, return it
if ips.len() == 1 {
return Some(*ips.get(0).unwrap());
}

// If there are multiple ips for a domain, use round robin to fetch the next one
if ips.len() > 1 {
let mut rr_ptr = *self.round_robin_ptrs.get(domain).unwrap_or(&0);
if rr_ptr >= ips.len() {
rr_ptr = 0;
}

let ip = ips.get(rr_ptr).expect("Invalid round robin pointer");
self.round_robin_ptrs.insert(domain.to_string(), rr_ptr + 1 % ips.len());

return Some(*ip);
}

None
}
}


#[cfg(test)]
mod test {
use core::str::FromStr;
use crate::dns::DnsResolver;
use super::*;

#[test]
fn test_local_override() {
let mut resolver = DnsResolver::new();
resolver.local_override_enabled = true;
resolver.local_override.entries.push(DNSEntry {
domain_name: "example.com".to_string(),
ip: IpAddr::from_str("1.2.3.4").unwrap(),
});

resolver.local_override.entries.push(DNSEntry {
domain_name: "foo.example.com".to_string(),
ip: IpAddr::from_str("2.3.4.5").unwrap(),
});

resolver.local_override.entries.push(DNSEntry {
domain_name: "*.wilcard.com".to_string(),
ip: IpAddr::from_str("6.6.6.6").unwrap(),
});

resolver.local_override.entries.push(DNSEntry {
domain_name: "specific.wilcard.com".to_string(),
ip: IpAddr::from_str("8.8.8.8").unwrap(),
});

resolver.local_override.entries.push(DNSEntry {
domain_name: "ipv6.com".to_string(),
ip: IpAddr::from_str("2002::1").unwrap(),
});
resolver.local_override.entries.push(DNSEntry {
domain_name: "ipv6.com".to_string(),
ip: IpAddr::from_str("2002::2").unwrap(),
});
resolver.local_override.entries.push(DNSEntry {
domain_name: "ipv6.com".to_string(),
ip: IpAddr::from_str("200.200.200.200").unwrap(),
});

// Simple resolve
assert_eq!(IpAddr::from_str("1.2.3.4").unwrap(), resolver.resolve("example.com").unwrap());
assert!(resolver.resolve("xample.com").is_err());
assert!(resolver.resolve("com").is_err());
assert!(resolver.resolve("example").is_err());

// Wildcard
assert_eq!(IpAddr::from_str("8.8.8.8").unwrap(), resolver.resolve("specific.wilcard.com").unwrap());
assert_eq!(IpAddr::from_str("6.6.6.6").unwrap(), resolver.resolve("something.wilcard.com").unwrap());
assert_eq!(IpAddr::from_str("6.6.6.6").unwrap(), resolver.resolve("foobar.wilcard.com").unwrap());
assert!(resolver.resolve("wilcard.com").is_err());

// round robin
assert_eq!(IpAddr::from_str("2002::1").unwrap(), resolver.resolve("ipv6.com").unwrap());
assert_eq!(IpAddr::from_str("2002::2").unwrap(), resolver.resolve("ipv6.com").unwrap());
assert_eq!(IpAddr::from_str("200.200.200.200").unwrap(), resolver.resolve("ipv6.com").unwrap());
assert_eq!(IpAddr::from_str("2002::1").unwrap(), resolver.resolve("ipv6.com").unwrap());
assert_eq!(IpAddr::from_str("2002::2").unwrap(), resolver.resolve("ipv6.com").unwrap());
assert_eq!(IpAddr::from_str("200.200.200.200").unwrap(), resolver.resolve("ipv6.com").unwrap());
}
}
Loading

0 comments on commit df8e9f9

Please sign in to comment.