Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PECO-909] Automatically renew oauth token when refresh token is available #156

Merged
merged 6 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 64 additions & 60 deletions lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import thrift from 'thrift';
import thrift, { HttpHeaders } from 'thrift';

import { EventEmitter } from 'events';
import TCLIService from '../thrift/TCLIService';
import { TProtocolVersion } from '../thrift/TCLIService_types';
import IDBSQLClient, { ConnectionOptions, OpenSessionRequest, ClientOptions } from './contracts/IDBSQLClient';
import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } from './contracts/IDBSQLClient';
import HiveDriver from './hive/HiveDriver';
import { Int64 } from './hive/Types';
import DBSQLSession from './DBSQLSession';
import IDBSQLSession from './contracts/IDBSQLSession';
import IThriftConnection from './connection/contracts/IThriftConnection';
import IConnectionProvider from './connection/contracts/IConnectionProvider';
import IAuthentication from './connection/contracts/IAuthentication';
import HttpConnection from './connection/connections/HttpConnection';
import IConnectionOptions from './connection/contracts/IConnectionOptions';
import Status from './dto/Status';
import HiveDriverError from './errors/HiveDriverError';
import { buildUserAgentString, definedOrError } from './utils';
import { areHeadersEqual, buildUserAgentString, definedOrError } from './utils';
import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication';
import DatabricksOAuth from './connection/auth/DatabricksOAuth';
import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger';
Expand All @@ -42,26 +40,25 @@ function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) {
}

export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
private client: TCLIService.Client | null;
private client: TCLIService.Client | null = null;

private connection: IThriftConnection | null;
private authProvider: IAuthentication | null = null;

private connectionProvider: IConnectionProvider;
private connectionOptions: ConnectionOptions | null = null;

private additionalHeaders: HttpHeaders = {};

private readonly logger: IDBSQLLogger;

private readonly thrift = thrift;

constructor(options?: ClientOptions) {
super();
this.connectionProvider = new HttpConnection();
this.logger = options?.logger || new DBSQLLogger();
this.client = null;
this.connection = null;
this.logger.log(LogLevel.info, 'Created DBSQLClient');
}

private getConnectionOptions(options: ConnectionOptions): IConnectionOptions {
private getConnectionOptions(options: ConnectionOptions, headers: HttpHeaders): IConnectionOptions {
const {
host,
port,
Expand All @@ -85,6 +82,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
https: true,
...otherOptions,
headers: {
...headers,
'User-Agent': buildUserAgentString(options.clientId),
},
},
Expand Down Expand Up @@ -126,39 +124,8 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
* const session = client.connect({host, path, token});
*/
public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise<IDBSQLClient> {
authProvider = this.getAuthProvider(options, authProvider);

this.connection = await this.connectionProvider.connect(this.getConnectionOptions(options), authProvider);

this.client = this.thrift.createClient(TCLIService, this.connection.getConnection());

this.connection.getConnection().on('error', (error: Error) => {
// Error.stack already contains error type and message, so log stack if available,
// otherwise fall back to just error type + message
this.logger.log(LogLevel.error, error.stack || `${error.name}: ${error.message}`);
try {
this.emit('error', error);
} catch (e) {
// EventEmitter will throw unhandled error when emitting 'error' event.
// Since we already logged it few lines above, just suppress this behaviour
}
});

this.connection.getConnection().on('reconnecting', (params: { delay: number; attempt: number }) => {
this.logger.log(LogLevel.debug, `Reconnecting, params: ${JSON.stringify(params)}`);
this.emit('reconnecting', params);
});

this.connection.getConnection().on('close', () => {
this.logger.log(LogLevel.debug, 'Closing connection.');
this.emit('close');
});

this.connection.getConnection().on('timeout', () => {
this.logger.log(LogLevel.debug, 'Connection timed out.');
this.emit('timeout');
});

this.authProvider = this.getAuthProvider(options, authProvider);
this.connectionOptions = options;
return this;
}

Expand All @@ -172,11 +139,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
* const session = await client.openSession();
*/
public async openSession(request: OpenSessionRequest = {}): Promise<IDBSQLSession> {
if (!this.connection?.isConnected()) {
throw new HiveDriverError('DBSQLClient: connection is lost');
}

const driver = new HiveDriver(this.getClient());
const driver = new HiveDriver(() => this.getClient());

const response = await driver.openSession({
client_protocol_i64: new Int64(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6),
Expand All @@ -187,23 +150,64 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
return new DBSQLSession(driver, definedOrError(response.sessionHandle), this.logger);
}

public getClient() {
if (!this.client) {
throw new HiveDriverError('DBSQLClient: client is not initialized');
private async getClient() {
if (!this.connectionOptions || !this.authProvider) {
throw new HiveDriverError('DBSQLClient: not connected');
}

const authHeaders = await this.authProvider.authenticate();
// When auth headers change - recreate client. Thrift library does not provide API for updating
// changed options, therefore we have to recreate both connection and client to apply new headers
if (!this.client || !areHeadersEqual(this.additionalHeaders, authHeaders)) {
this.logger.log(LogLevel.info, 'DBSQLClient: initializing thrift client');
this.additionalHeaders = authHeaders;
const connectionOptions = this.getConnectionOptions(this.connectionOptions, this.additionalHeaders);

const connection = await this.createConnection(connectionOptions);
this.client = this.thrift.createClient(TCLIService, connection.getConnection());
}

return this.client;
}

public async close(): Promise<void> {
if (this.connection) {
const thriftConnection = this.connection.getConnection();
private async createConnection(options: IConnectionOptions) {
const connectionProvider = new HttpConnection();
const connection = await connectionProvider.connect(options);
const thriftConnection = connection.getConnection();

if (typeof thriftConnection.end === 'function') {
this.connection.getConnection().end();
thriftConnection.on('error', (error: Error) => {
// Error.stack already contains error type and message, so log stack if available,
// otherwise fall back to just error type + message
this.logger.log(LogLevel.error, error.stack || `${error.name}: ${error.message}`);
try {
this.emit('error', error);
} catch (e) {
// EventEmitter will throw unhandled error when emitting 'error' event.
// Since we already logged it few lines above, just suppress this behaviour
}
});

this.connection = null;
}
thriftConnection.on('reconnecting', (params: { delay: number; attempt: number }) => {
this.logger.log(LogLevel.debug, `Reconnecting, params: ${JSON.stringify(params)}`);
this.emit('reconnecting', params);
});

thriftConnection.on('close', () => {
this.logger.log(LogLevel.debug, 'Closing connection.');
this.emit('close');
});

thriftConnection.on('timeout', () => {
this.logger.log(LogLevel.debug, 'Connection timed out.');
this.emit('timeout');
});

return connection;
}

public async close(): Promise<void> {
this.client = null;
this.authProvider = null;
this.connectionOptions = null;
}
}
12 changes: 12 additions & 0 deletions lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,15 @@ export default interface OAuthPersistence {

read(host: string): Promise<OAuthToken | undefined>;
}

export class OAuthPersistenceCache implements OAuthPersistence {
private tokens: Record<string, OAuthToken | undefined> = {};

async persist(host: string, token: OAuthToken) {
this.tokens[host] = token;
}

async read(host: string) {
return this.tokens[host];
}
}
19 changes: 11 additions & 8 deletions lib/connection/auth/DatabricksOAuth/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { HttpHeaders } from 'thrift';
import IAuthentication from '../../contracts/IAuthentication';
import HttpTransport from '../../transports/HttpTransport';
import IDBSQLLogger from '../../../contracts/IDBSQLLogger';
import OAuthPersistence from './OAuthPersistence';
import OAuthPersistence, { OAuthPersistenceCache } from './OAuthPersistence';
import OAuthManager, { OAuthManagerOptions } from './OAuthManager';
import { OAuthScopes, defaultOAuthScopes } from './OAuthScope';

Expand All @@ -20,26 +19,30 @@ export default class DatabricksOAuth implements IAuthentication {

private readonly manager: OAuthManager;

private readonly defaultPersistence = new OAuthPersistenceCache();

constructor(options: DatabricksOAuthOptions) {
this.options = options;
this.logger = options.logger;
this.manager = OAuthManager.getManager(this.options);
}

public async authenticate(transport: HttpTransport): Promise<void> {
const { host, scopes, headers, persistence } = this.options;
public async authenticate(): Promise<HttpHeaders> {
const { host, scopes, headers } = this.options;

const persistence = this.options.persistence ?? this.defaultPersistence;

let token = await persistence?.read(host);
let token = await persistence.read(host);
if (!token) {
token = await this.manager.getToken(scopes ?? defaultOAuthScopes);
}

token = await this.manager.refreshAccessToken(token);
await persistence?.persist(host, token);
await persistence.persist(host, token);

transport.updateHeaders({
return {
...headers,
Authorization: `Bearer ${token.accessToken}`,
});
};
}
}
7 changes: 3 additions & 4 deletions lib/connection/auth/PlainHttpAuthentication.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { HttpHeaders } from 'thrift';
import IAuthentication from '../contracts/IAuthentication';
import HttpTransport from '../transports/HttpTransport';

interface PlainHttpAuthenticationOptions {
username?: string;
Expand All @@ -21,10 +20,10 @@ export default class PlainHttpAuthentication implements IAuthentication {
this.headers = options?.headers || {};
}

public async authenticate(transport: HttpTransport): Promise<void> {
transport.updateHeaders({
public async authenticate(): Promise<HttpHeaders> {
return {
...this.headers,
Authorization: `Bearer ${this.password}`,
});
};
}
}
17 changes: 6 additions & 11 deletions lib/connection/connections/HttpConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import http, { IncomingMessage } from 'http';
import IThriftConnection from '../contracts/IThriftConnection';
import IConnectionProvider from '../contracts/IConnectionProvider';
import IConnectionOptions, { Options } from '../contracts/IConnectionOptions';
import IAuthentication from '../contracts/IAuthentication';
import HttpTransport from '../transports/HttpTransport';
import globalConfig from '../../globalConfig';

type NodeOptions = {
Expand All @@ -21,7 +19,7 @@ export default class HttpConnection implements IConnectionProvider, IThriftConne

private connection: any;

connect(options: IConnectionOptions, authProvider: IAuthentication): Promise<IThriftConnection> {
async connect(options: IConnectionOptions): Promise<IThriftConnection> {
const agentOptions: http.AgentOptions = {
keepAlive: true,
maxSockets: 5,
Expand All @@ -33,7 +31,7 @@ export default class HttpConnection implements IConnectionProvider, IThriftConne
? new https.Agent({ ...agentOptions, minVersion: 'TLSv1.2' })
: new http.Agent(agentOptions);

const httpTransport = new HttpTransport({
const thriftOptions = {
transport: thrift.TBufferedTransport,
protocol: thrift.TBinaryProtocol,
...options.options,
Expand All @@ -43,15 +41,12 @@ export default class HttpConnection implements IConnectionProvider, IThriftConne
...(options.options?.nodeOptions || {}),
timeout: options.options?.socketTimeout ?? globalConfig.socketTimeout,
},
});

return authProvider.authenticate(httpTransport).then(() => {
this.connection = this.thrift.createHttpConnection(options.host, options.port, httpTransport.getOptions());
};

this.addCookieHandler();
this.connection = this.thrift.createHttpConnection(options.host, options.port, thriftOptions);
this.addCookieHandler();

return this;
});
return this;
}

getConnection() {
Expand Down
4 changes: 2 additions & 2 deletions lib/connection/contracts/IAuthentication.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import HttpTransport from '../transports/HttpTransport';
import { HttpHeaders } from 'thrift';

export default interface IAuthentication {
authenticate(transport: HttpTransport): Promise<void>;
authenticate(): Promise<HttpHeaders>;
}
3 changes: 1 addition & 2 deletions lib/connection/contracts/IConnectionProvider.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import IConnectionOptions from './IConnectionOptions';
import IAuthentication from './IAuthentication';
import IThriftConnection from './IThriftConnection';

export default interface IConnectionProvider {
connect(options: IConnectionOptions, authProvider: IAuthentication): Promise<IThriftConnection>;
connect(options: IConnectionOptions): Promise<IThriftConnection>;
}
Loading
Loading