From 8b1f710f0ca33068f778b348f1acfe366bcdbbd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89rico=20Andrei?= Date: Sun, 24 Dec 2023 13:10:14 -0300 Subject: [PATCH] COde style fixes --- src/pas/plugins/oidc/plugins/group.py | 33 ++++++++++---------- src/pas/plugins/oidc/subscribers/settings.py | 17 +++++----- src/pas/plugins/oidc/utils/setup.py | 24 ++++++-------- 3 files changed, 34 insertions(+), 40 deletions(-) diff --git a/src/pas/plugins/oidc/plugins/group.py b/src/pas/plugins/oidc/plugins/group.py index e72ff54..a5d2d2a 100644 --- a/src/pas/plugins/oidc/plugins/group.py +++ b/src/pas/plugins/oidc/plugins/group.py @@ -8,9 +8,7 @@ from plone.memoize import ram from Products.PlonePAS.interfaces.group import IGroupIntrospection from Products.PlonePAS.plugins.group import PloneGroup -from Products.PluggableAuthService.interfaces.plugins import IGroupEnumerationPlugin -from Products.PluggableAuthService.interfaces.plugins import IGroupsPlugin -from Products.PluggableAuthService.interfaces.plugins import IRolesPlugin +from Products.PluggableAuthService.interfaces import plugins from Products.PluggableAuthService.permissions import ManageGroups from Products.PluggableAuthService.plugins.BasePlugin import BasePlugin from Products.PluggableAuthService.utils import classImplements @@ -179,7 +177,7 @@ def enumerateGroups( scaling issues for some implementations. """ default = () - if not self.is_plugin_active(IGroupEnumerationPlugin): + if not self.is_plugin_active(plugins.IGroupEnumerationPlugin): return default groups = self._groups if not groups: @@ -212,7 +210,7 @@ def enumerateGroups( @security.private def getGroupsForPrincipal(self, principal, request=None) -> Tuple[str]: """See IGroupsPlugin.""" - if not self.is_plugin_active(IGroupsPlugin): + if not self.is_plugin_active(plugins.IGroupsPlugin): return tuple() client = self.get_rest_api_client() try: @@ -233,7 +231,7 @@ def getGroupsForPrincipal(self, principal, request=None) -> Tuple[str]: @security.protected(ManageGroups) def listGroupIds(self) -> Tuple[str]: """-> (group_id_1, ... group_id_n)""" - if not self.is_plugin_active(IGroupsPlugin): + if not self.is_plugin_active(plugins.IGroupsPlugin): return tuple() return tuple(group_id for group_id in self._groups.keys()) @@ -245,7 +243,7 @@ def listGroupInfo(self) -> Tuple[dict]: - 'id' """ - if not self.is_plugin_active(IGroupsPlugin): + if not self.is_plugin_active(plugins.IGroupsPlugin): return tuple() return tuple(group_info for group_info in self._groups.values()) @@ -267,34 +265,34 @@ def _get_group_info(self, group_id: str) -> dict: @security.protected(ManageGroups) def getGroupInfo(self, group_id: str) -> Optional[dict]: """group_id -> dict""" - if not self.is_plugin_active(IGroupsPlugin): + if not self.is_plugin_active(plugins.IGroupsPlugin): return None group_info = self._get_group_info(group_id) return group_info def getGroupById(self, group_id: str) -> Optional[OIDCGroup]: """Return the portal_groupdata object for a group corresponding to this id.""" - if not self.is_plugin_active(IGroupsPlugin): + if not self.is_plugin_active(plugins.IGroupsPlugin): return None group_info = self.getGroupInfo(group_id) return self._wrap_group(group_info) if group_info else None def getGroups(self) -> List[OIDCGroup]: """Return an iterator of the available groups.""" - if not self.is_plugin_active(IGroupsPlugin): + if not self.is_plugin_active(plugins.IGroupsPlugin): return [] return [self.getGroupById(group_id) for group_id in self.getGroupIds()] def getGroupIds(self) -> List[str]: """Return a list of the available groups.""" - if not self.is_plugin_active(IGroupsPlugin): + if not self.is_plugin_active(plugins.IGroupsPlugin): return [] return [group_id for group_id in self._groups.keys()] def getGroupMembers(self, group_id: str) -> Tuple[str]: """Return the members of a group with the given group_id.""" default = tuple() - if self.is_plugin_active(IGroupsPlugin) and group_id in self._groups: + if self.is_plugin_active(plugins.IGroupsPlugin) and group_id in self._groups: client = self.get_rest_api_client() try: members = client.get_group_members(group_id=group_id) @@ -317,7 +315,10 @@ def getRolesForPrincipal(self, principal, request=None) -> Tuple[str]: """ principal_id = principal.getId() default = tuple() - if self.is_plugin_active(IGroupsPlugin) and principal_id in self._groups: + if ( + self.is_plugin_active(plugins.IGroupsPlugin) + and principal_id in self._groups + ): group_info = self._get_group_info(principal_id) if group_info: return tuple(group_info.get("_roles", default)) @@ -330,8 +331,8 @@ def getRolesForPrincipal(self, principal, request=None) -> Tuple[str]: classImplements( KeycloakGroupsPlugin, IKeycloakGroupsPlugin, - IGroupsPlugin, IGroupIntrospection, - IGroupEnumerationPlugin, - IRolesPlugin, + plugins.IGroupsPlugin, + plugins.IGroupEnumerationPlugin, + plugins.IRolesPlugin, ) diff --git a/src/pas/plugins/oidc/subscribers/settings.py b/src/pas/plugins/oidc/subscribers/settings.py index 49f7321..8109011 100644 --- a/src/pas/plugins/oidc/subscribers/settings.py +++ b/src/pas/plugins/oidc/subscribers/settings.py @@ -8,18 +8,17 @@ def keycloak_settings_modified(_: IKeycloakSettings, event: RecordModifiedEvent) """A setting in the keycloak group was modified.""" field_name = event.record.fieldName if field_name == "enabled": + # Enable or Disable the plugin pas = api.portal.get_tool("acl_users") plugin_id = setup.PLUGIN_GROUP[1] plugin = getattr(pas, plugin_id, None) value = event.record.value if plugin: interfaces = setup.interfaces_for_plugin(pas, plugin_id) - if value: - # Activate the plugin - move_to_top = setup.PLUGIN_GROUP[-1] - for interface_name in interfaces: - _move_to_top = interface_name in move_to_top - setup.activate_plugin(pas, plugin_id, interface_name, _move_to_top) - else: - for interface_name in interfaces: - setup.deactivate_plugin(pas, plugin_id, interface_name) + move_to_top = setup.PLUGIN_GROUP[-1] + func = setup.activate_plugin if value else setup.deactivate_plugin + for interface_name in interfaces: + args = [pas, plugin_id, interface_name] + if value: + args.append(interface_name in move_to_top) + func(*args) diff --git a/src/pas/plugins/oidc/utils/setup.py b/src/pas/plugins/oidc/utils/setup.py index a9a1bf1..06589b1 100644 --- a/src/pas/plugins/oidc/utils/setup.py +++ b/src/pas/plugins/oidc/utils/setup.py @@ -39,8 +39,8 @@ def get_plugin(plugin_id: str, klass: Type) -> Union[KeycloakGroupsPlugin, OIDCPlugin]: """Check if plugin has the correct class.""" pas = api.portal.get_tool("acl_users") - plugin = getattr(pas, plugin_id) - if not isinstance(plugin, klass): + plugin = getattr(pas, plugin_id, None) + if not (plugin and isinstance(plugin, klass)): logger.warning(f"PAS plugin {plugin_id} is not a {klass.__name__}.") return None return plugin @@ -104,15 +104,12 @@ def add_pas_plugin( ) -> Union[KeycloakGroupsPlugin, OIDCPlugin]: """Add a new plugin to acl_users.""" pas = api.portal.get_tool("acl_users") - # Create plugin if it does not exist. - if plugin_id not in pas.objectIds(): + plugin = get_plugin(plugin_id, klass) + if not plugin: plugin = klass(title=title) plugin.id = plugin_id pas._setObject(plugin_id, plugin) logger.info(f"Added {plugin_id} to acl_users.") - plugin = get_plugin(plugin_id, klass) - if not plugin: - raise ValueError(f"{plugin_id} is not an instance of {klass.__name__}") if should_activate: activate = interfaces_for_plugin(pas, plugin_id) for interface_name in activate: @@ -125,12 +122,9 @@ def remove_pas_plugin(klass: Type, plugin_id: str) -> bool: """Remove pas plugin from acl_users.""" pas = api.portal.get_tool("acl_users") # Remove plugin if it exists. - if plugin_id not in pas.objectIds(): - return False - plugin = get_plugin(plugin_id, klass) - if not plugin: - return False - pas._delObject(plugin_id) - logger.info(f"Removed {klass.__name__} {plugin_id} from acl_users.") - return True + if plugin: + pas._delObject(plugin_id) + logger.info(f"Removed {klass.__name__} {plugin_id} from acl_users.") + return True + return False