Skip to content

Commit

Permalink
Use correct scopes for U2M and M2M flows
Browse files Browse the repository at this point in the history
Signed-off-by: Levko Kravets <[email protected]>
  • Loading branch information
kravets-levko committed Feb 4, 2024
1 parent d587eb2 commit 9f1c540
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 26 deletions.
28 changes: 22 additions & 6 deletions lib/connection/auth/DatabricksOAuth/OAuthManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ export default abstract class OAuthManager {

protected abstract getCallbackPorts(): Array<number>;

protected getScopes(requestedScopes: OAuthScopes): OAuthScopes {
return requestedScopes;
}
protected abstract getScopes(requestedScopes: OAuthScopes): OAuthScopes;

protected async getClient(): Promise<BaseClient> {
// Obtain http agent each time when we need an OAuth client
Expand Down Expand Up @@ -172,11 +170,10 @@ export default abstract class OAuthManager {
return new OAuthToken(accessToken, refreshToken, mappedScopes);
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
private async getTokenM2M(scopes: OAuthScopes): Promise<OAuthToken> {
const client = await this.getClient();

const mappedScopes = [OAuthScope.AllAPIs]; // this is the only allowed scope for M2M flow
const mappedScopes = this.getScopes(scopes);

// M2M flow doesn't really support token refreshing, and refresh should not be available
// in response. Each time access token expires, client can just acquire a new one using
Expand Down Expand Up @@ -262,6 +259,14 @@ export class DatabricksOAuthManager extends OAuthManager {
protected getCallbackPorts(): Array<number> {
return this.options.callbackPorts ?? DatabricksOAuthManager.defaultCallbackPorts;
}

protected getScopes(requestedScopes: OAuthScopes): OAuthScopes {
if (this.options.flow === OAuthFlow.M2M) {
// this is the only allowed scope for M2M flow
return [OAuthScope.allAPIs];
}
return requestedScopes;
}
}

export class AzureOAuthManager extends OAuthManager {
Expand Down Expand Up @@ -290,7 +295,18 @@ export class AzureOAuthManager extends OAuthManager {
protected getScopes(requestedScopes: OAuthScopes): OAuthScopes {
// There is no corresponding scopes in Azure, instead, access control will be delegated to Databricks
const tenantId = this.options.azureTenantId ?? AzureOAuthManager.datatricksAzureApp;
const azureScopes = [`${tenantId}/user_impersonation`];

const azureScopes = [];

switch (this.options.flow) {
case OAuthFlow.U2M:
azureScopes.push(`${tenantId}/user_impersonation`);
break;
case OAuthFlow.M2M:
azureScopes.push(`${tenantId}/.default`);
break;
// no default
}

if (requestedScopes.includes(OAuthScope.offlineAccess)) {
azureScopes.push(OAuthScope.offlineAccess);
Expand Down
2 changes: 1 addition & 1 deletion lib/connection/auth/DatabricksOAuth/OAuthScope.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export enum OAuthScope {
offlineAccess = 'offline_access',
SQL = 'sql',
AllAPIs = 'all-apis',
allAPIs = 'all-apis',
}

export type OAuthScopes = Array<string>;
Expand Down
94 changes: 75 additions & 19 deletions tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const {
OAuthFlow,
} = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthManager');
const OAuthToken = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthToken').default;
const { OAuthScope, scopeDelimiter } = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthScope');
const AuthorizationCodeModule = require('../../../../../dist/connection/auth/DatabricksOAuth/AuthorizationCode');

const { createValidAccessToken, createExpiredAccessToken } = require('./utils');
Expand All @@ -17,9 +18,13 @@ const logger = new DBSQLLogger({ level: LogLevel.error });
class AuthorizationCodeMock {
constructor() {
this.fetchResult = undefined;
this.expectedScope = undefined;
}

async fetch() {
async fetch(scopes) {
if (this.expectedScope) {
expect(scopes.join(scopeDelimiter)).to.be.equal(this.expectedScope);
}
return this.fetchResult;
}
}
Expand All @@ -35,6 +40,7 @@ class OAuthClientMock {
this.clientOptions = {};
this.expectedClientId = undefined;
this.expectedClientSecret = undefined;
this.expectedScope = undefined;

this.grantError = undefined;
this.refreshError = undefined;
Expand All @@ -61,6 +67,9 @@ class OAuthClientMock {
expect(params.code).to.be.equal(AuthorizationCodeMock.validCode.code);
expect(params.code_verifier).to.be.equal(AuthorizationCodeMock.validCode.verifier);
expect(params.redirect_uri).to.be.equal(AuthorizationCodeMock.validCode.redirectUri);
if (this.expectedScope) {
expect(params.scope).to.be.equal(this.expectedScope);
}

return {
access_token: this.accessToken,
Expand All @@ -76,7 +85,9 @@ class OAuthClientMock {
}

expect(params.grant_type).to.be.equal('client_credentials');
expect(params.scope).to.be.equal('all-apis');
if (this.expectedScope) {
expect(params.scope).to.be.equal(this.expectedScope);
}

return {
access_token: this.accessToken,
Expand Down Expand Up @@ -156,25 +167,40 @@ class OAuthClientMock {
});

describe('U2M flow', () => {
function getExpectedScope(scopes) {
switch (OAuthManagerClass) {
case DatabricksOAuthManager:
return [...scopes].join(scopeDelimiter);
case AzureOAuthManager:
const tenantId = AzureOAuthManager.datatricksAzureApp;
return [`${tenantId}/user_impersonation`, ...scopes].join(scopeDelimiter);
}
return undefined;
}

it('should get access token', async () => {
const { oauthManager, oauthClient } = prepareTestInstances({ flow: OAuthFlow.U2M });
const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M });
const requestedScopes = [OAuthScope.offlineAccess];
authCode.expectedScope = getExpectedScope(requestedScopes);

const token = await oauthManager.getToken(['offline_access']);
const token = await oauthManager.getToken(requestedScopes);
expect(oauthClient.grant.called).to.be.true;
expect(token).to.be.instanceOf(OAuthToken);
expect(token.accessToken).to.be.equal(oauthClient.accessToken);
expect(token.refreshToken).to.be.equal(oauthClient.refreshToken);
});

it('should throw an error if cannot get access token', async () => {
const { oauthManager, oauthClient } = prepareTestInstances({ flow: OAuthFlow.U2M });
const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M });
const requestedScopes = [OAuthScope.offlineAccess];
authCode.expectedScope = getExpectedScope(requestedScopes);

// Make it return empty tokens
oauthClient.accessToken = undefined;
oauthClient.refreshToken = undefined;

try {
await oauthManager.getToken([]);
await oauthManager.getToken(requestedScopes);
expect.fail('It should throw an error');
} catch (error) {
if (error instanceof AssertionError) {
Expand All @@ -186,13 +212,15 @@ class OAuthClientMock {
});

it('should re-throw unhandled errors when getting access token', async () => {
const { oauthManager, oauthClient } = prepareTestInstances({ flow: OAuthFlow.U2M });
const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M });
const requestedScopes = [];
authCode.expectedScope = getExpectedScope(requestedScopes);

const testError = new Error('Test');
oauthClient.grantError = testError;

try {
await oauthManager.getToken([]);
await oauthManager.getToken(requestedScopes);
expect.fail('It should throw an error');
} catch (error) {
if (error instanceof AssertionError) {
Expand All @@ -204,7 +232,8 @@ class OAuthClientMock {
});

it('should not refresh valid token', async () => {
const { oauthManager, oauthClient } = prepareTestInstances({ flow: OAuthFlow.U2M });
const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M });
authCode.expectedScope = getExpectedScope([]);

const token = new OAuthToken(createValidAccessToken(), oauthClient.refreshToken);
expect(token.hasExpired).to.be.false;
Expand All @@ -217,7 +246,8 @@ class OAuthClientMock {
});

it('should throw an error if no refresh token is available', async () => {
const { oauthManager, oauthClient } = prepareTestInstances({ flow: OAuthFlow.U2M });
const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M });
authCode.expectedScope = getExpectedScope([]);

try {
const token = new OAuthToken(createExpiredAccessToken());
Expand All @@ -235,7 +265,8 @@ class OAuthClientMock {
});

it('should throw an error for invalid token', async () => {
const { oauthManager, oauthClient } = prepareTestInstances({ flow: OAuthFlow.U2M });
const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M });
authCode.expectedScope = getExpectedScope([]);

try {
const token = new OAuthToken('invalid_access_token', 'invalid_refresh_token');
Expand All @@ -252,10 +283,12 @@ class OAuthClientMock {
});

it('should refresh expired token', async () => {
const { oauthManager, oauthClient } = prepareTestInstances({ flow: OAuthFlow.U2M });
const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M });
const requestedScopes = [OAuthScope.offlineAccess];
authCode.expectedScope = getExpectedScope(requestedScopes);

oauthClient.accessToken = createExpiredAccessToken();
const token = await oauthManager.getToken([]);
const token = await oauthManager.getToken(requestedScopes);
expect(token.hasExpired).to.be.true;

const newToken = await oauthManager.refreshAccessToken(token);
Expand All @@ -266,7 +299,8 @@ class OAuthClientMock {
});

it('should throw an error if cannot refresh token', async () => {
const { oauthManager, oauthClient } = prepareTestInstances({ flow: OAuthFlow.U2M });
const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M });
authCode.expectedScope = getExpectedScope([]);

oauthClient.refresh.restore();
sinon.stub(oauthClient, 'refresh').returns({});
Expand All @@ -288,14 +322,27 @@ class OAuthClientMock {
});

describe('M2M flow', () => {
function getExpectedScope(scopes) {
switch (OAuthManagerClass) {
case DatabricksOAuthManager:
return [OAuthScope.allAPIs].join(scopeDelimiter);
case AzureOAuthManager:
const tenantId = AzureOAuthManager.datatricksAzureApp;
return [`${tenantId}/.default`, ...scopes].join(scopeDelimiter);
}
return undefined;
}

it('should get access token', async () => {
const { oauthManager, oauthClient } = prepareTestInstances({
flow: OAuthFlow.M2M,
clientId: 'test_client_id',
clientSecret: 'test_client_secret',
});
const requestedScopes = [OAuthScope.offlineAccess];
oauthClient.expectedScope = getExpectedScope(requestedScopes);

const token = await oauthManager.getToken([]);
const token = await oauthManager.getToken(requestedScopes);
expect(oauthClient.grant.called).to.be.true;
expect(token).to.be.instanceOf(OAuthToken);
expect(token.accessToken).to.be.equal(oauthClient.accessToken);
Expand All @@ -308,13 +355,15 @@ class OAuthClientMock {
clientId: 'test_client_id',
clientSecret: 'test_client_secret',
});
const requestedScopes = [OAuthScope.offlineAccess];
oauthClient.expectedScope = getExpectedScope(requestedScopes);

// Make it return empty tokens
oauthClient.accessToken = undefined;
oauthClient.refreshToken = undefined;

try {
await oauthManager.getToken([]);
await oauthManager.getToken(requestedScopes);
expect.fail('It should throw an error');
} catch (error) {
if (error instanceof AssertionError) {
Expand All @@ -331,12 +380,14 @@ class OAuthClientMock {
clientId: 'test_client_id',
clientSecret: 'test_client_secret',
});
const requestedScopes = [];
oauthClient.expectedScope = getExpectedScope(requestedScopes);

const testError = new Error('Test');
oauthClient.grantError = testError;

try {
await oauthManager.getToken([]);
await oauthManager.getToken(requestedScopes);
expect.fail('It should throw an error');
} catch (error) {
if (error instanceof AssertionError) {
Expand All @@ -353,6 +404,7 @@ class OAuthClientMock {
clientId: 'test_client_id',
clientSecret: 'test_client_secret',
});
oauthClient.expectedScope = getExpectedScope([]);

const token = new OAuthToken(createValidAccessToken());
expect(token.hasExpired).to.be.false;
Expand All @@ -371,9 +423,11 @@ class OAuthClientMock {
clientId: 'test_client_id',
clientSecret: 'test_client_secret',
});
const requestedScopes = [OAuthScope.offlineAccess];
oauthClient.expectedScope = getExpectedScope(requestedScopes);

oauthClient.accessToken = createExpiredAccessToken();
const token = await oauthManager.getToken([]);
const token = await oauthManager.getToken(requestedScopes);
expect(token.hasExpired).to.be.true;

oauthClient.accessToken = createValidAccessToken();
Expand All @@ -391,9 +445,11 @@ class OAuthClientMock {
clientId: 'test_client_id',
clientSecret: 'test_client_secret',
});
const requestedScopes = [OAuthScope.offlineAccess];
oauthClient.expectedScope = getExpectedScope(requestedScopes);

oauthClient.accessToken = createExpiredAccessToken();
const token = await oauthManager.getToken([]);
const token = await oauthManager.getToken(requestedScopes);
expect(token.hasExpired).to.be.true;

oauthClient.grant.restore();
Expand Down

0 comments on commit 9f1c540

Please sign in to comment.