Skip to content

Commit

Permalink
🧹 improve the validation of the ms365 credentials during the connecti…
Browse files Browse the repository at this point in the history
…on (#3180)
  • Loading branch information
chris-rock authored Feb 1, 2024
1 parent 74ea320 commit 4bf15da
Show file tree
Hide file tree
Showing 15 changed files with 78 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package resources
package connection

import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/cockroachdb/errors"
"github.com/microsoft/kiota-abstractions-go/authentication"
a "github.com/microsoft/kiota-authentication-azure-go"
msgraphsdkgo "github.com/microsoftgraph/msgraph-sdk-go"
"go.mondoo.com/cnquery/v10/providers/ms365/connection"
)

var DefaultMSGraphScopes = []string{connection.DefaultMSGraphScope}
var DefaultMSGraphScopes = []string{DefaultMSGraphScope}

func newGraphRequestAdapterWithFn(providerFn func() (authentication.AuthenticationProvider, error)) (*msgraphsdkgo.GraphRequestAdapter, error) {
auth, err := providerFn()
Expand All @@ -22,9 +22,7 @@ func newGraphRequestAdapterWithFn(providerFn func() (authentication.Authenticati
return msgraphsdkgo.NewGraphRequestAdapter(auth)
}

func graphClient(conn *connection.Ms365Connection) (*msgraphsdkgo.GraphServiceClient, error) {
token := conn.Token()

func graphClient(token azcore.TokenCredential) (*msgraphsdkgo.GraphServiceClient, error) {
providerFunc := func() (authentication.AuthenticationProvider, error) {
return a.NewAzureIdentityAuthenticationProviderWithScopes(token, DefaultMSGraphScopes)
}
Expand All @@ -35,3 +33,7 @@ func graphClient(conn *connection.Ms365Connection) (*msgraphsdkgo.GraphServiceCl
graphClient := msgraphsdkgo.NewGraphServiceClient(adapter)
return graphClient, nil
}

func (conn *Ms365Connection) GraphClient() (*msgraphsdkgo.GraphServiceClient, error) {
return graphClient(conn.Token())
}
15 changes: 14 additions & 1 deletion providers/ms365/connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
package connection

import (
"context"
"fmt"
"github.com/cockroachdb/errors"
"runtime"
"sync"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/pkg/errors"
msgrapgh_org "github.com/microsoftgraph/msgraph-sdk-go/organization"
"go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory"
"go.mondoo.com/cnquery/v10/providers-sdk/v1/vault"
"go.mondoo.com/cnquery/v10/providers/os/connection/local"
Expand Down Expand Up @@ -59,6 +61,17 @@ func NewMs365Connection(id uint32, asset *inventory.Asset, conf *inventory.Confi
if err != nil {
return nil, errors.Wrap(err, "cannot fetch credentials for microsoft provider")
}

// test connection
client, err := graphClient(token)
if err != nil {
return nil, errors.Wrap(err, "authentication failed")
}
_, err = client.Organization().Get(context.Background(), &msgrapgh_org.OrganizationRequestBuilderGetRequestConfiguration{})
if err != nil {
return nil, errors.Wrap(err, "authentication failed")
}

return &Ms365Connection{
Conf: conf,
id: id,
Expand Down
37 changes: 37 additions & 0 deletions providers/ms365/connection/connection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

//go:build debugtest
// +build debugtest

package connection

import (
"github.com/stretchr/testify/require"
"go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory"
"go.mondoo.com/cnquery/v10/providers-sdk/v1/vault"
"os"
"testing"
)

func TestMs365(t *testing.T) {
cred := &vault.Credential{
Type: vault.CredentialType_pkcs12,
PrivateKeyPath: "/Users/chris/tmph5uvp4s4.pem",
}

data, err := os.ReadFile(cred.PrivateKeyPath)
require.NoError(t, err)
cred.Secret = data

conn, err := NewMs365Connection(0, &inventory.Asset{}, &inventory.Config{
Options: map[string]string{
OptionTenantID: "<tenant_id>",
OptionClientID: "<client_id>",
},
Credentials: []*vault.Credential{cred},
})
require.NoError(t, err)
require.NotNil(t, conn)

}
2 changes: 1 addition & 1 deletion providers/ms365/resources/applications.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (m *mqlMicrosoftApplication) id() (string, error) {

func (a *mqlMicrosoft) applications() ([]interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions providers/ms365/resources/devicemanagement.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (m *mqlMicrosoftDevicemanagementDevicecompliancepolicy) id() (string, error

func (a *mqlMicrosoftDevicemanagement) deviceConfigurations() ([]interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -59,7 +59,7 @@ func (a *mqlMicrosoftDevicemanagement) deviceConfigurations() ([]interface{}, er

func (a *mqlMicrosoftDevicemanagement) deviceCompliancePolicies() ([]interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions providers/ms365/resources/domains.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (m *mqlMicrosoftDomaindnsrecord) id() (string, error) {

func (a *mqlMicrosoft) domains() ([]interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -66,7 +66,7 @@ func (a *mqlMicrosoft) domains() ([]interface{}, error) {

func (a *mqlMicrosoftDomain) serviceConfigurationRecords() ([]interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion providers/ms365/resources/groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (a *mqlMicrosoftGroup) members() ([]interface{}, error) {

func (a *mqlMicrosoft) groups() ([]interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion providers/ms365/resources/microsoft.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

func (a *mqlMicrosoft) tenantDomainName() (string, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return "", err
}
Expand Down
2 changes: 1 addition & 1 deletion providers/ms365/resources/organizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (m *mqlMicrosoftOrganization) id() (string, error) {

func (a *mqlMicrosoft) organizations() ([]interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions providers/ms365/resources/policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

func (a *mqlMicrosoftPolicies) authorizationPolicy() (interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand All @@ -28,7 +28,7 @@ func (a *mqlMicrosoftPolicies) authorizationPolicy() (interface{}, error) {

func (a *mqlMicrosoftPolicies) identitySecurityDefaultsEnforcementPolicy() (interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand All @@ -44,7 +44,7 @@ func (a *mqlMicrosoftPolicies) identitySecurityDefaultsEnforcementPolicy() (inte
// https://docs.microsoft.com/en-us/graph/api/adminconsentrequestpolicy-get?view=graph-rest-
func (a *mqlMicrosoftPolicies) adminConsentRequestPolicy() (interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand All @@ -61,7 +61,7 @@ func (a *mqlMicrosoftPolicies) adminConsentRequestPolicy() (interface{}, error)
// https://docs.microsoft.com/en-us/graph/api/permissiongrantpolicy-list?view=graph-rest-1.0&tabs=http
func (a *mqlMicrosoftPolicies) permissionGrantPolicies() ([]interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions providers/ms365/resources/rolemanagement.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (m *mqlMicrosoftRolemanagementRoleassignment) id() (string, error) {

func (a *mqlMicrosoftRolemanagement) roleDefinitions() ([]interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -62,7 +62,7 @@ func (a *mqlMicrosoftRolemanagement) roleDefinitions() ([]interface{}, error) {

func (a *mqlMicrosoftRolemanagementRoledefinition) assignments() ([]interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion providers/ms365/resources/securescores.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (a *mqlMicrosoftSecurity) latestSecureScores() (*mqlMicrosoftSecuritySecuri
// see https://docs.microsoft.com/en-us/graph/api/securescore-get?view=graph-rest-1.0&tabs=http
func (a *mqlMicrosoftSecurity) secureScores() ([]interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion providers/ms365/resources/serviceprincipals.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (a *mqlMicrosoft) serviceprincipals() ([]interface{}, error) {
}

func fetchServicePrincipals(runtime *plugin.Runtime, conn *connection.Ms365Connection, params *serviceprincipals.ServicePrincipalsRequestBuilderGetQueryParameters) ([]interface{}, error) {
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
} // TODO: what if we have more than 1k SPs?
Expand Down
2 changes: 1 addition & 1 deletion providers/ms365/resources/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

func (a *mqlMicrosoft) settings() ([]interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions providers/ms365/resources/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (m *mqlMicrosoftUser) id() (string, error) {

func (a *mqlMicrosoft) users() ([]interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -78,7 +78,7 @@ func (a *mqlMicrosoft) users() ([]interface{}, error) {

func (a *mqlMicrosoftUser) settings() (interface{}, error) {
conn := a.MqlRuntime.Connection.(*connection.Ms365Connection)
graphClient, err := graphClient(conn)
graphClient, err := conn.GraphClient()
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 4bf15da

Please sign in to comment.