Skip to content

Commit

Permalink
COde style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ericof committed Dec 24, 2023
1 parent 1be4b37 commit 8b1f710
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 40 deletions.
33 changes: 17 additions & 16 deletions src/pas/plugins/oidc/plugins/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from plone.memoize import ram
from Products.PlonePAS.interfaces.group import IGroupIntrospection
from Products.PlonePAS.plugins.group import PloneGroup
from Products.PluggableAuthService.interfaces.plugins import IGroupEnumerationPlugin
from Products.PluggableAuthService.interfaces.plugins import IGroupsPlugin
from Products.PluggableAuthService.interfaces.plugins import IRolesPlugin
from Products.PluggableAuthService.interfaces import plugins
from Products.PluggableAuthService.permissions import ManageGroups
from Products.PluggableAuthService.plugins.BasePlugin import BasePlugin
from Products.PluggableAuthService.utils import classImplements
Expand Down Expand Up @@ -179,7 +177,7 @@ def enumerateGroups(
scaling issues for some implementations.
"""
default = ()
if not self.is_plugin_active(IGroupEnumerationPlugin):
if not self.is_plugin_active(plugins.IGroupEnumerationPlugin):
return default
groups = self._groups
if not groups:
Expand Down Expand Up @@ -212,7 +210,7 @@ def enumerateGroups(
@security.private
def getGroupsForPrincipal(self, principal, request=None) -> Tuple[str]:
"""See IGroupsPlugin."""
if not self.is_plugin_active(IGroupsPlugin):
if not self.is_plugin_active(plugins.IGroupsPlugin):
return tuple()
client = self.get_rest_api_client()
try:
Expand All @@ -233,7 +231,7 @@ def getGroupsForPrincipal(self, principal, request=None) -> Tuple[str]:
@security.protected(ManageGroups)
def listGroupIds(self) -> Tuple[str]:
"""-> (group_id_1, ... group_id_n)"""
if not self.is_plugin_active(IGroupsPlugin):
if not self.is_plugin_active(plugins.IGroupsPlugin):
return tuple()
return tuple(group_id for group_id in self._groups.keys())

Expand All @@ -245,7 +243,7 @@ def listGroupInfo(self) -> Tuple[dict]:
- 'id'
"""
if not self.is_plugin_active(IGroupsPlugin):
if not self.is_plugin_active(plugins.IGroupsPlugin):
return tuple()
return tuple(group_info for group_info in self._groups.values())

Expand All @@ -267,34 +265,34 @@ def _get_group_info(self, group_id: str) -> dict:
@security.protected(ManageGroups)
def getGroupInfo(self, group_id: str) -> Optional[dict]:
"""group_id -> dict"""
if not self.is_plugin_active(IGroupsPlugin):
if not self.is_plugin_active(plugins.IGroupsPlugin):
return None
group_info = self._get_group_info(group_id)
return group_info

def getGroupById(self, group_id: str) -> Optional[OIDCGroup]:
"""Return the portal_groupdata object for a group corresponding to this id."""
if not self.is_plugin_active(IGroupsPlugin):
if not self.is_plugin_active(plugins.IGroupsPlugin):
return None
group_info = self.getGroupInfo(group_id)
return self._wrap_group(group_info) if group_info else None

def getGroups(self) -> List[OIDCGroup]:
"""Return an iterator of the available groups."""
if not self.is_plugin_active(IGroupsPlugin):
if not self.is_plugin_active(plugins.IGroupsPlugin):
return []
return [self.getGroupById(group_id) for group_id in self.getGroupIds()]

def getGroupIds(self) -> List[str]:
"""Return a list of the available groups."""
if not self.is_plugin_active(IGroupsPlugin):
if not self.is_plugin_active(plugins.IGroupsPlugin):
return []
return [group_id for group_id in self._groups.keys()]

def getGroupMembers(self, group_id: str) -> Tuple[str]:
"""Return the members of a group with the given group_id."""
default = tuple()
if self.is_plugin_active(IGroupsPlugin) and group_id in self._groups:
if self.is_plugin_active(plugins.IGroupsPlugin) and group_id in self._groups:
client = self.get_rest_api_client()
try:
members = client.get_group_members(group_id=group_id)
Expand All @@ -317,7 +315,10 @@ def getRolesForPrincipal(self, principal, request=None) -> Tuple[str]:
"""
principal_id = principal.getId()
default = tuple()
if self.is_plugin_active(IGroupsPlugin) and principal_id in self._groups:
if (
self.is_plugin_active(plugins.IGroupsPlugin)
and principal_id in self._groups
):
group_info = self._get_group_info(principal_id)
if group_info:
return tuple(group_info.get("_roles", default))
Expand All @@ -330,8 +331,8 @@ def getRolesForPrincipal(self, principal, request=None) -> Tuple[str]:
classImplements(
KeycloakGroupsPlugin,
IKeycloakGroupsPlugin,
IGroupsPlugin,
IGroupIntrospection,
IGroupEnumerationPlugin,
IRolesPlugin,
plugins.IGroupsPlugin,
plugins.IGroupEnumerationPlugin,
plugins.IRolesPlugin,
)
17 changes: 8 additions & 9 deletions src/pas/plugins/oidc/subscribers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@ def keycloak_settings_modified(_: IKeycloakSettings, event: RecordModifiedEvent)
"""A setting in the keycloak group was modified."""
field_name = event.record.fieldName
if field_name == "enabled":
# Enable or Disable the plugin
pas = api.portal.get_tool("acl_users")
plugin_id = setup.PLUGIN_GROUP[1]
plugin = getattr(pas, plugin_id, None)
value = event.record.value
if plugin:
interfaces = setup.interfaces_for_plugin(pas, plugin_id)
if value:
# Activate the plugin
move_to_top = setup.PLUGIN_GROUP[-1]
for interface_name in interfaces:
_move_to_top = interface_name in move_to_top
setup.activate_plugin(pas, plugin_id, interface_name, _move_to_top)
else:
for interface_name in interfaces:
setup.deactivate_plugin(pas, plugin_id, interface_name)
move_to_top = setup.PLUGIN_GROUP[-1]
func = setup.activate_plugin if value else setup.deactivate_plugin
for interface_name in interfaces:
args = [pas, plugin_id, interface_name]
if value:
args.append(interface_name in move_to_top)
func(*args)
24 changes: 9 additions & 15 deletions src/pas/plugins/oidc/utils/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
def get_plugin(plugin_id: str, klass: Type) -> Union[KeycloakGroupsPlugin, OIDCPlugin]:
"""Check if plugin has the correct class."""
pas = api.portal.get_tool("acl_users")
plugin = getattr(pas, plugin_id)
if not isinstance(plugin, klass):
plugin = getattr(pas, plugin_id, None)
if not (plugin and isinstance(plugin, klass)):
logger.warning(f"PAS plugin {plugin_id} is not a {klass.__name__}.")
return None
return plugin
Expand Down Expand Up @@ -104,15 +104,12 @@ def add_pas_plugin(
) -> Union[KeycloakGroupsPlugin, OIDCPlugin]:
"""Add a new plugin to acl_users."""
pas = api.portal.get_tool("acl_users")
# Create plugin if it does not exist.
if plugin_id not in pas.objectIds():
plugin = get_plugin(plugin_id, klass)
if not plugin:
plugin = klass(title=title)
plugin.id = plugin_id
pas._setObject(plugin_id, plugin)
logger.info(f"Added {plugin_id} to acl_users.")
plugin = get_plugin(plugin_id, klass)
if not plugin:
raise ValueError(f"{plugin_id} is not an instance of {klass.__name__}")
if should_activate:
activate = interfaces_for_plugin(pas, plugin_id)
for interface_name in activate:
Expand All @@ -125,12 +122,9 @@ def remove_pas_plugin(klass: Type, plugin_id: str) -> bool:
"""Remove pas plugin from acl_users."""
pas = api.portal.get_tool("acl_users")
# Remove plugin if it exists.
if plugin_id not in pas.objectIds():
return False

plugin = get_plugin(plugin_id, klass)
if not plugin:
return False
pas._delObject(plugin_id)
logger.info(f"Removed {klass.__name__} {plugin_id} from acl_users.")
return True
if plugin:
pas._delObject(plugin_id)
logger.info(f"Removed {klass.__name__} {plugin_id} from acl_users.")
return True
return False

0 comments on commit 8b1f710

Please sign in to comment.