Skip to content

Commit

Permalink
Allow override direct & reverse domain matching
Browse files Browse the repository at this point in the history
  • Loading branch information
sorz committed Apr 13, 2023
1 parent 2504f85 commit 962647f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 54 deletions.
5 changes: 3 additions & 2 deletions conf/policy.rules
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ listen-port 8002 require cap1 or cap2
dst domain netflix.com require streaming
dst domain netflix.com require us

# *.cn will not use any proxy
# `direct` always override `require` actions
# *.cn will not use any proxy, expect *.edu.cn require proxies with "edu"
# more specific match override less specific one
dst domain cn direct
dst domain edu.cn require edu

# *.edu.au will match both rules, and require proxies with BOTH "us" AND "edu"
dst domain au require us
Expand Down
102 changes: 50 additions & 52 deletions src/policy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,35 @@ impl Default for Action {
}
}

impl Action {
fn add_require(&mut self, caps: CapSet) {
match self {
Self::Direct => false,
Self::Require(set) => set.insert(caps),
};
}

fn set_direct(&mut self) {
*self = Self::Direct
impl From<parser::RuleAction> for Action {
fn from(value: parser::RuleAction) -> Self {
match value {
parser::RuleAction::Direct => Self::Direct,
parser::RuleAction::Require(caps) => {
let mut set = HashSet::new();
set.insert(caps);
Self::Require(set)
}
}
}
}

impl Action {
fn len(&self) -> usize {
match self {
Self::Direct => 1,
Self::Require(set) => set.len(),
}
}

fn extend(&mut self, other: &Self) {
fn extend(&mut self, other: Self) {
match other {
Self::Direct => self.set_direct(),
Self::Direct => *self = Self::Direct,
Self::Require(new_caps) => {
if let Self::Require(caps) = self {
caps.extend(new_caps.iter().cloned())
caps.extend(new_caps.into_iter())
} else {
*self = Self::Require(new_caps)
}
}
}
Expand All @@ -67,10 +71,7 @@ impl<K: Eq + Hash> RuleSet<K> {
fn add(&mut self, key: K, action: parser::RuleAction) {
// TODO: warning duplicated rules
let value = self.0.entry(key).or_default();
match action {
parser::RuleAction::Require(caps) => value.add_require(caps),
parser::RuleAction::Direct => value.set_direct(),
}
value.extend(action.into())
}

fn get<'a>(&'a self, key: &'a K) -> impl Iterator<Item = &'a Action> {
Expand All @@ -80,14 +81,15 @@ impl<K: Eq + Hash> RuleSet<K> {

impl DstDomainRuleSet {
fn get_recursive<'a>(&'a self, name: &'a str) -> impl Iterator<Item = &'a Action> {
let mut skip = 0usize;
name.split_terminator('.')
.map(move |part| {
let key = &name[skip..];
skip += part.len() + 1;
key
})
.chain(["."])
let name = name.trim_end_matches('.'); // Add back later
let mut skip = name.len() + 1; // pretend ending with dot
let parts = name.rsplit('.').map(move |part| {
skip -= part.len() + 1; // +1 for the dot
&name[skip..]
});
["."] // add back the dot
.into_iter()
.chain(parts)
.filter_map(|key| self.0.get(key))
}
}
Expand Down Expand Up @@ -123,10 +125,7 @@ impl Policy {
fn add_rule(&mut self, rule: parser::Rule) {
let parser::Rule { filter, action } = rule;
match filter {
parser::RuleFilter::Default => match action {
parser::RuleAction::Require(caps) => self.default_action.add_require(caps),
parser::RuleAction::Direct => self.default_action.set_direct(),
},
parser::RuleFilter::Default => self.default_action.extend(action.into()),
parser::RuleFilter::ListenPort(port) => {
self.listen_port_ruleset.add(port, action);
}
Expand All @@ -149,12 +148,12 @@ impl Policy {
if let Some(port) = listen_port {
self.listen_port_ruleset
.get(&port)
.for_each(|a| action.extend(a))
.for_each(|a| action.extend(a.clone()))
}
if let Some(name) = dst_domain {
self.dst_domain_ruleset
.get_recursive(&name)
.for_each(|a| action.extend(a));
.for_each(|a| action.extend(a.clone()));
}
action
}
Expand Down Expand Up @@ -201,37 +200,36 @@ fn test_policy_get_domain_caps_requirements() {
.as_bytes(),
)
.unwrap();
assert_eq!(
3,
policy
.dst_domain_ruleset
.get_recursive("test.example.com")
.count()
);
assert_eq!(
3,
policy
.dst_domain_ruleset
.get_recursive("example.com")
.count()
);
assert_eq!(2, policy.dst_domain_ruleset.get_recursive("com").count());
assert_eq!(1, policy.dst_domain_ruleset.get_recursive("net").count());
let set = policy.dst_domain_ruleset;
assert_eq!(3, set.get_recursive("test.example.com").count());
assert_eq!(3, set.get_recursive("example.com").count());
assert_eq!(2, set.get_recursive("com").count());
assert_eq!(1, set.get_recursive("net").count());
}

#[test]
fn test_policy_action() {
let rules = "
default require def
listen-port 1 require a
listen-port 1 direct
listen-port 2 direct
dst domain test require c
dst domain d.test direct
";
let policy = Policy::load(rules.as_bytes()).unwrap();
let direct = policy.matches(Some(1), Some("test".into()));
let require1 = policy.matches(Some(2), Some("abcd".into()));
let require2 = policy.matches(None, Some("test".into()));
assert!(matches!(direct, Action::Direct));
// listen-port/direct override default/require
let direct1 = policy.matches(Some(2), Some("abcd".into()));
assert!(matches!(direct1, Action::Direct));
// d.test/direct override others
let direct2 = policy.matches(Some(1), Some("a.d.test".into()));
assert!(matches!(direct2, Action::Direct));
// just default/require
let require1 = policy.matches(Some(3), Some("abcd".into()));
assert!(matches!(require1, Action::Require(a) if a.len() == 1));
// default/require + dst-domain/require
let require2 = policy.matches(None, Some("test".into()));
assert!(matches!(require2, Action::Require(a) if a.len() == 2));
// default/require + dst-domain/require + listen-port/require
let require3 = policy.matches(Some(1), Some("test".into()));
assert!(matches!(require3, Action::Require(a) if a.len() == 3));
}

0 comments on commit 962647f

Please sign in to comment.