Skip to content

Commit

Permalink
Merge pull request #6094 from ShanChathusanda93/role-permission-impr-…
Browse files Browse the repository at this point in the history
…branch

Improve role permission extraction when roles are shared and not shared
  • Loading branch information
ShanChathusanda93 authored Nov 28, 2024
2 parents 8c6b1d7 + 84bfd43 commit 6289d27
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,7 @@
import static org.wso2.carbon.identity.role.v2.mgt.core.dao.SQLQueries.GET_ROLE_SCOPE_SQL;
import static org.wso2.carbon.identity.role.v2.mgt.core.dao.SQLQueries.GET_ROLE_TENANT_DOMAIN_BY_ID;
import static org.wso2.carbon.identity.role.v2.mgt.core.dao.SQLQueries.GET_ROLE_UM_ID_BY_UUID;
import static org.wso2.carbon.identity.role.v2.mgt.core.dao.SQLQueries.GET_SCOPE_BY_ROLES_SQL;
import static org.wso2.carbon.identity.role.v2.mgt.core.dao.SQLQueries.GET_SHARED_HYBRID_ROLE_WITH_MAIN_ROLE_SQL;
import static org.wso2.carbon.identity.role.v2.mgt.core.dao.SQLQueries.GET_SHARED_ROLES_MAIN_ROLE_IDS_SQL;
import static org.wso2.carbon.identity.role.v2.mgt.core.dao.SQLQueries.GET_SHARED_ROLES_SQL;
import static org.wso2.carbon.identity.role.v2.mgt.core.dao.SQLQueries.GET_SHARED_ROLE_MAIN_ROLE_ID_SQL;
import static org.wso2.carbon.identity.role.v2.mgt.core.dao.SQLQueries.INSERT_MAIN_TO_SHARED_ROLE_RELATIONSHIP;
Expand Down Expand Up @@ -531,35 +529,19 @@ public List<Permission> getPermissionListOfRole(String roleId, String tenantDoma
public List<String> getPermissionListOfRoles(List<String> roleIds, String tenantDomain)
throws IdentityRoleManagementException {

if (isOrganization(tenantDomain)) {
return getPermissionsOfSharedRoles(roleIds, tenantDomain);
} else {
return getPermissionListOfRolesByIds(roleIds, tenantDomain);
}
}

private List<String> getPermissionListOfRolesByIds(List<String> roleIds, String tenantDomain)
throws IdentityRoleManagementException {

List<String> permissions = new ArrayList<>();
String query = GET_SCOPE_BY_ROLES_SQL + String.join(", ",
Collections.nCopies(roleIds.size(), "?")) + ")";
try (Connection connection = IdentityDatabaseUtil.getDBConnection(false);
NamedPreparedStatement statement = new NamedPreparedStatement(connection, query)) {

for (int i = 0; i < roleIds.size(); i++) {
statement.setString(i + 1, roleIds.get(i));
}
try (ResultSet resultSet = statement.executeQuery()) {
while (resultSet.next()) {
permissions.add(resultSet.getString(1));
}
List<Permission> permissionList = new ArrayList<>();
for (String roleId : roleIds) {
if (isOrganization(tenantDomain) && isSharedRole(roleId, tenantDomain)) {
permissionList.addAll(getPermissionsOfSharedRole(roleId, tenantDomain));
} else {
permissionList.addAll(getPermissions(roleId, tenantDomain));
}
} catch (SQLException e) {
String errorMessage =
"Error while retrieving permissions for role ids: " + StringUtils.join(roleIds, ", ")
+ " and tenantDomain : " + tenantDomain;
throw new IdentityRoleManagementServerException(UNEXPECTED_SERVER_ERROR.getCode(), errorMessage, e);
}

List<Permission> distinctPermissions = permissionList.stream().distinct().collect(Collectors.toList());
for (Permission permission : distinctPermissions) {
permissions.add(permission.getName());
}
return permissions;
}
Expand Down Expand Up @@ -1655,50 +1637,6 @@ private boolean isValidSubOrgPermission(String permission) {
(!permission.startsWith(INTERNAL_SCOPE_PREFIX) && !permission.startsWith(CONSOLE_SCOPE_PREFIX));
}

/**
* Get permission of shared roles.
*
* @param roleIds Role IDs.
* @param tenantDomain Tenant domain.
* @throws IdentityRoleManagementException IdentityRoleManagementException.
*/
private List<String> getPermissionsOfSharedRoles(List<String> roleIds, String tenantDomain)
throws IdentityRoleManagementException {

int tenantId = IdentityTenantUtil.getTenantId(tenantDomain);
List<String> mainRoleIds = new ArrayList<>();
int mainTenantId = -1;
String query = GET_SHARED_ROLES_MAIN_ROLE_IDS_SQL + String.join(", ",
Collections.nCopies(roleIds.size(), "?")) + ")";
try (Connection connection = IdentityDatabaseUtil.getUserDBConnection(false);
NamedPreparedStatement statement = new NamedPreparedStatement(connection, query)) {

statement.setInt(RoleConstants.RoleTableColumns.UM_TENANT_ID, tenantId);
for (int i = 0; i < roleIds.size(); i++) {
statement.setString(i + 2, roleIds.get(i));
}
try (ResultSet resultSet = statement.executeQuery()) {
while (resultSet.next()) {
mainRoleIds.add(resultSet.getString(RoleConstants.RoleTableColumns.UM_UUID));
if (mainTenantId == -1) {
mainTenantId = resultSet.getInt(RoleConstants.RoleTableColumns.UM_TENANT_ID);
}
}
}
if (!mainRoleIds.isEmpty() && mainTenantId != -1) {
String mainTenantDomain = IdentityTenantUtil.getTenantDomain(mainTenantId);
if (StringUtils.isNotEmpty(mainTenantDomain)) {
return getPermissionListOfRolesByIds(mainRoleIds, mainTenantDomain);
}
}
} catch (SQLException | IdentityRoleManagementException e) {
String errorMessage = "Error while retrieving permissions for role ids : "
+ StringUtils.join(roleIds, ",") + "in the tenantDomain: " + tenantDomain;
throw new IdentityRoleManagementServerException(errorMessage, e);
}
return null;
}

/**
* Delete application role association.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -102,6 +103,7 @@ public class RoleDAOTest {
private static final int SAMPLE_TENANT_ID = 1;
private static final String SAMPLE_TENANT_DOMAIN = "wso2.com";
private static final String SAMPLE_ORG_ID = "test-org-id";
private static final String SAMPLE_SUB_ORG_TENANT_DOMAIN = "wso2123.com";
private static final String SAMPLE_APP_ID = "test-app-id";
private static final String DB_NAME = "ROLE_DB";
private static final String ORGANIZATION_AUD = "organization";
Expand Down Expand Up @@ -376,6 +378,62 @@ public void testGetPermissionListOfRole() throws Exception {
Assert.assertEquals(getPermissionNameList(rolePermissions), getPermissionNameList(permissions));
}

@Test
public void testGetPermissionListOfRoles() throws Exception {

RoleDAOImpl roleDAO = spy(new RoleDAOImpl());
mockCacheClearing(roleDAO);
identityDatabaseUtil.when(() -> IdentityDatabaseUtil.getUserDBConnection(anyBoolean()))
.thenAnswer(invocation -> getConnection());
identityDatabaseUtil.when(() -> IdentityDatabaseUtil.getDBConnection(anyBoolean()))
.thenAnswer(invocation -> getConnection());
identityUtil.when(IdentityUtil::getPrimaryDomainName).thenReturn(USER_DOMAIN_PRIMARY);
identityUtil.when(() -> IdentityUtil.extractDomainFromName(anyString())).thenCallRealMethod();
identityTenantUtil.when(() -> IdentityTenantUtil.getTenantId(anyString())).thenReturn(SAMPLE_TENANT_ID);
userCoreUtil.when(() -> UserCoreUtil.isEveryoneRole(anyString(), any(RealmConfiguration.class)))
.thenReturn(false);
userCoreUtil.when(() -> UserCoreUtil.removeDomainFromName(anyString())).thenCallRealMethod();
RoleBasicInfo role = addRole(roleNamesList.get(0), APPLICATION_AUD, SAMPLE_APP_ID, roleDAO);
List<String> roleIds = Arrays.asList(role.getId());
List<String> rolePermissions = roleDAO.getPermissionListOfRoles(roleIds, SAMPLE_TENANT_DOMAIN);
Assert.assertEquals(rolePermissions, getPermissionNameList(permissions));
}

@Test
public void testGetPermissionListOfSharedRolesInSubOrganization() throws Exception {

RoleDAOImpl roleDAO = spy(new RoleDAOImpl());
mockCacheClearing(roleDAO);
identityDatabaseUtil.when(() -> IdentityDatabaseUtil.getUserDBConnection(anyBoolean())).
thenAnswer(invocation -> getConnection());
identityDatabaseUtil.when(() -> IdentityDatabaseUtil.getDBConnection(anyBoolean())).
thenAnswer(invocation -> getConnection());
identityUtil.when(IdentityUtil::getPrimaryDomainName).thenReturn(USER_DOMAIN_PRIMARY);
identityUtil.when(() -> IdentityUtil.extractDomainFromName(anyString())).thenCallRealMethod();
identityTenantUtil.when(() -> IdentityTenantUtil.getTenantId(anyString())).thenReturn(SAMPLE_TENANT_ID);
userCoreUtil.when(() -> UserCoreUtil.isEveryoneRole(anyString(), any(RealmConfiguration.class))).
thenReturn(false);
userCoreUtil.when(() -> UserCoreUtil.removeDomainFromName(anyString())).thenCallRealMethod();

// Constructing a shared role scenario
RoleBasicInfo roleBasicInfo = addRole("sharing-org-role-with-permission-001", APPLICATION_AUD,
SAMPLE_APP_ID, roleDAO);
identityTenantUtil.when(() -> IdentityTenantUtil.getTenantId(SAMPLE_SUB_ORG_TENANT_DOMAIN)).thenReturn(2);
identityTenantUtil.when(() -> IdentityTenantUtil.getTenantDomain(2)).thenReturn(SAMPLE_SUB_ORG_TENANT_DOMAIN);
OrganizationManager organizationManager = mock(OrganizationManager.class);
lenient().when(organizationManager.resolveOrganizationId(anyString())).thenReturn(SAMPLE_ORG_ID);
RoleBasicInfo sharedRoleBasicInfo = addRole("sharing-org-role-with-permission-001", APPLICATION_AUD,
"test-app-id-3", SAMPLE_SUB_ORG_TENANT_DOMAIN, new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
new ArrayList<>(), new HashMap<>(), roleDAO, true);
roleDAO.addMainRoleToSharedRoleRelationship(roleBasicInfo.getId(), sharedRoleBasicInfo.getId(),
SAMPLE_TENANT_DOMAIN, SAMPLE_SUB_ORG_TENANT_DOMAIN);
identityTenantUtil.when(() -> IdentityTenantUtil.getTenantDomain(1)).thenReturn(SAMPLE_TENANT_DOMAIN);

List<String> roleIds = Arrays.asList(sharedRoleBasicInfo.getId());
List<String> rolePermissions = roleDAO.getPermissionListOfRoles(roleIds, SAMPLE_SUB_ORG_TENANT_DOMAIN);
Assert.assertEquals(rolePermissions, getPermissionNameList(permissions));
}

@Test
public void testUpdatePermissionListOfRole() throws Exception {

Expand Down Expand Up @@ -822,6 +880,40 @@ private RoleBasicInfo addRole(String roleName, String audience, String audienceI
}
}

private RoleBasicInfo addRole(String roleName, String audience, String audienceId, String tenantDomain,
List<Permission> permissions, List<String> userIDsList, List<String> userNamesList,
List<String> groupIDsList, Map<String, String> groupNamesMap, RoleDAOImpl roleDAO,
boolean isOrganization) throws Exception {

OrganizationManager organizationManager = mock(OrganizationManager.class);
RoleManagementServiceComponentHolder.getInstance().setOrganizationManager(organizationManager);
lenient().when(organizationManager.getOrganizationNameById(anyString())).thenReturn("test-org");
lenient().when(organizationManager.resolveOrganizationId(anyString())).thenReturn(tenantDomain);
organizationManagementUtil.when(() -> OrganizationManagementUtil.isOrganization(anyString())).
thenReturn(isOrganization);
UserIDResolver userIDResolver = mock(UserIDResolver.class);
setPrivateFinalField(RoleDAOImpl.class, "userIDResolver", roleDAO, userIDResolver);
when(userIDResolver.getNamesByIDs(anyList(), anyString())).thenReturn(userNamesList);
if (!userIDsList.isEmpty()) {
lenient().when(userIDResolver.getNameByID(eq(userIDsList.get(0)), anyString()))
.thenReturn(userNamesList.get(0));
lenient().when(userIDResolver.getNameByID(eq(userIDsList.get(1)), anyString()))
.thenReturn(userNamesList.get(1));
}

GroupIDResolver groupIDResolver = mock(GroupIDResolver.class);
setPrivateFinalField(RoleDAOImpl.class, "groupIDResolver", roleDAO, groupIDResolver);
when(groupIDResolver.getNamesByIDs(anyList(), anyString())).thenReturn(groupNamesMap);

if ("everyone".equals(roleName)) {
return roleDAO.addRole(roleName, new ArrayList<>(), new ArrayList<>(), new ArrayList<>(), audience,
audienceId, tenantDomain);
} else {
return roleDAO.addRole(roleName, userIDsList, groupIDsList, permissions, audience, audienceId,
tenantDomain);
}
}

private void mockCacheClearing(RoleDAOImpl roleDAO) throws Exception {

UserRolesCache mockUserRolesCache = mock(UserRolesCache.class);
Expand Down Expand Up @@ -912,7 +1004,8 @@ private void populateData() throws Exception {
", (4,1,'update','update',1,'update')";
String spAppSQL = "INSERT INTO SP_APP (ID, TENANT_ID, APP_NAME, USER_STORE, USERNAME, AUTH_TYPE, UUID) " +
"VALUES (1, 1, 'TEST_APP_NAME','TEST_USER_STORE', 'TEST_USERNAME', 'TEST_AUTH_TYPE', 'test-app-id'), " +
"(2, 1, 'TEST_APP_NAME2','TEST_USER_STORE', 'TEST_USERNAME', 'TEST_AUTH_TYPE', 'test-app-id-2')";
"(2, 1, 'TEST_APP_NAME2','TEST_USER_STORE', 'TEST_USERNAME', 'TEST_AUTH_TYPE', 'test-app-id-2'), " +
"(3, 2, 'TEST_APP_NAME','TEST_USER_STORE', 'TEST_USERNAME', 'TEST_AUTH_TYPE', 'test-app-id-3')";
String idpSQL = "INSERT INTO IDP (ID, TENANT_ID, NAME, UUID) VALUES (1, 1, 'TEST_IDP_NAME', 'test-idp-id');";
String idpGroupSQL = "INSERT INTO IDP_GROUP (ID, IDP_ID, TENANT_ID, GROUP_NAME, UUID) VALUES " +
"(1, 1, 1, 'group1', 'test-group1-id'), (2, 1, 1, 'group2', 'test-group2-id');";
Expand Down

0 comments on commit 6289d27

Please sign in to comment.