Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enhance FilteredFileAdapter to handle flexible filtering for policies and roles #360

Merged
merged 5 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions casbin/persist/adapters/filtered_file_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,28 @@ def load_filtered_policy(self, model, filter):

try:
filter_value = [filter.__dict__["P"]] + [filter.__dict__["G"]]
is_empty_filter = all(not f for f in filter_value) or all(
all(not x.strip() for x in f) if f else True for f in filter_value
)
if is_empty_filter:
return self.load_policy(model)
except:
raise RuntimeError("invalid filter type")

self.load_filtered_policy_file(model, filter_value, persist.load_policy_line)
self.filtered = True

def load_filtered_policy_file(self, model, filter, hanlder):
def load_filtered_policy_file(self, model, filter, handler):
with open(self._file_path, "rb") as file:
while True:
line = file.readline()
for line in file:
line = line.decode().strip()
if line == "\n":
if not line or line == "\n":
continue
if not line:
break

if filter_line(line, filter):
continue

hanlder(line, model)
handler(line, model)

# is_filtered returns true if the loaded policy has been filtered.
def is_filtered(self):
Expand All @@ -92,10 +95,13 @@ def filter_line(line, filter):
return True
filter_slice = []

if p[0].strip() == "p":
filter_slice = filter[0]
elif p[0].strip() == "g":
if p[0].strip() == "g":
if not filter[1] or all(not x.strip() for x in filter[1]):
return False
filter_slice = filter[1]
elif p[0].strip() == "p":
filter_slice = filter[0]

return filter_words(p, filter_slice)


Expand All @@ -104,7 +110,7 @@ def filter_words(line, filter):
return True
skip_line = False
for i, v in enumerate(filter):
if len(v) > 0 and (v.strip() != line[i + 1].strip()):
if v and v.strip() and (v.strip() != line[i + 1].strip()):
skip_line = True
break

Expand Down
176 changes: 175 additions & 1 deletion tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import casbin
import os
from unittest import TestCase
import casbin
from tests.test_enforcer import get_examples
from casbin.persist.adapters import FilteredFileAdapter
from casbin.persist.adapters.filtered_file_adapter import filter_line, filter_words
HashCookie marked this conversation as resolved.
Show resolved Hide resolved


class Filter:
Expand Down Expand Up @@ -141,3 +143,175 @@ def test_filtered_adapter_invalid_filepath(self):

with self.assertRaises(RuntimeError):
e.load_filtered_policy(None)

def test_empty_filter_array(self):
"""Test filter for empty array."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = []
filter.G = []

e.load_filtered_policy(filter)
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))

def test_empty_string_filter(self):
"""Test the filter for all empty strings."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = ["", "", ""]
filter.G = ["", "", ""]

e.load_filtered_policy(filter)
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))

def test_mixed_empty_filter(self):
"""Test the filter for mixed empty and non-empty strings."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = ["", "domain1", ""]
filter.G = ["", "", "domain1"]

e.load_filtered_policy(filter)
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
self.assertFalse(e.has_policy(["admin", "domain2", "data2", "read"]))

def test_nonexistent_domain_filter(self):
"""Testing the filter for a non-existent domain."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = ["", "domain3"]
filter.G = ["", "", "domain3"]

e.load_filtered_policy(filter)
self.assertFalse(e.has_policy(["admin", "domain3", "data3", "read"]))

def test_empty_filter_array(self):
"""Test filter for empty array."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = []
filter.G = []

try:
e.load_filtered_policy(filter)
except:
raise RuntimeError("unexpected error with empty filter arrays")

self.assertFalse(e.is_filtered(), "Adapter should not be marked as filtered with empty filters")

self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))

def test_empty_string_filter(self):
"""Test the filter for all empty strings."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = ["", "", ""]
filter.G = ["", "", ""]

try:
e.load_filtered_policy(filter)
except:
raise RuntimeError("unexpected error with empty string filters")

self.assertFalse(e.is_filtered(), "Adapter should not be marked as filtered with empty string filters")

try:
e.save_policy()
except:
raise RuntimeError("unexpected error in SavePolicy with empty string filters")

self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))

def test_mixed_empty_filter(self):
"""Test the filter for mixed empty and non-empty strings."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = ["", "domain1", ""]
filter.G = ["", "", "domain1"]

try:
e.load_filtered_policy(filter)
except:
raise RuntimeError("unexpected error with mixed empty filters")

self.assertTrue(e.is_filtered(), "Adapter should be marked as filtered")

with self.assertRaises(RuntimeError):
e.save_policy()

self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
self.assertFalse(e.has_policy(["admin", "domain2", "data2", "read"]))

def test_whitespace_filter(self):
"""Test the filter for all blank characters."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = [" ", " ", "\t"]
filter.G = ["\n", " ", " "]

e.load_filtered_policy(filter)

self.assertFalse(e.is_filtered())
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))

def test_filter_line_edge_cases(self):
"""Test the boundary cases of the filter_line function."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))

self.assertFalse(filter_line("", [[""], [""]]))

self.assertFalse(filter_line("invalid_line", [[""], [""]]))

self.assertFalse(filter_line("p, admin, domain1, data1, read", None))

def test_filter_words_edge_cases(self):
"""Test the boundary cases of the filter_words function."""
self.assertTrue(filter_words(["p"], ["filter1", "filter2"]))

self.assertFalse(filter_words(["p", "admin", "domain1"], []))

line = ["admin", "domain1", "data*", "read"]
filter = ["", "", "data1", ""]
self.assertTrue(filter_words(line, filter))

def test_load_filtered_policy_with_comments(self):
"""Test loading filtering policies with comments."""
import tempfile
import shutil

with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
with open(get_examples("rbac_with_domains_policy.csv"), "r") as source:
shutil.copyfileobj(source, temp_file)

temp_file.write("\n# This is a comment\np, admin, domain1, data3, read")
temp_file.flush()

temp_path = temp_file.name

try:
adapter = FilteredFileAdapter(temp_path)
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = ["", "domain1"]
filter.G = ["", "", "domain1"]

e.load_filtered_policy(filter)
self.assertTrue(e.has_policy(["admin", "domain1", "data3", "read"]))
finally:
try:
os.unlink(temp_path)
except OSError:
pass