Skip to content

Commit

Permalink
[PECO-909] Automatically renew oauth token when refresh token is avai…
Browse files Browse the repository at this point in the history
…lable (databricks#156)

* HiveDriver: obtain a thrift client before each request (allows to re-create client if needed)

Signed-off-by: Levko Kravets <[email protected]>

* Move auth logic to DBSQLClient

Signed-off-by: Levko Kravets <[email protected]>

* Remove redundant HttpTransport class

Signed-off-by: Levko Kravets <[email protected]>

* Cache OAuth tokens in memory by default to avoid re-running OAuth flow on every request

Signed-off-by: Levko Kravets <[email protected]>

* Re-create thrift client when auth credentials (e.g. oauth token) change

Signed-off-by: Levko Kravets <[email protected]>

* Update tests

Signed-off-by: Levko Kravets <[email protected]>

---------

Signed-off-by: Levko Kravets <[email protected]>
Signed-off-by: nithinkdb <[email protected]>
  • Loading branch information
kravets-levko authored and nithinkdb committed Aug 21, 2023
1 parent 637b3bf commit efd3c65
Show file tree
Hide file tree
Showing 17 changed files with 554 additions and 564 deletions.
124 changes: 64 additions & 60 deletions lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import thrift from 'thrift';
import thrift, { HttpHeaders } from 'thrift';

import { EventEmitter } from 'events';
import TCLIService from '../thrift/TCLIService';
import { TProtocolVersion } from '../thrift/TCLIService_types';
import IDBSQLClient, { ConnectionOptions, OpenSessionRequest, ClientOptions } from './contracts/IDBSQLClient';
import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } from './contracts/IDBSQLClient';
import HiveDriver from './hive/HiveDriver';
import { Int64 } from './hive/Types';
import DBSQLSession from './DBSQLSession';
import IDBSQLSession from './contracts/IDBSQLSession';
import IThriftConnection from './connection/contracts/IThriftConnection';
import IConnectionProvider from './connection/contracts/IConnectionProvider';
import IAuthentication from './connection/contracts/IAuthentication';
import HttpConnection from './connection/connections/HttpConnection';
import IConnectionOptions from './connection/contracts/IConnectionOptions';
import Status from './dto/Status';
import HiveDriverError from './errors/HiveDriverError';
import { buildUserAgentString, definedOrError } from './utils';
import { areHeadersEqual, buildUserAgentString, definedOrError } from './utils';
import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication';
import DatabricksOAuth from './connection/auth/DatabricksOAuth';
import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger';
Expand All @@ -42,26 +40,25 @@ function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) {
}

export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
private client: TCLIService.Client | null;
private client: TCLIService.Client | null = null;

private connection: IThriftConnection | null;
private authProvider: IAuthentication | null = null;

private connectionProvider: IConnectionProvider;
private connectionOptions: ConnectionOptions | null = null;

private additionalHeaders: HttpHeaders = {};

private readonly logger: IDBSQLLogger;

private readonly thrift = thrift;

constructor(options?: ClientOptions) {
super();
this.connectionProvider = new HttpConnection();
this.logger = options?.logger || new DBSQLLogger();
this.client = null;
this.connection = null;
this.logger.log(LogLevel.info, 'Created DBSQLClient');
}

private getConnectionOptions(options: ConnectionOptions): IConnectionOptions {
private getConnectionOptions(options: ConnectionOptions, headers: HttpHeaders): IConnectionOptions {
const {
host,
port,
Expand All @@ -85,6 +82,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
https: true,
...otherOptions,
headers: {
...headers,
'User-Agent': buildUserAgentString(options.clientId),
},
},
Expand Down Expand Up @@ -126,39 +124,8 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
* const session = client.connect({host, path, token});
*/
public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise<IDBSQLClient> {
authProvider = this.getAuthProvider(options, authProvider);

this.connection = await this.connectionProvider.connect(this.getConnectionOptions(options), authProvider);

this.client = this.thrift.createClient(TCLIService, this.connection.getConnection());

this.connection.getConnection().on('error', (error: Error) => {
// Error.stack already contains error type and message, so log stack if available,
// otherwise fall back to just error type + message
this.logger.log(LogLevel.error, error.stack || `${error.name}: ${error.message}`);
try {
this.emit('error', error);
} catch (e) {
// EventEmitter will throw unhandled error when emitting 'error' event.
// Since we already logged it few lines above, just suppress this behaviour
}
});

this.connection.getConnection().on('reconnecting', (params: { delay: number; attempt: number }) => {
this.logger.log(LogLevel.debug, `Reconnecting, params: ${JSON.stringify(params)}`);
this.emit('reconnecting', params);
});

this.connection.getConnection().on('close', () => {
this.logger.log(LogLevel.debug, 'Closing connection.');
this.emit('close');
});

this.connection.getConnection().on('timeout', () => {
this.logger.log(LogLevel.debug, 'Connection timed out.');
this.emit('timeout');
});

this.authProvider = this.getAuthProvider(options, authProvider);
this.connectionOptions = options;
return this;
}

Expand All @@ -172,11 +139,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
* const session = await client.openSession();
*/
public async openSession(request: OpenSessionRequest = {}): Promise<IDBSQLSession> {
if (!this.connection?.isConnected()) {
throw new HiveDriverError('DBSQLClient: connection is lost');
}

const driver = new HiveDriver(this.getClient());
const driver = new HiveDriver(() => this.getClient());

const response = await driver.openSession({
client_protocol_i64: new Int64(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6),
Expand All @@ -187,23 +150,64 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
return new DBSQLSession(driver, definedOrError(response.sessionHandle), this.logger);
}

public getClient() {
if (!this.client) {
throw new HiveDriverError('DBSQLClient: client is not initialized');
private async getClient() {
if (!this.connectionOptions || !this.authProvider) {
throw new HiveDriverError('DBSQLClient: not connected');
}

const authHeaders = await this.authProvider.authenticate();
// When auth headers change - recreate client. Thrift library does not provide API for updating
// changed options, therefore we have to recreate both connection and client to apply new headers
if (!this.client || !areHeadersEqual(this.additionalHeaders, authHeaders)) {
this.logger.log(LogLevel.info, 'DBSQLClient: initializing thrift client');
this.additionalHeaders = authHeaders;
const connectionOptions = this.getConnectionOptions(this.connectionOptions, this.additionalHeaders);

const connection = await this.createConnection(connectionOptions);
this.client = this.thrift.createClient(TCLIService, connection.getConnection());
}

return this.client;
}

public async close(): Promise<void> {
if (this.connection) {
const thriftConnection = this.connection.getConnection();
private async createConnection(options: IConnectionOptions) {
const connectionProvider = new HttpConnection();
const connection = await connectionProvider.connect(options);
const thriftConnection = connection.getConnection();

if (typeof thriftConnection.end === 'function') {
this.connection.getConnection().end();
thriftConnection.on('error', (error: Error) => {
// Error.stack already contains error type and message, so log stack if available,
// otherwise fall back to just error type + message
this.logger.log(LogLevel.error, error.stack || `${error.name}: ${error.message}`);
try {
this.emit('error', error);
} catch (e) {
// EventEmitter will throw unhandled error when emitting 'error' event.
// Since we already logged it few lines above, just suppress this behaviour
}
});

this.connection = null;
}
thriftConnection.on('reconnecting', (params: { delay: number; attempt: number }) => {
this.logger.log(LogLevel.debug, `Reconnecting, params: ${JSON.stringify(params)}`);
this.emit('reconnecting', params);
});

thriftConnection.on('close', () => {
this.logger.log(LogLevel.debug, 'Closing connection.');
this.emit('close');
});

thriftConnection.on('timeout', () => {
this.logger.log(LogLevel.debug, 'Connection timed out.');
this.emit('timeout');
});

return connection;
}

public async close(): Promise<void> {
this.client = null;
this.authProvider = null;
this.connectionOptions = null;
}
}
12 changes: 12 additions & 0 deletions lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,15 @@ export default interface OAuthPersistence {

read(host: string): Promise<OAuthToken | undefined>;
}

export class OAuthPersistenceCache implements OAuthPersistence {
private tokens: Record<string, OAuthToken | undefined> = {};

async persist(host: string, token: OAuthToken) {
this.tokens[host] = token;
}

async read(host: string) {
return this.tokens[host];
}
}
19 changes: 11 additions & 8 deletions lib/connection/auth/DatabricksOAuth/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { HttpHeaders } from 'thrift';
import IAuthentication from '../../contracts/IAuthentication';
import HttpTransport from '../../transports/HttpTransport';
import IDBSQLLogger from '../../../contracts/IDBSQLLogger';
import OAuthPersistence from './OAuthPersistence';
import OAuthPersistence, { OAuthPersistenceCache } from './OAuthPersistence';
import OAuthManager, { OAuthManagerOptions } from './OAuthManager';
import { OAuthScopes, defaultOAuthScopes } from './OAuthScope';

Expand All @@ -20,26 +19,30 @@ export default class DatabricksOAuth implements IAuthentication {

private readonly manager: OAuthManager;

private readonly defaultPersistence = new OAuthPersistenceCache();

constructor(options: DatabricksOAuthOptions) {
this.options = options;
this.logger = options.logger;
this.manager = OAuthManager.getManager(this.options);
}

public async authenticate(transport: HttpTransport): Promise<void> {
const { host, scopes, headers, persistence } = this.options;
public async authenticate(): Promise<HttpHeaders> {
const { host, scopes, headers } = this.options;

const persistence = this.options.persistence ?? this.defaultPersistence;

let token = await persistence?.read(host);
let token = await persistence.read(host);
if (!token) {
token = await this.manager.getToken(scopes ?? defaultOAuthScopes);
}

token = await this.manager.refreshAccessToken(token);
await persistence?.persist(host, token);
await persistence.persist(host, token);

transport.updateHeaders({
return {
...headers,
Authorization: `Bearer ${token.accessToken}`,
});
};
}
}
7 changes: 3 additions & 4 deletions lib/connection/auth/PlainHttpAuthentication.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { HttpHeaders } from 'thrift';
import IAuthentication from '../contracts/IAuthentication';
import HttpTransport from '../transports/HttpTransport';

interface PlainHttpAuthenticationOptions {
username?: string;
Expand All @@ -21,10 +20,10 @@ export default class PlainHttpAuthentication implements IAuthentication {
this.headers = options?.headers || {};
}

public async authenticate(transport: HttpTransport): Promise<void> {
transport.updateHeaders({
public async authenticate(): Promise<HttpHeaders> {
return {
...this.headers,
Authorization: `Bearer ${this.password}`,
});
};
}
}
17 changes: 6 additions & 11 deletions lib/connection/connections/HttpConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import http, { IncomingMessage } from 'http';
import IThriftConnection from '../contracts/IThriftConnection';
import IConnectionProvider from '../contracts/IConnectionProvider';
import IConnectionOptions, { Options } from '../contracts/IConnectionOptions';
import IAuthentication from '../contracts/IAuthentication';
import HttpTransport from '../transports/HttpTransport';
import globalConfig from '../../globalConfig';

type NodeOptions = {
Expand All @@ -21,7 +19,7 @@ export default class HttpConnection implements IConnectionProvider, IThriftConne

private connection: any;

connect(options: IConnectionOptions, authProvider: IAuthentication): Promise<IThriftConnection> {
async connect(options: IConnectionOptions): Promise<IThriftConnection> {
const agentOptions: http.AgentOptions = {
keepAlive: true,
maxSockets: 5,
Expand All @@ -33,7 +31,7 @@ export default class HttpConnection implements IConnectionProvider, IThriftConne
? new https.Agent({ ...agentOptions, minVersion: 'TLSv1.2' })
: new http.Agent(agentOptions);

const httpTransport = new HttpTransport({
const thriftOptions = {
transport: thrift.TBufferedTransport,
protocol: thrift.TBinaryProtocol,
...options.options,
Expand All @@ -43,15 +41,12 @@ export default class HttpConnection implements IConnectionProvider, IThriftConne
...(options.options?.nodeOptions || {}),
timeout: options.options?.socketTimeout ?? globalConfig.socketTimeout,
},
});

return authProvider.authenticate(httpTransport).then(() => {
this.connection = this.thrift.createHttpConnection(options.host, options.port, httpTransport.getOptions());
};

this.addCookieHandler();
this.connection = this.thrift.createHttpConnection(options.host, options.port, thriftOptions);
this.addCookieHandler();

return this;
});
return this;
}

getConnection() {
Expand Down
4 changes: 2 additions & 2 deletions lib/connection/contracts/IAuthentication.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import HttpTransport from '../transports/HttpTransport';
import { HttpHeaders } from 'thrift';

export default interface IAuthentication {
authenticate(transport: HttpTransport): Promise<void>;
authenticate(): Promise<HttpHeaders>;
}
3 changes: 1 addition & 2 deletions lib/connection/contracts/IConnectionProvider.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import IConnectionOptions from './IConnectionOptions';
import IAuthentication from './IAuthentication';
import IThriftConnection from './IThriftConnection';

export default interface IConnectionProvider {
connect(options: IConnectionOptions, authProvider: IAuthentication): Promise<IThriftConnection>;
connect(options: IConnectionOptions): Promise<IThriftConnection>;
}
Loading

0 comments on commit efd3c65

Please sign in to comment.