Skip to content

Commit

Permalink
Use correct scopes for OAuth U2M and M2M flows (#228)
Browse files Browse the repository at this point in the history
* Refactor OAuthManager: explicitly define which flow to use

Signed-off-by: Levko Kravets <[email protected]>

* Refactoring: when refreshing OAuth token, use same scopes as when getting the one

Signed-off-by: Levko Kravets <[email protected]>

* Use correct scopes for U2M and M2M flows

Signed-off-by: Levko Kravets <[email protected]>

* Tests

Signed-off-by: Levko Kravets <[email protected]>

---------

Signed-off-by: Levko Kravets <[email protected]>
  • Loading branch information
kravets-levko authored Feb 5, 2024
1 parent 3953e5d commit 957791b
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 54 deletions.
3 changes: 2 additions & 1 deletion lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Status from './dto/Status';
import HiveDriverError from './errors/HiveDriverError';
import { buildUserAgentString, definedOrError } from './utils';
import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication';
import DatabricksOAuth from './connection/auth/DatabricksOAuth';
import DatabricksOAuth, { OAuthFlow } from './connection/auth/DatabricksOAuth';
import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger';
import DBSQLLogger from './DBSQLLogger';
import CloseableCollection from './utils/CloseableCollection';
Expand Down Expand Up @@ -125,6 +125,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
});
case 'databricks-oauth':
return new DatabricksOAuth({
flow: options.oauthClientSecret === undefined ? OAuthFlow.U2M : OAuthFlow.M2M,
host: options.host,
persistence: options.persistence,
azureTenantId: options.azureTenantId,
Expand Down
85 changes: 59 additions & 26 deletions lib/connection/auth/DatabricksOAuth/OAuthManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@ import HiveDriverError from '../../../errors/HiveDriverError';
import { LogLevel } from '../../../contracts/IDBSQLLogger';
import OAuthToken from './OAuthToken';
import AuthorizationCode from './AuthorizationCode';
import { OAuthScope, OAuthScopes } from './OAuthScope';
import { OAuthScope, OAuthScopes, scopeDelimiter } from './OAuthScope';
import IClientContext from '../../../contracts/IClientContext';

export enum OAuthFlow {
U2M = 'U2M',
M2M = 'M2M',
}

export interface OAuthManagerOptions {
flow: OAuthFlow;
host: string;
callbackPorts?: Array<number>;
clientId?: string;
Expand Down Expand Up @@ -47,9 +53,7 @@ export default abstract class OAuthManager {

protected abstract getCallbackPorts(): Array<number>;

protected getScopes(requestedScopes: OAuthScopes): OAuthScopes {
return requestedScopes;
}
protected abstract getScopes(requestedScopes: OAuthScopes): OAuthScopes;

protected async getClient(): Promise<BaseClient> {
// Obtain http agent each time when we need an OAuth client
Expand Down Expand Up @@ -113,17 +117,11 @@ export default abstract class OAuthManager {
if (!accessToken || !refreshToken) {
throw new Error('Failed to refresh token: invalid response');
}
return new OAuthToken(accessToken, refreshToken);
return new OAuthToken(accessToken, refreshToken, token.scopes);
}

private async refreshAccessTokenM2M(): Promise<OAuthToken> {
const { access_token: accessToken, refresh_token: refreshToken } = await this.getTokenM2M();

if (!accessToken) {
throw new Error('Failed to fetch access token');
}

return new OAuthToken(accessToken, refreshToken);
private async refreshAccessTokenM2M(token: OAuthToken): Promise<OAuthToken> {
return this.getTokenM2M(token.scopes ?? []);
}

public async refreshAccessToken(token: OAuthToken): Promise<OAuthToken> {
Expand All @@ -137,10 +135,16 @@ export default abstract class OAuthManager {
throw error;
}

return this.options.clientSecret === undefined ? this.refreshAccessTokenU2M(token) : this.refreshAccessTokenM2M();
switch (this.options.flow) {
case OAuthFlow.U2M:
return this.refreshAccessTokenU2M(token);
case OAuthFlow.M2M:
return this.refreshAccessTokenM2M(token);
// no default
}
}

private async getTokenU2M(scopes: OAuthScopes) {
private async getTokenU2M(scopes: OAuthScopes): Promise<OAuthToken> {
const client = await this.getClient();

const authCode = new AuthorizationCode({
Expand All @@ -153,37 +157,47 @@ export default abstract class OAuthManager {

const { code, verifier, redirectUri } = await authCode.fetch(mappedScopes);

return client.grant({
const { access_token: accessToken, refresh_token: refreshToken } = await client.grant({
grant_type: 'authorization_code',
code,
code_verifier: verifier,
redirect_uri: redirectUri,
});

if (!accessToken) {
throw new Error('Failed to fetch access token');
}
return new OAuthToken(accessToken, refreshToken, mappedScopes);
}

private async getTokenM2M() {
private async getTokenM2M(scopes: OAuthScopes): Promise<OAuthToken> {
const client = await this.getClient();

const mappedScopes = this.getScopes(scopes);

// M2M flow doesn't really support token refreshing, and refresh should not be available
// in response. Each time access token expires, client can just acquire a new one using
// client secret. Here we explicitly return access token only as a sign that we're not going
// to use refresh token for M2M flow anywhere later
const { access_token: accessToken } = await client.grant({
grant_type: 'client_credentials',
scope: 'all-apis', // this is the only allowed scope for M2M flow
scope: mappedScopes.join(scopeDelimiter),
});
return { access_token: accessToken, refresh_token: undefined };
}

public async getToken(scopes: OAuthScopes): Promise<OAuthToken> {
const { access_token: accessToken, refresh_token: refreshToken } =
this.options.clientSecret === undefined ? await this.getTokenU2M(scopes) : await this.getTokenM2M();

if (!accessToken) {
throw new Error('Failed to fetch access token');
}
return new OAuthToken(accessToken, undefined, mappedScopes);
}

return new OAuthToken(accessToken, refreshToken);
public async getToken(scopes: OAuthScopes): Promise<OAuthToken> {
switch (this.options.flow) {
case OAuthFlow.U2M:
return this.getTokenU2M(scopes);
case OAuthFlow.M2M:
return this.getTokenM2M(scopes);
// no default
}
}

public static getManager(options: OAuthManagerOptions): OAuthManager {
Expand Down Expand Up @@ -245,6 +259,14 @@ export class DatabricksOAuthManager extends OAuthManager {
protected getCallbackPorts(): Array<number> {
return this.options.callbackPorts ?? DatabricksOAuthManager.defaultCallbackPorts;
}

protected getScopes(requestedScopes: OAuthScopes): OAuthScopes {
if (this.options.flow === OAuthFlow.M2M) {
// this is the only allowed scope for M2M flow
return [OAuthScope.allAPIs];
}
return requestedScopes;
}
}

export class AzureOAuthManager extends OAuthManager {
Expand Down Expand Up @@ -273,7 +295,18 @@ export class AzureOAuthManager extends OAuthManager {
protected getScopes(requestedScopes: OAuthScopes): OAuthScopes {
// There is no corresponding scopes in Azure, instead, access control will be delegated to Databricks
const tenantId = this.options.azureTenantId ?? AzureOAuthManager.datatricksAzureApp;
const azureScopes = [`${tenantId}/user_impersonation`];

const azureScopes = [];

switch (this.options.flow) {
case OAuthFlow.U2M:
azureScopes.push(`${tenantId}/user_impersonation`);
break;
case OAuthFlow.M2M:
azureScopes.push(`${tenantId}/.default`);
break;
// no default
}

if (requestedScopes.includes(OAuthScope.offlineAccess)) {
azureScopes.push(OAuthScope.offlineAccess);
Expand Down
1 change: 1 addition & 0 deletions lib/connection/auth/DatabricksOAuth/OAuthScope.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export enum OAuthScope {
offlineAccess = 'offline_access',
SQL = 'sql',
allAPIs = 'all-apis',
}

export type OAuthScopes = Array<string>;
Expand Down
11 changes: 10 additions & 1 deletion lib/connection/auth/DatabricksOAuth/OAuthToken.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import { OAuthScopes } from './OAuthScope';

export default class OAuthToken {
private readonly _accessToken: string;

private readonly _refreshToken?: string;

private readonly _scopes?: OAuthScopes;

private _expirationTime?: number;

constructor(accessToken: string, refreshToken?: string) {
constructor(accessToken: string, refreshToken?: string, scopes?: OAuthScopes) {
this._accessToken = accessToken;
this._refreshToken = refreshToken;
this._scopes = scopes;
}

get accessToken(): string {
Expand All @@ -18,6 +23,10 @@ export default class OAuthToken {
return this._refreshToken;
}

get scopes(): OAuthScopes | undefined {
return this._scopes;
}

get expirationTime(): number {
// This token has already been verified, and we are just parsing it.
// If it has been tampered with, it will be rejected on the server side.
Expand Down
4 changes: 3 additions & 1 deletion lib/connection/auth/DatabricksOAuth/index.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import { HeadersInit } from 'node-fetch';
import IAuthentication from '../../contracts/IAuthentication';
import OAuthPersistence, { OAuthPersistenceCache } from './OAuthPersistence';
import OAuthManager, { OAuthManagerOptions } from './OAuthManager';
import OAuthManager, { OAuthManagerOptions, OAuthFlow } from './OAuthManager';
import { OAuthScopes, defaultOAuthScopes } from './OAuthScope';
import IClientContext from '../../../contracts/IClientContext';

export { OAuthFlow };

interface DatabricksOAuthOptions extends OAuthManagerOptions {
scopes?: OAuthScopes;
persistence?: OAuthPersistence;
Expand Down
1 change: 1 addition & 0 deletions tests/unit/DBSQLClient.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ describe('DBSQLClient.initAuthProvider', () => {
authType: 'databricks-oauth',
// host is used when creating OAuth manager, so make it look like a real AWS instance
host: 'example.dev.databricks.com',
oauthClientSecret: 'test-secret',
});

expect(provider).to.be.instanceOf(DatabricksOAuth);
Expand Down
Loading

0 comments on commit 957791b

Please sign in to comment.