diff --git a/casbin/persist/adapters/filtered_file_adapter.py b/casbin/persist/adapters/filtered_file_adapter.py index eeb3a6c..b0e31cb 100644 --- a/casbin/persist/adapters/filtered_file_adapter.py +++ b/casbin/persist/adapters/filtered_file_adapter.py @@ -52,25 +52,28 @@ def load_filtered_policy(self, model, filter): try: filter_value = [filter.__dict__["P"]] + [filter.__dict__["G"]] + is_empty_filter = all(not f for f in filter_value) or all( + all(not x.strip() for x in f) if f else True for f in filter_value + ) + if is_empty_filter: + return self.load_policy(model) except: raise RuntimeError("invalid filter type") self.load_filtered_policy_file(model, filter_value, persist.load_policy_line) self.filtered = True - def load_filtered_policy_file(self, model, filter, hanlder): + def load_filtered_policy_file(self, model, filter, handler): with open(self._file_path, "rb") as file: - while True: - line = file.readline() + for line in file: line = line.decode().strip() - if line == "\n": + if not line or line == "\n": continue - if not line: - break + if filter_line(line, filter): continue - hanlder(line, model) + handler(line, model) # is_filtered returns true if the loaded policy has been filtered. def is_filtered(self): @@ -92,10 +95,13 @@ def filter_line(line, filter): return True filter_slice = [] - if p[0].strip() == "p": - filter_slice = filter[0] - elif p[0].strip() == "g": + if p[0].strip() == "g": + if not filter[1] or all(not x.strip() for x in filter[1]): + return False filter_slice = filter[1] + elif p[0].strip() == "p": + filter_slice = filter[0] + return filter_words(p, filter_slice) @@ -104,7 +110,7 @@ def filter_words(line, filter): return True skip_line = False for i, v in enumerate(filter): - if len(v) > 0 and (v.strip() != line[i + 1].strip()): + if v and v.strip() and (v.strip() != line[i + 1].strip()): skip_line = True break diff --git a/tests/test_filter.py b/tests/test_filter.py index 536505f..0160325 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import casbin +import os from unittest import TestCase +import casbin from tests.test_enforcer import get_examples from casbin.persist.adapters import FilteredFileAdapter +from casbin.persist.adapters.filtered_file_adapter import filter_line, filter_words class Filter: @@ -141,3 +143,175 @@ def test_filtered_adapter_invalid_filepath(self): with self.assertRaises(RuntimeError): e.load_filtered_policy(None) + + def test_empty_filter_array(self): + """Test filter for empty array.""" + adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv")) + e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter) + filter = Filter() + filter.P = [] + filter.G = [] + + e.load_filtered_policy(filter) + self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"])) + self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"])) + + def test_empty_string_filter(self): + """Test the filter for all empty strings.""" + adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv")) + e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter) + filter = Filter() + filter.P = ["", "", ""] + filter.G = ["", "", ""] + + e.load_filtered_policy(filter) + self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"])) + self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"])) + + def test_mixed_empty_filter(self): + """Test the filter for mixed empty and non-empty strings.""" + adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv")) + e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter) + filter = Filter() + filter.P = ["", "domain1", ""] + filter.G = ["", "", "domain1"] + + e.load_filtered_policy(filter) + self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"])) + self.assertFalse(e.has_policy(["admin", "domain2", "data2", "read"])) + + def test_nonexistent_domain_filter(self): + """Testing the filter for a non-existent domain.""" + adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv")) + e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter) + filter = Filter() + filter.P = ["", "domain3"] + filter.G = ["", "", "domain3"] + + e.load_filtered_policy(filter) + self.assertFalse(e.has_policy(["admin", "domain3", "data3", "read"])) + + def test_empty_filter_array(self): + """Test filter for empty array.""" + adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv")) + e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter) + filter = Filter() + filter.P = [] + filter.G = [] + + try: + e.load_filtered_policy(filter) + except: + raise RuntimeError("unexpected error with empty filter arrays") + + self.assertFalse(e.is_filtered(), "Adapter should not be marked as filtered with empty filters") + + self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"])) + self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"])) + + def test_empty_string_filter(self): + """Test the filter for all empty strings.""" + adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv")) + e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter) + filter = Filter() + filter.P = ["", "", ""] + filter.G = ["", "", ""] + + try: + e.load_filtered_policy(filter) + except: + raise RuntimeError("unexpected error with empty string filters") + + self.assertFalse(e.is_filtered(), "Adapter should not be marked as filtered with empty string filters") + + try: + e.save_policy() + except: + raise RuntimeError("unexpected error in SavePolicy with empty string filters") + + self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"])) + self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"])) + + def test_mixed_empty_filter(self): + """Test the filter for mixed empty and non-empty strings.""" + adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv")) + e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter) + filter = Filter() + filter.P = ["", "domain1", ""] + filter.G = ["", "", "domain1"] + + try: + e.load_filtered_policy(filter) + except: + raise RuntimeError("unexpected error with mixed empty filters") + + self.assertTrue(e.is_filtered(), "Adapter should be marked as filtered") + + with self.assertRaises(RuntimeError): + e.save_policy() + + self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"])) + self.assertFalse(e.has_policy(["admin", "domain2", "data2", "read"])) + + def test_whitespace_filter(self): + """Test the filter for all blank characters.""" + adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv")) + e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter) + filter = Filter() + filter.P = [" ", " ", "\t"] + filter.G = ["\n", " ", " "] + + e.load_filtered_policy(filter) + + self.assertFalse(e.is_filtered()) + self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"])) + self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"])) + + def test_filter_line_edge_cases(self): + """Test the boundary cases of the filter_line function.""" + adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv")) + + self.assertFalse(filter_line("", [[""], [""]])) + + self.assertFalse(filter_line("invalid_line", [[""], [""]])) + + self.assertFalse(filter_line("p, admin, domain1, data1, read", None)) + + def test_filter_words_edge_cases(self): + """Test the boundary cases of the filter_words function.""" + self.assertTrue(filter_words(["p"], ["filter1", "filter2"])) + + self.assertFalse(filter_words(["p", "admin", "domain1"], [])) + + line = ["admin", "domain1", "data*", "read"] + filter = ["", "", "data1", ""] + self.assertTrue(filter_words(line, filter)) + + def test_load_filtered_policy_with_comments(self): + """Test loading filtering policies with comments.""" + import tempfile + import shutil + + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: + with open(get_examples("rbac_with_domains_policy.csv"), "r") as source: + shutil.copyfileobj(source, temp_file) + + temp_file.write("\n# This is a comment\np, admin, domain1, data3, read") + temp_file.flush() + + temp_path = temp_file.name + + try: + adapter = FilteredFileAdapter(temp_path) + e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter) + filter = Filter() + filter.P = ["", "domain1"] + filter.G = ["", "", "domain1"] + + e.load_filtered_policy(filter) + self.assertTrue(e.has_policy(["admin", "domain1", "data3", "read"])) + finally: + try: + os.unlink(temp_path) + except OSError: + pass