diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index cabcad15..620143f4 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -17,7 +17,7 @@ import Status from './dto/Status'; import HiveDriverError from './errors/HiveDriverError'; import { buildUserAgentString, definedOrError } from './utils'; import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication'; -import DatabricksOAuth from './connection/auth/DatabricksOAuth'; +import DatabricksOAuth, { OAuthFlow } from './connection/auth/DatabricksOAuth'; import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger'; import DBSQLLogger from './DBSQLLogger'; import CloseableCollection from './utils/CloseableCollection'; @@ -125,6 +125,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I }); case 'databricks-oauth': return new DatabricksOAuth({ + flow: options.oauthClientSecret === undefined ? OAuthFlow.U2M : OAuthFlow.M2M, host: options.host, persistence: options.persistence, azureTenantId: options.azureTenantId, diff --git a/lib/connection/auth/DatabricksOAuth/OAuthManager.ts b/lib/connection/auth/DatabricksOAuth/OAuthManager.ts index 3615a6a0..c1c41345 100644 --- a/lib/connection/auth/DatabricksOAuth/OAuthManager.ts +++ b/lib/connection/auth/DatabricksOAuth/OAuthManager.ts @@ -4,10 +4,16 @@ import HiveDriverError from '../../../errors/HiveDriverError'; import { LogLevel } from '../../../contracts/IDBSQLLogger'; import OAuthToken from './OAuthToken'; import AuthorizationCode from './AuthorizationCode'; -import { OAuthScope, OAuthScopes } from './OAuthScope'; +import { OAuthScope, OAuthScopes, scopeDelimiter } from './OAuthScope'; import IClientContext from '../../../contracts/IClientContext'; +export enum OAuthFlow { + U2M = 'U2M', + M2M = 'M2M', +} + export interface OAuthManagerOptions { + flow: OAuthFlow; host: string; callbackPorts?: Array; clientId?: string; @@ -47,9 +53,7 @@ export default abstract class OAuthManager { protected abstract getCallbackPorts(): Array; - protected getScopes(requestedScopes: OAuthScopes): OAuthScopes { - return requestedScopes; - } + protected abstract getScopes(requestedScopes: OAuthScopes): OAuthScopes; protected async getClient(): Promise { // Obtain http agent each time when we need an OAuth client @@ -113,17 +117,11 @@ export default abstract class OAuthManager { if (!accessToken || !refreshToken) { throw new Error('Failed to refresh token: invalid response'); } - return new OAuthToken(accessToken, refreshToken); + return new OAuthToken(accessToken, refreshToken, token.scopes); } - private async refreshAccessTokenM2M(): Promise { - const { access_token: accessToken, refresh_token: refreshToken } = await this.getTokenM2M(); - - if (!accessToken) { - throw new Error('Failed to fetch access token'); - } - - return new OAuthToken(accessToken, refreshToken); + private async refreshAccessTokenM2M(token: OAuthToken): Promise { + return this.getTokenM2M(token.scopes ?? []); } public async refreshAccessToken(token: OAuthToken): Promise { @@ -137,10 +135,16 @@ export default abstract class OAuthManager { throw error; } - return this.options.clientSecret === undefined ? this.refreshAccessTokenU2M(token) : this.refreshAccessTokenM2M(); + switch (this.options.flow) { + case OAuthFlow.U2M: + return this.refreshAccessTokenU2M(token); + case OAuthFlow.M2M: + return this.refreshAccessTokenM2M(token); + // no default + } } - private async getTokenU2M(scopes: OAuthScopes) { + private async getTokenU2M(scopes: OAuthScopes): Promise { const client = await this.getClient(); const authCode = new AuthorizationCode({ @@ -153,37 +157,47 @@ export default abstract class OAuthManager { const { code, verifier, redirectUri } = await authCode.fetch(mappedScopes); - return client.grant({ + const { access_token: accessToken, refresh_token: refreshToken } = await client.grant({ grant_type: 'authorization_code', code, code_verifier: verifier, redirect_uri: redirectUri, }); + + if (!accessToken) { + throw new Error('Failed to fetch access token'); + } + return new OAuthToken(accessToken, refreshToken, mappedScopes); } - private async getTokenM2M() { + private async getTokenM2M(scopes: OAuthScopes): Promise { const client = await this.getClient(); + const mappedScopes = this.getScopes(scopes); + // M2M flow doesn't really support token refreshing, and refresh should not be available // in response. Each time access token expires, client can just acquire a new one using // client secret. Here we explicitly return access token only as a sign that we're not going // to use refresh token for M2M flow anywhere later const { access_token: accessToken } = await client.grant({ grant_type: 'client_credentials', - scope: 'all-apis', // this is the only allowed scope for M2M flow + scope: mappedScopes.join(scopeDelimiter), }); - return { access_token: accessToken, refresh_token: undefined }; - } - - public async getToken(scopes: OAuthScopes): Promise { - const { access_token: accessToken, refresh_token: refreshToken } = - this.options.clientSecret === undefined ? await this.getTokenU2M(scopes) : await this.getTokenM2M(); if (!accessToken) { throw new Error('Failed to fetch access token'); } + return new OAuthToken(accessToken, undefined, mappedScopes); + } - return new OAuthToken(accessToken, refreshToken); + public async getToken(scopes: OAuthScopes): Promise { + switch (this.options.flow) { + case OAuthFlow.U2M: + return this.getTokenU2M(scopes); + case OAuthFlow.M2M: + return this.getTokenM2M(scopes); + // no default + } } public static getManager(options: OAuthManagerOptions): OAuthManager { @@ -245,6 +259,14 @@ export class DatabricksOAuthManager extends OAuthManager { protected getCallbackPorts(): Array { return this.options.callbackPorts ?? DatabricksOAuthManager.defaultCallbackPorts; } + + protected getScopes(requestedScopes: OAuthScopes): OAuthScopes { + if (this.options.flow === OAuthFlow.M2M) { + // this is the only allowed scope for M2M flow + return [OAuthScope.allAPIs]; + } + return requestedScopes; + } } export class AzureOAuthManager extends OAuthManager { @@ -273,7 +295,18 @@ export class AzureOAuthManager extends OAuthManager { protected getScopes(requestedScopes: OAuthScopes): OAuthScopes { // There is no corresponding scopes in Azure, instead, access control will be delegated to Databricks const tenantId = this.options.azureTenantId ?? AzureOAuthManager.datatricksAzureApp; - const azureScopes = [`${tenantId}/user_impersonation`]; + + const azureScopes = []; + + switch (this.options.flow) { + case OAuthFlow.U2M: + azureScopes.push(`${tenantId}/user_impersonation`); + break; + case OAuthFlow.M2M: + azureScopes.push(`${tenantId}/.default`); + break; + // no default + } if (requestedScopes.includes(OAuthScope.offlineAccess)) { azureScopes.push(OAuthScope.offlineAccess); diff --git a/lib/connection/auth/DatabricksOAuth/OAuthScope.ts b/lib/connection/auth/DatabricksOAuth/OAuthScope.ts index 3996c9a2..87c50350 100644 --- a/lib/connection/auth/DatabricksOAuth/OAuthScope.ts +++ b/lib/connection/auth/DatabricksOAuth/OAuthScope.ts @@ -1,6 +1,7 @@ export enum OAuthScope { offlineAccess = 'offline_access', SQL = 'sql', + allAPIs = 'all-apis', } export type OAuthScopes = Array; diff --git a/lib/connection/auth/DatabricksOAuth/OAuthToken.ts b/lib/connection/auth/DatabricksOAuth/OAuthToken.ts index e48b0d05..cccf04dc 100644 --- a/lib/connection/auth/DatabricksOAuth/OAuthToken.ts +++ b/lib/connection/auth/DatabricksOAuth/OAuthToken.ts @@ -1,13 +1,18 @@ +import { OAuthScopes } from './OAuthScope'; + export default class OAuthToken { private readonly _accessToken: string; private readonly _refreshToken?: string; + private readonly _scopes?: OAuthScopes; + private _expirationTime?: number; - constructor(accessToken: string, refreshToken?: string) { + constructor(accessToken: string, refreshToken?: string, scopes?: OAuthScopes) { this._accessToken = accessToken; this._refreshToken = refreshToken; + this._scopes = scopes; } get accessToken(): string { @@ -18,6 +23,10 @@ export default class OAuthToken { return this._refreshToken; } + get scopes(): OAuthScopes | undefined { + return this._scopes; + } + get expirationTime(): number { // This token has already been verified, and we are just parsing it. // If it has been tampered with, it will be rejected on the server side. diff --git a/lib/connection/auth/DatabricksOAuth/index.ts b/lib/connection/auth/DatabricksOAuth/index.ts index 37489974..faed4823 100644 --- a/lib/connection/auth/DatabricksOAuth/index.ts +++ b/lib/connection/auth/DatabricksOAuth/index.ts @@ -1,10 +1,12 @@ import { HeadersInit } from 'node-fetch'; import IAuthentication from '../../contracts/IAuthentication'; import OAuthPersistence, { OAuthPersistenceCache } from './OAuthPersistence'; -import OAuthManager, { OAuthManagerOptions } from './OAuthManager'; +import OAuthManager, { OAuthManagerOptions, OAuthFlow } from './OAuthManager'; import { OAuthScopes, defaultOAuthScopes } from './OAuthScope'; import IClientContext from '../../../contracts/IClientContext'; +export { OAuthFlow }; + interface DatabricksOAuthOptions extends OAuthManagerOptions { scopes?: OAuthScopes; persistence?: OAuthPersistence; diff --git a/tests/unit/DBSQLClient.test.js b/tests/unit/DBSQLClient.test.js index 8b530bc9..c12a64bf 100644 --- a/tests/unit/DBSQLClient.test.js +++ b/tests/unit/DBSQLClient.test.js @@ -344,6 +344,7 @@ describe('DBSQLClient.initAuthProvider', () => { authType: 'databricks-oauth', // host is used when creating OAuth manager, so make it look like a real AWS instance host: 'example.dev.databricks.com', + oauthClientSecret: 'test-secret', }); expect(provider).to.be.instanceOf(DatabricksOAuth); diff --git a/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js b/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js index aab82975..e3411cd1 100644 --- a/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js +++ b/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js @@ -5,8 +5,10 @@ const { DBSQLLogger, LogLevel } = require('../../../../../dist'); const { DatabricksOAuthManager, AzureOAuthManager, + OAuthFlow, } = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthManager'); const OAuthToken = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthToken').default; +const { OAuthScope, scopeDelimiter } = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthScope'); const AuthorizationCodeModule = require('../../../../../dist/connection/auth/DatabricksOAuth/AuthorizationCode'); const { createValidAccessToken, createExpiredAccessToken } = require('./utils'); @@ -16,9 +18,13 @@ const logger = new DBSQLLogger({ level: LogLevel.error }); class AuthorizationCodeMock { constructor() { this.fetchResult = undefined; + this.expectedScope = undefined; } - async fetch() { + async fetch(scopes) { + if (this.expectedScope) { + expect(scopes.join(scopeDelimiter)).to.be.equal(this.expectedScope); + } return this.fetchResult; } } @@ -34,6 +40,7 @@ class OAuthClientMock { this.clientOptions = {}; this.expectedClientId = undefined; this.expectedClientSecret = undefined; + this.expectedScope = undefined; this.grantError = undefined; this.refreshError = undefined; @@ -60,6 +67,9 @@ class OAuthClientMock { expect(params.code).to.be.equal(AuthorizationCodeMock.validCode.code); expect(params.code_verifier).to.be.equal(AuthorizationCodeMock.validCode.verifier); expect(params.redirect_uri).to.be.equal(AuthorizationCodeMock.validCode.redirectUri); + if (this.expectedScope) { + expect(params.scope).to.be.equal(this.expectedScope); + } return { access_token: this.accessToken, @@ -75,7 +85,9 @@ class OAuthClientMock { } expect(params.grant_type).to.be.equal('client_credentials'); - expect(params.scope).to.be.equal('all-apis'); + if (this.expectedScope) { + expect(params.scope).to.be.equal(this.expectedScope); + } return { access_token: this.accessToken, @@ -155,10 +167,23 @@ class OAuthClientMock { }); describe('U2M flow', () => { + function getExpectedScope(scopes) { + switch (OAuthManagerClass) { + case DatabricksOAuthManager: + return [...scopes].join(scopeDelimiter); + case AzureOAuthManager: + const tenantId = AzureOAuthManager.datatricksAzureApp; + return [`${tenantId}/user_impersonation`, ...scopes].join(scopeDelimiter); + } + return undefined; + } + it('should get access token', async () => { - const { oauthManager, oauthClient } = prepareTestInstances(); + const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M }); + const requestedScopes = [OAuthScope.offlineAccess]; + authCode.expectedScope = getExpectedScope(requestedScopes); - const token = await oauthManager.getToken(['offline_access']); + const token = await oauthManager.getToken(requestedScopes); expect(oauthClient.grant.called).to.be.true; expect(token).to.be.instanceOf(OAuthToken); expect(token.accessToken).to.be.equal(oauthClient.accessToken); @@ -166,14 +191,16 @@ class OAuthClientMock { }); it('should throw an error if cannot get access token', async () => { - const { oauthManager, oauthClient } = prepareTestInstances(); + const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M }); + const requestedScopes = [OAuthScope.offlineAccess]; + authCode.expectedScope = getExpectedScope(requestedScopes); // Make it return empty tokens oauthClient.accessToken = undefined; oauthClient.refreshToken = undefined; try { - await oauthManager.getToken([]); + await oauthManager.getToken(requestedScopes); expect.fail('It should throw an error'); } catch (error) { if (error instanceof AssertionError) { @@ -185,13 +212,15 @@ class OAuthClientMock { }); it('should re-throw unhandled errors when getting access token', async () => { - const { oauthManager, oauthClient } = prepareTestInstances(); + const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M }); + const requestedScopes = []; + authCode.expectedScope = getExpectedScope(requestedScopes); const testError = new Error('Test'); oauthClient.grantError = testError; try { - await oauthManager.getToken([]); + await oauthManager.getToken(requestedScopes); expect.fail('It should throw an error'); } catch (error) { if (error instanceof AssertionError) { @@ -203,7 +232,8 @@ class OAuthClientMock { }); it('should not refresh valid token', async () => { - const { oauthManager, oauthClient } = prepareTestInstances(); + const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M }); + authCode.expectedScope = getExpectedScope([]); const token = new OAuthToken(createValidAccessToken(), oauthClient.refreshToken); expect(token.hasExpired).to.be.false; @@ -216,7 +246,8 @@ class OAuthClientMock { }); it('should throw an error if no refresh token is available', async () => { - const { oauthManager, oauthClient } = prepareTestInstances(); + const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M }); + authCode.expectedScope = getExpectedScope([]); try { const token = new OAuthToken(createExpiredAccessToken()); @@ -234,7 +265,8 @@ class OAuthClientMock { }); it('should throw an error for invalid token', async () => { - const { oauthManager, oauthClient } = prepareTestInstances(); + const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M }); + authCode.expectedScope = getExpectedScope([]); try { const token = new OAuthToken('invalid_access_token', 'invalid_refresh_token'); @@ -251,10 +283,12 @@ class OAuthClientMock { }); it('should refresh expired token', async () => { - const { oauthManager, oauthClient } = prepareTestInstances(); + const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M }); + const requestedScopes = [OAuthScope.offlineAccess]; + authCode.expectedScope = getExpectedScope(requestedScopes); oauthClient.accessToken = createExpiredAccessToken(); - const token = await oauthManager.getToken([]); + const token = await oauthManager.getToken(requestedScopes); expect(token.hasExpired).to.be.true; const newToken = await oauthManager.refreshAccessToken(token); @@ -265,7 +299,8 @@ class OAuthClientMock { }); it('should throw an error if cannot refresh token', async () => { - const { oauthManager, oauthClient } = prepareTestInstances(); + const { oauthManager, oauthClient, authCode } = prepareTestInstances({ flow: OAuthFlow.U2M }); + authCode.expectedScope = getExpectedScope([]); oauthClient.refresh.restore(); sinon.stub(oauthClient, 'refresh').returns({}); @@ -287,14 +322,27 @@ class OAuthClientMock { }); describe('M2M flow', () => { + function getExpectedScope(scopes) { + switch (OAuthManagerClass) { + case DatabricksOAuthManager: + return [OAuthScope.allAPIs].join(scopeDelimiter); + case AzureOAuthManager: + const tenantId = AzureOAuthManager.datatricksAzureApp; + return [`${tenantId}/.default`, ...scopes].join(scopeDelimiter); + } + return undefined; + } + it('should get access token', async () => { const { oauthManager, oauthClient } = prepareTestInstances({ - // setup for M2M flow + flow: OAuthFlow.M2M, clientId: 'test_client_id', clientSecret: 'test_client_secret', }); + const requestedScopes = [OAuthScope.offlineAccess]; + oauthClient.expectedScope = getExpectedScope(requestedScopes); - const token = await oauthManager.getToken([]); + const token = await oauthManager.getToken(requestedScopes); expect(oauthClient.grant.called).to.be.true; expect(token).to.be.instanceOf(OAuthToken); expect(token.accessToken).to.be.equal(oauthClient.accessToken); @@ -303,17 +351,19 @@ class OAuthClientMock { it('should throw an error if cannot get access token', async () => { const { oauthManager, oauthClient } = prepareTestInstances({ - // setup for M2M flow + flow: OAuthFlow.M2M, clientId: 'test_client_id', clientSecret: 'test_client_secret', }); + const requestedScopes = [OAuthScope.offlineAccess]; + oauthClient.expectedScope = getExpectedScope(requestedScopes); // Make it return empty tokens oauthClient.accessToken = undefined; oauthClient.refreshToken = undefined; try { - await oauthManager.getToken([]); + await oauthManager.getToken(requestedScopes); expect.fail('It should throw an error'); } catch (error) { if (error instanceof AssertionError) { @@ -326,16 +376,18 @@ class OAuthClientMock { it('should re-throw unhandled errors when getting access token', async () => { const { oauthManager, oauthClient } = prepareTestInstances({ - // setup for M2M flow + flow: OAuthFlow.M2M, clientId: 'test_client_id', clientSecret: 'test_client_secret', }); + const requestedScopes = []; + oauthClient.expectedScope = getExpectedScope(requestedScopes); const testError = new Error('Test'); oauthClient.grantError = testError; try { - await oauthManager.getToken([]); + await oauthManager.getToken(requestedScopes); expect.fail('It should throw an error'); } catch (error) { if (error instanceof AssertionError) { @@ -348,10 +400,11 @@ class OAuthClientMock { it('should not refresh valid token', async () => { const { oauthManager, oauthClient } = prepareTestInstances({ - // setup for M2M flow + flow: OAuthFlow.M2M, clientId: 'test_client_id', clientSecret: 'test_client_secret', }); + oauthClient.expectedScope = getExpectedScope([]); const token = new OAuthToken(createValidAccessToken()); expect(token.hasExpired).to.be.false; @@ -366,13 +419,15 @@ class OAuthClientMock { it('should refresh expired token', async () => { const { oauthManager, oauthClient } = prepareTestInstances({ - // setup for M2M flow + flow: OAuthFlow.M2M, clientId: 'test_client_id', clientSecret: 'test_client_secret', }); + const requestedScopes = [OAuthScope.offlineAccess]; + oauthClient.expectedScope = getExpectedScope(requestedScopes); oauthClient.accessToken = createExpiredAccessToken(); - const token = await oauthManager.getToken([]); + const token = await oauthManager.getToken(requestedScopes); expect(token.hasExpired).to.be.true; oauthClient.accessToken = createValidAccessToken(); @@ -386,13 +441,15 @@ class OAuthClientMock { it('should throw an error if cannot refresh token', async () => { const { oauthManager, oauthClient } = prepareTestInstances({ - // setup for M2M flow + flow: OAuthFlow.M2M, clientId: 'test_client_id', clientSecret: 'test_client_secret', }); + const requestedScopes = [OAuthScope.offlineAccess]; + oauthClient.expectedScope = getExpectedScope(requestedScopes); oauthClient.accessToken = createExpiredAccessToken(); - const token = await oauthManager.getToken([]); + const token = await oauthManager.getToken(requestedScopes); expect(token.hasExpired).to.be.true; oauthClient.grant.restore(); diff --git a/tests/unit/connection/auth/DatabricksOAuth/OAuthToken.test.js b/tests/unit/connection/auth/DatabricksOAuth/OAuthToken.test.js index 1ab4f486..3e902051 100644 --- a/tests/unit/connection/auth/DatabricksOAuth/OAuthToken.test.js +++ b/tests/unit/connection/auth/DatabricksOAuth/OAuthToken.test.js @@ -7,6 +7,7 @@ describe('OAuthToken', () => { it('should be properly initialized', () => { const accessToken = 'access'; const refreshToken = 'refresh'; + const scopes = ['test']; const token1 = new OAuthToken(accessToken); expect(token1.accessToken).to.be.equal(accessToken); @@ -14,6 +15,11 @@ describe('OAuthToken', () => { const token2 = new OAuthToken(accessToken, refreshToken); expect(token2.accessToken).to.be.equal(accessToken); expect(token2.refreshToken).to.be.equal(refreshToken); + + const token3 = new OAuthToken(accessToken, refreshToken, scopes); + expect(token3.accessToken).to.be.equal(accessToken); + expect(token3.refreshToken).to.be.equal(refreshToken); + expect(token3.scopes).to.deep.equal(scopes); }); it('should return valid expiration time', () => {