diff --git a/casbin/model/policy.py b/casbin/model/policy.py index 5daeea3..0927846 100644 --- a/casbin/model/policy.py +++ b/casbin/model/policy.py @@ -75,7 +75,8 @@ def print_policy(self): continue for key, ast in self[sec].items(): - self.logger.info("{} : {} : {}".format(key, ast.value, ast.policy)) + self.logger.info("{} : {} : {}".format( + key, ast.value, ast.policy)) def clear_policy(self): """clears all current policy.""" @@ -98,7 +99,8 @@ def get_filtered_policy(self, sec, ptype, field_index, *field_values): rule for rule in self[sec][ptype].policy if all( - (callable(value) and value(rule[field_index + i])) or (value == "" or rule[field_index + i] == value) + (callable(value) and value( + rule[field_index + i])) or (value == "" or rule[field_index + i] == value) for i, value in enumerate(field_values) ) ] @@ -127,7 +129,8 @@ def add_policy(self, sec, ptype, rule): i = len(assertion.policy) - 1 for i in range(i, 0, -1): try: - idx = int(assertion.policy[i - 1][assertion.priority_index]) + idx = int( + assertion.policy[i - 1][assertion.priority_index]) except Exception as e: print(e) @@ -143,20 +146,28 @@ def add_policy(self, sec, ptype, rule): except Exception as e: print(e) - assertion.policy_map[DEFAULT_SEP.join(rule)] = len(assertion.policy) - 1 + assertion.policy_map[DEFAULT_SEP.join( + rule)] = len(assertion.policy) - 1 return True def add_policies(self, sec, ptype, rules): """adds policy rules to the model.""" + if ptype not in self[sec]: + return False - for rule in rules: - if self.has_policy(sec, ptype, rule): - return False + policy_map = {tuple(policy): idx for idx, + policy in enumerate(self[sec][ptype].policy)} + added = False for rule in rules: + rule_tuple = tuple(rule) + if rule_tuple in policy_map: + continue + self[sec][ptype].policy.append(rule) + added = True - return True + return added def update_policy(self, sec, ptype, old_rule, new_rule): """update a policy rule from the model.""" @@ -178,43 +189,32 @@ def update_policy(self, sec, ptype, old_rule, new_rule): if old_rule[priority_index] == new_rule[priority_index]: ast.policy[rule_index] = new_rule else: - raise Exception("New rule should have the same priority with old rule.") + raise Exception( + "New rule should have the same priority with old rule.") else: ast.policy[rule_index] = new_rule return True def update_policies(self, sec, ptype, old_rules, new_rules): - """update policy rules from the model.""" - - if sec not in self.keys(): - return False + """updates policy rules in the model.""" if ptype not in self[sec]: return False - if len(old_rules) != len(new_rules): - return False - ast = self[sec][ptype] - old_rules_index = [] + policy_map = {tuple(policy): idx for idx, + policy in enumerate(self[sec][ptype].policy)} + updated = False - for old_rule in old_rules: - if old_rule in ast.policy: - old_rules_index.append(ast.policy.index(old_rule)) - else: - return False + for old_rule, new_rule in zip(old_rules, new_rules): + old_rule_tuple = tuple(old_rule) + if old_rule_tuple not in policy_map: + continue - if "p_priority" in ast.tokens: - priority_index = ast.tokens.index("p_priority") - for idx, old_rule, new_rule in zip(old_rules_index, old_rules, new_rules): - if old_rule[priority_index] == new_rule[priority_index]: - ast.policy[idx] = new_rule - else: - raise Exception("New rule should have the same priority with old rule.") - else: - for idx, old_rule, new_rule in zip(old_rules_index, old_rules, new_rules): - ast.policy[idx] = new_rule + idx = policy_map[old_rule_tuple] + self[sec][ptype].policy[idx] = new_rule + updated = True - return True + return updated def remove_policy(self, sec, ptype, rule): """removes a policy rule from the model.""" @@ -226,16 +226,23 @@ def remove_policy(self, sec, ptype, rule): return rule not in self[sec][ptype].policy def remove_policies(self, sec, ptype, rules): - """RemovePolicies removes policy rules from the model.""" + """removes policy rules from the model.""" + if ptype not in self[sec]: + return False + + policy_map = {tuple(policy): idx for idx, + policy in enumerate(self[sec][ptype].policy)} + removed = False for rule in rules: - if not self.has_policy(sec, ptype, rule): - return False - self[sec][ptype].policy.remove(rule) - if rule in self[sec][ptype].policy: - return False + rule_tuple = tuple(rule) + if rule_tuple not in policy_map: + continue - return True + self[sec][ptype].policy.remove(list(rule)) + removed = True + + return removed def remove_policies_with_effected(self, sec, ptype, rules): effected = []