From 531202be0247fdfb0dade15d18a2af6c2c5db335 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 27 Jul 2023 11:26:52 +0300 Subject: [PATCH 1/6] HiveDriver: obtain a thrift client before each request (allows to re-create client if needed) Signed-off-by: Levko Kravets --- lib/DBSQLClient.ts | 13 ++- lib/hive/HiveDriver.ts | 134 +++++++++++++++-------------- tests/unit/DBSQLClient.test.js | 14 ++- tests/unit/hive/HiveDriver.test.js | 2 +- 4 files changed, 85 insertions(+), 78 deletions(-) diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 9abb8f5f..8276bc32 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -42,11 +42,11 @@ function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) { } export default class DBSQLClient extends EventEmitter implements IDBSQLClient { - private client: TCLIService.Client | null; + private client: TCLIService.Client | null = null; - private connection: IThriftConnection | null; + private connection: IThriftConnection | null = null; - private connectionProvider: IConnectionProvider; + private connectionProvider: IConnectionProvider = new HttpConnection(); private readonly logger: IDBSQLLogger; @@ -54,10 +54,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { constructor(options?: ClientOptions) { super(); - this.connectionProvider = new HttpConnection(); this.logger = options?.logger || new DBSQLLogger(); - this.client = null; - this.connection = null; this.logger.log(LogLevel.info, 'Created DBSQLClient'); } @@ -176,7 +173,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { throw new HiveDriverError('DBSQLClient: connection is lost'); } - const driver = new HiveDriver(this.getClient()); + const driver = new HiveDriver(() => this.getClient()); const response = await driver.openSession({ client_protocol_i64: new Int64(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6), @@ -187,7 +184,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { return new DBSQLSession(driver, definedOrError(response.sessionHandle), this.logger); } - public getClient() { + private async getClient() { if (!this.client) { throw new HiveDriverError('DBSQLClient: client is not initialized'); } diff --git a/lib/hive/HiveDriver.ts b/lib/hive/HiveDriver.ts index 4774b793..f4199bac 100644 --- a/lib/hive/HiveDriver.ts +++ b/lib/hive/HiveDriver.ts @@ -44,136 +44,138 @@ import GetDelegationTokenCommand from './Commands/GetDelegationTokenCommand'; import CancelDelegationTokenCommand from './Commands/CancelDelegationTokenCommand'; import RenewDelegationTokenCommand from './Commands/RenewDelegationTokenCommand'; +export type ClientFactory = () => Promise; + export default class HiveDriver { - private client: TCLIService.Client; + private readonly clientFactory: ClientFactory; - constructor(client: TCLIService.Client) { - this.client = client; + constructor(clientFactory: ClientFactory) { + this.clientFactory = clientFactory; } - openSession(request: TOpenSessionReq) { - const action = new OpenSessionCommand(this.client); - + async openSession(request: TOpenSessionReq) { + const client = await this.clientFactory(); + const action = new OpenSessionCommand(client); return action.execute(request); } - closeSession(request: TCloseSessionReq) { - const command = new CloseSessionCommand(this.client); - + async closeSession(request: TCloseSessionReq) { + const client = await this.clientFactory(); + const command = new CloseSessionCommand(client); return command.execute(request); } - executeStatement(request: TExecuteStatementReq) { - const command = new ExecuteStatementCommand(this.client); - + async executeStatement(request: TExecuteStatementReq) { + const client = await this.clientFactory(); + const command = new ExecuteStatementCommand(client); return command.execute(request); } - getResultSetMetadata(request: TGetResultSetMetadataReq) { - const command = new GetResultSetMetadataCommand(this.client); - + async getResultSetMetadata(request: TGetResultSetMetadataReq) { + const client = await this.clientFactory(); + const command = new GetResultSetMetadataCommand(client); return command.execute(request); } - fetchResults(request: TFetchResultsReq) { - const command = new FetchResultsCommand(this.client); - + async fetchResults(request: TFetchResultsReq) { + const client = await this.clientFactory(); + const command = new FetchResultsCommand(client); return command.execute(request); } - getInfo(request: TGetInfoReq) { - const command = new GetInfoCommand(this.client); - + async getInfo(request: TGetInfoReq) { + const client = await this.clientFactory(); + const command = new GetInfoCommand(client); return command.execute(request); } - getTypeInfo(request: TGetTypeInfoReq) { - const command = new GetTypeInfoCommand(this.client); - + async getTypeInfo(request: TGetTypeInfoReq) { + const client = await this.clientFactory(); + const command = new GetTypeInfoCommand(client); return command.execute(request); } - getCatalogs(request: TGetCatalogsReq) { - const command = new GetCatalogsCommand(this.client); - + async getCatalogs(request: TGetCatalogsReq) { + const client = await this.clientFactory(); + const command = new GetCatalogsCommand(client); return command.execute(request); } - getSchemas(request: TGetSchemasReq) { - const command = new GetSchemasCommand(this.client); - + async getSchemas(request: TGetSchemasReq) { + const client = await this.clientFactory(); + const command = new GetSchemasCommand(client); return command.execute(request); } - getTables(request: TGetTablesReq) { - const command = new GetTablesCommand(this.client); - + async getTables(request: TGetTablesReq) { + const client = await this.clientFactory(); + const command = new GetTablesCommand(client); return command.execute(request); } - getTableTypes(request: TGetTableTypesReq) { - const command = new GetTableTypesCommand(this.client); - + async getTableTypes(request: TGetTableTypesReq) { + const client = await this.clientFactory(); + const command = new GetTableTypesCommand(client); return command.execute(request); } - getColumns(request: TGetColumnsReq) { - const command = new GetColumnsCommand(this.client); - + async getColumns(request: TGetColumnsReq) { + const client = await this.clientFactory(); + const command = new GetColumnsCommand(client); return command.execute(request); } - getFunctions(request: TGetFunctionsReq) { - const command = new GetFunctionsCommand(this.client); - + async getFunctions(request: TGetFunctionsReq) { + const client = await this.clientFactory(); + const command = new GetFunctionsCommand(client); return command.execute(request); } - getPrimaryKeys(request: TGetPrimaryKeysReq) { - const command = new GetPrimaryKeysCommand(this.client); - + async getPrimaryKeys(request: TGetPrimaryKeysReq) { + const client = await this.clientFactory(); + const command = new GetPrimaryKeysCommand(client); return command.execute(request); } - getCrossReference(request: TGetCrossReferenceReq) { - const command = new GetCrossReferenceCommand(this.client); - + async getCrossReference(request: TGetCrossReferenceReq) { + const client = await this.clientFactory(); + const command = new GetCrossReferenceCommand(client); return command.execute(request); } - getOperationStatus(request: TGetOperationStatusReq) { - const command = new GetOperationStatusCommand(this.client); - + async getOperationStatus(request: TGetOperationStatusReq) { + const client = await this.clientFactory(); + const command = new GetOperationStatusCommand(client); return command.execute(request); } - cancelOperation(request: TCancelOperationReq) { - const command = new CancelOperationCommand(this.client); - + async cancelOperation(request: TCancelOperationReq) { + const client = await this.clientFactory(); + const command = new CancelOperationCommand(client); return command.execute(request); } - closeOperation(request: TCloseOperationReq) { - const command = new CloseOperationCommand(this.client); - + async closeOperation(request: TCloseOperationReq) { + const client = await this.clientFactory(); + const command = new CloseOperationCommand(client); return command.execute(request); } - getDelegationToken(request: TGetDelegationTokenReq) { - const command = new GetDelegationTokenCommand(this.client); - + async getDelegationToken(request: TGetDelegationTokenReq) { + const client = await this.clientFactory(); + const command = new GetDelegationTokenCommand(client); return command.execute(request); } - cancelDelegationToken(request: TCancelDelegationTokenReq) { - const command = new CancelDelegationTokenCommand(this.client); - + async cancelDelegationToken(request: TCancelDelegationTokenReq) { + const client = await this.clientFactory(); + const command = new CancelDelegationTokenCommand(client); return command.execute(request); } - renewDelegationToken(request: TRenewDelegationTokenReq) { - const command = new RenewDelegationTokenCommand(this.client); - + async renewDelegationToken(request: TRenewDelegationTokenReq) { + const client = await this.clientFactory(); + const command = new RenewDelegationTokenCommand(client); return command.execute(request); } } diff --git a/tests/unit/DBSQLClient.test.js b/tests/unit/DBSQLClient.test.js index 2e6d8416..704881ae 100644 --- a/tests/unit/DBSQLClient.test.js +++ b/tests/unit/DBSQLClient.test.js @@ -1,4 +1,4 @@ -const { expect } = require('chai'); +const { expect, AssertionError } = require('chai'); const sinon = require('sinon'); const DBSQLClient = require('../../dist/DBSQLClient').default; const DBSQLSession = require('../../dist/DBSQLSession').default; @@ -190,9 +190,17 @@ describe('DBSQLClient.openSession', () => { }); describe('DBSQLClient.getClient', () => { - it('should throw an error if the client is not set', () => { + it('should throw an error if the client is not set', async () => { const client = new DBSQLClient(); - expect(() => client.getClient()).to.throw('DBSQLClient: client is not initialized'); + try { + await client.getClient(); + expect.fail('It should throw an error'); + } catch (error) { + if (error instanceof AssertionError) { + throw error; + } + expect(error.message).to.contain('DBSQLClient: client is not initialized'); + } }); }); diff --git a/tests/unit/hive/HiveDriver.test.js b/tests/unit/hive/HiveDriver.test.js index cb593f96..54511f23 100644 --- a/tests/unit/hive/HiveDriver.test.js +++ b/tests/unit/hive/HiveDriver.test.js @@ -6,7 +6,7 @@ const toTitleCase = (str) => str[0].toUpperCase() + str.slice(1); const testCommand = (command, request) => { const client = {}; - const driver = new HiveDriver(client); + const driver = new HiveDriver(() => Promise.resolve(client)); const response = { response: 'value' }; client[toTitleCase(command)] = function (req, cb) { expect(req).to.be.deep.eq(new TCLIService_types[`T${toTitleCase(command)}Req`](request)); From 2ed762af78e6d4161cd4ae81f07931ca1f9b9af2 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 27 Jul 2023 14:16:38 +0300 Subject: [PATCH 2/6] Move auth logic to DBSQLClient Signed-off-by: Levko Kravets --- lib/DBSQLClient.ts | 103 +++++++------- lib/connection/auth/DatabricksOAuth/index.ts | 7 +- .../auth/PlainHttpAuthentication.ts | 7 +- lib/connection/connections/HttpConnection.ts | 12 +- lib/connection/contracts/IAuthentication.ts | 4 +- .../contracts/IConnectionProvider.ts | 3 +- tests/unit/DBSQLClient.test.js | 126 ++++++------------ .../auth/DatabricksOAuth/index.test.js | 48 ++----- .../connections/HttpConnection.test.js | 115 ++++++---------- 9 files changed, 159 insertions(+), 266 deletions(-) diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 8276bc32..041c8d11 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -1,4 +1,4 @@ -import thrift from 'thrift'; +import thrift, { HttpHeaders } from 'thrift'; import { EventEmitter } from 'events'; import TCLIService from '../thrift/TCLIService'; @@ -8,8 +8,6 @@ import HiveDriver from './hive/HiveDriver'; import { Int64 } from './hive/Types'; import DBSQLSession from './DBSQLSession'; import IDBSQLSession from './contracts/IDBSQLSession'; -import IThriftConnection from './connection/contracts/IThriftConnection'; -import IConnectionProvider from './connection/contracts/IConnectionProvider'; import IAuthentication from './connection/contracts/IAuthentication'; import HttpConnection from './connection/connections/HttpConnection'; import IConnectionOptions from './connection/contracts/IConnectionOptions'; @@ -44,9 +42,9 @@ function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) { export default class DBSQLClient extends EventEmitter implements IDBSQLClient { private client: TCLIService.Client | null = null; - private connection: IThriftConnection | null = null; + private authProvider: IAuthentication | null = null; - private connectionProvider: IConnectionProvider = new HttpConnection(); + private connectionOptions: ConnectionOptions | null = null; private readonly logger: IDBSQLLogger; @@ -58,7 +56,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { this.logger.log(LogLevel.info, 'Created DBSQLClient'); } - private getConnectionOptions(options: ConnectionOptions): IConnectionOptions { + private getConnectionOptions(options: ConnectionOptions, headers: HttpHeaders): IConnectionOptions { const { host, port, @@ -82,6 +80,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { https: true, ...otherOptions, headers: { + ...headers, 'User-Agent': buildUserAgentString(options.clientId), }, }, @@ -123,39 +122,8 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { * const session = client.connect({host, path, token}); */ public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise { - authProvider = this.getAuthProvider(options, authProvider); - - this.connection = await this.connectionProvider.connect(this.getConnectionOptions(options), authProvider); - - this.client = this.thrift.createClient(TCLIService, this.connection.getConnection()); - - this.connection.getConnection().on('error', (error: Error) => { - // Error.stack already contains error type and message, so log stack if available, - // otherwise fall back to just error type + message - this.logger.log(LogLevel.error, error.stack || `${error.name}: ${error.message}`); - try { - this.emit('error', error); - } catch (e) { - // EventEmitter will throw unhandled error when emitting 'error' event. - // Since we already logged it few lines above, just suppress this behaviour - } - }); - - this.connection.getConnection().on('reconnecting', (params: { delay: number; attempt: number }) => { - this.logger.log(LogLevel.debug, `Reconnecting, params: ${JSON.stringify(params)}`); - this.emit('reconnecting', params); - }); - - this.connection.getConnection().on('close', () => { - this.logger.log(LogLevel.debug, 'Closing connection.'); - this.emit('close'); - }); - - this.connection.getConnection().on('timeout', () => { - this.logger.log(LogLevel.debug, 'Connection timed out.'); - this.emit('timeout'); - }); - + this.authProvider = this.getAuthProvider(options, authProvider); + this.connectionOptions = options; return this; } @@ -169,10 +137,6 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { * const session = await client.openSession(); */ public async openSession(request: OpenSessionRequest = {}): Promise { - if (!this.connection?.isConnected()) { - throw new HiveDriverError('DBSQLClient: connection is lost'); - } - const driver = new HiveDriver(() => this.getClient()); const response = await driver.openSession({ @@ -185,22 +149,59 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { } private async getClient() { + if (!this.connectionOptions || !this.authProvider) { + throw new HiveDriverError('DBSQLClient: not connected'); + } + if (!this.client) { - throw new HiveDriverError('DBSQLClient: client is not initialized'); + const authHeaders = await this.authProvider.authenticate(); + const connectionOptions = this.getConnectionOptions(this.connectionOptions, authHeaders); + + const connection = await this.createConnection(connectionOptions); + this.client = this.thrift.createClient(TCLIService, connection.getConnection()); } return this.client; } - public async close(): Promise { - if (this.connection) { - const thriftConnection = this.connection.getConnection(); + private async createConnection(options: IConnectionOptions) { + const connectionProvider = new HttpConnection(); + const connection = await connectionProvider.connect(options); + const thriftConnection = connection.getConnection(); - if (typeof thriftConnection.end === 'function') { - this.connection.getConnection().end(); + thriftConnection.on('error', (error: Error) => { + // Error.stack already contains error type and message, so log stack if available, + // otherwise fall back to just error type + message + this.logger.log(LogLevel.error, error.stack || `${error.name}: ${error.message}`); + try { + this.emit('error', error); + } catch (e) { + // EventEmitter will throw unhandled error when emitting 'error' event. + // Since we already logged it few lines above, just suppress this behaviour } + }); - this.connection = null; - } + thriftConnection.on('reconnecting', (params: { delay: number; attempt: number }) => { + this.logger.log(LogLevel.debug, `Reconnecting, params: ${JSON.stringify(params)}`); + this.emit('reconnecting', params); + }); + + thriftConnection.on('close', () => { + this.logger.log(LogLevel.debug, 'Closing connection.'); + this.emit('close'); + }); + + thriftConnection.on('timeout', () => { + this.logger.log(LogLevel.debug, 'Connection timed out.'); + this.emit('timeout'); + }); + + return connection; + } + + public async close(): Promise { + this.client = null; + this.authProvider = null; + this.connectionOptions = null; } } diff --git a/lib/connection/auth/DatabricksOAuth/index.ts b/lib/connection/auth/DatabricksOAuth/index.ts index ca8db9e7..14fbf7af 100644 --- a/lib/connection/auth/DatabricksOAuth/index.ts +++ b/lib/connection/auth/DatabricksOAuth/index.ts @@ -1,6 +1,5 @@ import { HttpHeaders } from 'thrift'; import IAuthentication from '../../contracts/IAuthentication'; -import HttpTransport from '../../transports/HttpTransport'; import IDBSQLLogger from '../../../contracts/IDBSQLLogger'; import OAuthPersistence from './OAuthPersistence'; import OAuthManager, { OAuthManagerOptions } from './OAuthManager'; @@ -26,7 +25,7 @@ export default class DatabricksOAuth implements IAuthentication { this.manager = OAuthManager.getManager(this.options); } - public async authenticate(transport: HttpTransport): Promise { + public async authenticate(): Promise { const { host, scopes, headers, persistence } = this.options; let token = await persistence?.read(host); @@ -37,9 +36,9 @@ export default class DatabricksOAuth implements IAuthentication { token = await this.manager.refreshAccessToken(token); await persistence?.persist(host, token); - transport.updateHeaders({ + return { ...headers, Authorization: `Bearer ${token.accessToken}`, - }); + }; } } diff --git a/lib/connection/auth/PlainHttpAuthentication.ts b/lib/connection/auth/PlainHttpAuthentication.ts index 3efc611d..ab0283e3 100644 --- a/lib/connection/auth/PlainHttpAuthentication.ts +++ b/lib/connection/auth/PlainHttpAuthentication.ts @@ -1,6 +1,5 @@ import { HttpHeaders } from 'thrift'; import IAuthentication from '../contracts/IAuthentication'; -import HttpTransport from '../transports/HttpTransport'; interface PlainHttpAuthenticationOptions { username?: string; @@ -21,10 +20,10 @@ export default class PlainHttpAuthentication implements IAuthentication { this.headers = options?.headers || {}; } - public async authenticate(transport: HttpTransport): Promise { - transport.updateHeaders({ + public async authenticate(): Promise { + return { ...this.headers, Authorization: `Bearer ${this.password}`, - }); + }; } } diff --git a/lib/connection/connections/HttpConnection.ts b/lib/connection/connections/HttpConnection.ts index 164a932e..111e2db6 100644 --- a/lib/connection/connections/HttpConnection.ts +++ b/lib/connection/connections/HttpConnection.ts @@ -5,7 +5,6 @@ import http, { IncomingMessage } from 'http'; import IThriftConnection from '../contracts/IThriftConnection'; import IConnectionProvider from '../contracts/IConnectionProvider'; import IConnectionOptions, { Options } from '../contracts/IConnectionOptions'; -import IAuthentication from '../contracts/IAuthentication'; import HttpTransport from '../transports/HttpTransport'; import globalConfig from '../../globalConfig'; @@ -21,7 +20,7 @@ export default class HttpConnection implements IConnectionProvider, IThriftConne private connection: any; - connect(options: IConnectionOptions, authProvider: IAuthentication): Promise { + async connect(options: IConnectionOptions): Promise { const agentOptions: http.AgentOptions = { keepAlive: true, maxSockets: 5, @@ -45,13 +44,10 @@ export default class HttpConnection implements IConnectionProvider, IThriftConne }, }); - return authProvider.authenticate(httpTransport).then(() => { - this.connection = this.thrift.createHttpConnection(options.host, options.port, httpTransport.getOptions()); + this.connection = this.thrift.createHttpConnection(options.host, options.port, httpTransport.getOptions()); + this.addCookieHandler(); - this.addCookieHandler(); - - return this; - }); + return this; } getConnection() { diff --git a/lib/connection/contracts/IAuthentication.ts b/lib/connection/contracts/IAuthentication.ts index 44da935d..03bdbe1b 100644 --- a/lib/connection/contracts/IAuthentication.ts +++ b/lib/connection/contracts/IAuthentication.ts @@ -1,5 +1,5 @@ -import HttpTransport from '../transports/HttpTransport'; +import { HttpHeaders } from 'thrift'; export default interface IAuthentication { - authenticate(transport: HttpTransport): Promise; + authenticate(): Promise; } diff --git a/lib/connection/contracts/IConnectionProvider.ts b/lib/connection/contracts/IConnectionProvider.ts index a4b74079..0d25ccb6 100644 --- a/lib/connection/contracts/IConnectionProvider.ts +++ b/lib/connection/contracts/IConnectionProvider.ts @@ -1,7 +1,6 @@ import IConnectionOptions from './IConnectionOptions'; -import IAuthentication from './IAuthentication'; import IThriftConnection from './IThriftConnection'; export default interface IConnectionProvider { - connect(options: IConnectionOptions, authProvider: IAuthentication): Promise; + connect(options: IConnectionOptions): Promise; } diff --git a/tests/unit/DBSQLClient.test.js b/tests/unit/DBSQLClient.test.js index 704881ae..3ac1255a 100644 --- a/tests/unit/DBSQLClient.test.js +++ b/tests/unit/DBSQLClient.test.js @@ -6,7 +6,6 @@ const DBSQLSession = require('../../dist/DBSQLSession').default; const PlainHttpAuthentication = require('../../dist/connection/auth/PlainHttpAuthentication').default; const DatabricksOAuth = require('../../dist/connection/auth/DatabricksOAuth').default; const { AWSOAuthManager, AzureOAuthManager } = require('../../dist/connection/auth/DatabricksOAuth/OAuthManager'); -const HttpConnection = require('../../dist/connection/connections/HttpConnection').default; const ConnectionProviderMock = (connection) => ({ connect(options, auth) { @@ -34,53 +33,25 @@ describe('DBSQLClient.connect', () => { it('should prepend "/" to path if it is missing', async () => { const client = new DBSQLClient(); - client.thrift = { - createClient() {}, - }; - const connectionProvider = ConnectionProviderMock(); const path = 'example/path'; + const connectionOptions = client.getConnectionOptions({ ...options, path }, {}); - client.connectionProvider = connectionProvider; - await client.connect({ - ...options, - path, - }); - - expect(connectionProvider.options.options.path).to.equal(`/${path}`); + expect(connectionOptions.options.path).to.equal(`/${path}`); }); it('should not prepend "/" to path if it is already available', async () => { const client = new DBSQLClient(); - client.thrift = { - createClient() {}, - }; - const connectionProvider = ConnectionProviderMock(); const path = '/example/path'; + const connectionOptions = client.getConnectionOptions({ ...options, path }, {}); - client.connectionProvider = connectionProvider; - await client.connect({ - ...options, - path, - }); - - expect(connectionProvider.options.options.path).to.equal(path); + expect(connectionOptions.options.path).to.equal(path); }); - it('should set nosasl authenticator by default', async () => { - const client = new DBSQLClient(); - const connectionProvider = ConnectionProviderMock(); - - client.connectionProvider = connectionProvider; - try { - await client.connect(options); - } catch { - expect(connectionProvider.auth).instanceOf(PlainHttpAuthentication); - } - }); - - it('should handle network errors', (cb) => { + // client.connect() now does not actually attempt any network operations. for http it never did it + // even before, but this test was not quite correct even then. it needs to be updated + it.skip('should handle network errors', (cb) => { const client = new DBSQLClient(); client.thrift = { createClient() {}, @@ -91,6 +62,8 @@ describe('DBSQLClient.connect', () => { }, }); + sinon.stub(client, 'createConnection').returns(Promise.resolve(connectionProvider)); + client.on('error', (error) => { expect(error.message).to.be.eq('network error'); cb(); @@ -101,31 +74,22 @@ describe('DBSQLClient.connect', () => { cb(error); }); }); - - it('should use http connection by default', async () => { - const client = new DBSQLClient(); - client.thrift = { - createClient() {}, - }; - - await client.connect(options); - expect(client.connectionProvider).instanceOf(HttpConnection); - }); }); describe('DBSQLClient.openSession', () => { it('should successfully open session', async () => { const client = new DBSQLClient(); - client.client = { - OpenSession(req, cb) { - cb(null, { status: {}, sessionHandle: {} }); - }, - }; - client.connection = { - isConnected() { - return true; - }, - }; + + sinon.stub(client, 'getClient').returns( + Promise.resolve({ + OpenSession(req, cb) { + cb(null, { status: {}, sessionHandle: {} }); + }, + }), + ); + + client.authProvider = {}; + client.connectionOptions = {}; const session = await client.openSession(); expect(session).instanceOf(DBSQLSession); @@ -133,16 +97,17 @@ describe('DBSQLClient.openSession', () => { it('should use initial namespace options', async () => { const client = new DBSQLClient(); - client.client = { - OpenSession(req, cb) { - cb(null, { status: {}, sessionHandle: {} }); - }, - }; - client.connection = { - isConnected() { - return true; - }, - }; + + sinon.stub(client, 'getClient').returns( + Promise.resolve({ + OpenSession(req, cb) { + cb(null, { status: {}, sessionHandle: {} }); + }, + }), + ); + + client.authProvider = {}; + client.connectionOptions = {}; case1: { const session = await client.openSession({ initialCatalog: 'catalog' }); @@ -168,7 +133,7 @@ describe('DBSQLClient.openSession', () => { await client.openSession(); expect.fail('It should throw an error'); } catch (error) { - expect(error.message).to.be.eq('DBSQLClient: connection is lost'); + expect(error.message).to.be.eq('DBSQLClient: not connected'); } }); @@ -184,7 +149,7 @@ describe('DBSQLClient.openSession', () => { await client.openSession(); expect.fail('It should throw an error'); } catch (error) { - expect(error.message).to.be.eq('DBSQLClient: connection is lost'); + expect(error.message).to.be.eq('DBSQLClient: not connected'); } }); }); @@ -199,7 +164,7 @@ describe('DBSQLClient.getClient', () => { if (error instanceof AssertionError) { throw error; } - expect(error.message).to.contain('DBSQLClient: client is not initialized'); + expect(error.message).to.contain('DBSQLClient: not connected'); } }); }); @@ -207,15 +172,12 @@ describe('DBSQLClient.getClient', () => { describe('DBSQLClient.close', () => { it('should close the connection if it was initiated', async () => { const client = new DBSQLClient(); - const closeConnectionStub = sinon.stub(); - client.connection = { - getConnection: () => ({ - end: closeConnectionStub, - }), - }; + client.authProvider = {}; + client.connectionOptions = {}; await client.close(); - expect(closeConnectionStub.called).to.be.true; + expect(client.authProvider).to.be.null; + expect(client.connectionOptions).to.be.null; // No additional asserts needed - it should just reach this point }); @@ -223,16 +185,8 @@ describe('DBSQLClient.close', () => { const client = new DBSQLClient(); await client.close(); - // No additional asserts needed - it should just reach this point - }); - - it('should do nothing if the connection exists but cannot be finished', async () => { - const client = new DBSQLClient(); - client.connection = { - getConnection: () => ({}), - }; - - await client.close(); + expect(client.authProvider).to.be.null; + expect(client.connectionOptions).to.be.null; // No additional asserts needed - it should just reach this point }); }); diff --git a/tests/unit/connection/auth/DatabricksOAuth/index.test.js b/tests/unit/connection/auth/DatabricksOAuth/index.test.js index 272e51eb..a2a06218 100644 --- a/tests/unit/connection/auth/DatabricksOAuth/index.test.js +++ b/tests/unit/connection/auth/DatabricksOAuth/index.test.js @@ -21,19 +21,6 @@ class OAuthManagerMock { } } -class TransportMock { - constructor() { - this.headers = {}; - } - - updateHeaders(newHeaders) { - this.headers = { - ...this.headers, - ...newHeaders, - }; - } -} - class OAuthPersistenceMock { constructor() { this.token = undefined; @@ -61,10 +48,7 @@ function prepareTestInstances(options) { const provider = new DatabricksOAuth({ ...options }); - const transport = new TransportMock(); - sinon.stub(transport, 'updateHeaders').callThrough(); - - return { oauthManager, provider, transport }; + return { oauthManager, provider }; } describe('DatabricksOAuth', () => { @@ -76,25 +60,25 @@ describe('DatabricksOAuth', () => { const persistence = new OAuthPersistenceMock(); persistence.token = new OAuthToken(createValidAccessToken()); - const { provider, transport } = prepareTestInstances({ persistence }); + const { provider } = prepareTestInstances({ persistence }); - await provider.authenticate(transport); + await provider.authenticate(); expect(persistence.read.called).to.be.true; }); it('should get new token if storage not available', async () => { - const { oauthManager, provider, transport } = prepareTestInstances(); + const { oauthManager, provider } = prepareTestInstances(); - await provider.authenticate(transport); + await provider.authenticate(); expect(oauthManager.getToken.called).to.be.true; }); it('should get new token if persisted token not available, and store valid token', async () => { const persistence = new OAuthPersistenceMock(); persistence.token = undefined; - const { oauthManager, provider, transport } = prepareTestInstances({ persistence }); + const { oauthManager, provider } = prepareTestInstances({ persistence }); - await provider.authenticate(transport); + await provider.authenticate(); expect(oauthManager.getToken.called).to.be.true; expect(persistence.persist.called).to.be.true; expect(persistence.token).to.be.equal(oauthManager.getTokenResult); @@ -104,11 +88,11 @@ describe('DatabricksOAuth', () => { const persistence = new OAuthPersistenceMock(); persistence.token = undefined; - const { oauthManager, provider, transport } = prepareTestInstances({ persistence }); + const { oauthManager, provider } = prepareTestInstances({ persistence }); oauthManager.getTokenResult = new OAuthToken(createExpiredAccessToken()); oauthManager.refreshTokenResult = new OAuthToken(createValidAccessToken()); - await provider.authenticate(transport); + await provider.authenticate(); expect(oauthManager.getToken.called).to.be.true; expect(oauthManager.refreshAccessToken.called).to.be.true; expect(oauthManager.refreshAccessToken.firstCall.firstArg).to.be.equal(oauthManager.getTokenResult); @@ -118,18 +102,10 @@ describe('DatabricksOAuth', () => { }); it('should configure transport using valid token', async () => { - const { oauthManager, provider, transport } = prepareTestInstances(); - - const initialHeaders = { - x: 'x', - y: 'y', - }; - - transport.headers = initialHeaders; + const { oauthManager, provider } = prepareTestInstances(); - await provider.authenticate(transport); + const authHeaders = await provider.authenticate(); expect(oauthManager.getToken.called).to.be.true; - expect(transport.updateHeaders.called).to.be.true; - expect(Object.keys(transport.headers)).to.deep.equal([...Object.keys(initialHeaders), 'Authorization']); + expect(Object.keys(authHeaders)).to.deep.equal(['Authorization']); }); }); diff --git a/tests/unit/connection/connections/HttpConnection.test.js b/tests/unit/connection/connections/HttpConnection.test.js index 8bba299b..1c154939 100644 --- a/tests/unit/connection/connections/HttpConnection.test.js +++ b/tests/unit/connection/connections/HttpConnection.test.js @@ -11,17 +11,10 @@ const thriftMock = (connection) => ({ return connection; }, }); -const authProviderMock = () => ({ - authenticate() { - this.executed = true; - return Promise.resolve(); - }, -}); describe('HttpConnection.connect', () => { it('should successfully connect', () => { const connection = new HttpConnection(); - const authenticator = authProviderMock(); const resultConnection = { responseCallback() {}, }; @@ -30,18 +23,14 @@ describe('HttpConnection.connect', () => { expect(connection.isConnected()).to.be.false; return connection - .connect( - { - host: 'localhost', - port: 10001, - options: { - path: '/hive', - }, + .connect({ + host: 'localhost', + port: 10001, + options: { + path: '/hive', }, - authenticator, - ) + }) .then(() => { - expect(authenticator.executed).to.be.true; expect(connection.thrift.executed).to.be.true; expect(connection.thrift.host).to.be.eq('localhost'); expect(connection.thrift.port).to.be.eq(10001); @@ -53,27 +42,23 @@ describe('HttpConnection.connect', () => { it('should set SSL certificates and disable rejectUnauthorized', () => { const connection = new HttpConnection(); - const authenticator = authProviderMock(); const resultConnection = { responseCallback() {}, }; connection.thrift = thriftMock(resultConnection); return connection - .connect( - { - host: 'localhost', - port: 10001, - options: { - path: '/hive', - https: true, - ca: 'ca', - cert: 'cert', - key: 'key', - }, + .connect({ + host: 'localhost', + port: 10001, + options: { + path: '/hive', + https: true, + ca: 'ca', + cert: 'cert', + key: 'key', }, - authenticator, - ) + }) .then(() => { expect(connection.thrift.options.nodeOptions.rejectUnauthorized).to.be.false; expect(connection.thrift.options.nodeOptions.ca).to.be.eq('ca'); @@ -84,7 +69,6 @@ describe('HttpConnection.connect', () => { it('should set cookie', () => { const connection = new HttpConnection(); - const authenticator = authProviderMock(); const resultConnection = { nodeOptions: { headers: { cookie: '' } }, responseCallback() {}, @@ -92,16 +76,13 @@ describe('HttpConnection.connect', () => { connection.thrift = thriftMock(resultConnection); return connection - .connect( - { - host: 'localhost', - port: 10001, - options: { - path: '/hive', - }, + .connect({ + host: 'localhost', + port: 10001, + options: { + path: '/hive', }, - authenticator, - ) + }) .then(() => { resultConnection.responseCallback({ headers: { @@ -115,27 +96,23 @@ describe('HttpConnection.connect', () => { it('should overlay rejectUnauthorized', () => { const connection = new HttpConnection(); - const authenticator = authProviderMock(); const resultConnection = { responseCallback() {}, }; connection.thrift = thriftMock(resultConnection); return connection - .connect( - { - host: 'localhost', - port: 10001, - options: { - path: '/hive', - https: true, - nodeOptions: { - rejectUnauthorized: true, - }, + .connect({ + host: 'localhost', + port: 10001, + options: { + path: '/hive', + https: true, + nodeOptions: { + rejectUnauthorized: true, }, }, - authenticator, - ) + }) .then(() => { expect(connection.thrift.options.nodeOptions.rejectUnauthorized).to.be.true; }); @@ -143,7 +120,6 @@ describe('HttpConnection.connect', () => { it('should call response callback if cookie is not set', () => { const connection = new HttpConnection(); - const authenticator = authProviderMock(); const resultConnection = { responseCallback() { this.executed = true; @@ -152,13 +128,10 @@ describe('HttpConnection.connect', () => { connection.thrift = thriftMock(resultConnection); return connection - .connect( - { - host: 'localhost', - port: 10001, - }, - authenticator, - ) + .connect({ + host: 'localhost', + port: 10001, + }) .then(() => { resultConnection.responseCallback({ headers: {} }); expect(resultConnection.executed).to.be.true; @@ -168,24 +141,20 @@ describe('HttpConnection.connect', () => { it('should use a http agent if https is not enabled', () => { const connection = new HttpConnection(); - const authenticator = authProviderMock(); const resultConnection = { responseCallback() {}, }; connection.thrift = thriftMock(resultConnection); return connection - .connect( - { - host: 'localhost', - port: 10001, - options: { - https: false, - path: '/hive', - }, + .connect({ + host: 'localhost', + port: 10001, + options: { + https: false, + path: '/hive', }, - authenticator, - ) + }) .then(() => { expect(connection.thrift.options.nodeOptions.agent).to.be.instanceOf(http.Agent); }); From b5ec0aedb49e3e82445e0a307b30c6f93fffcc16 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 27 Jul 2023 14:20:13 +0300 Subject: [PATCH 3/6] Remove redundant HttpTransport class Signed-off-by: Levko Kravets --- lib/connection/connections/HttpConnection.ts | 7 +- lib/connection/transports/HttpTransport.ts | 56 -------- .../transports/HttpTransport.test.js | 120 ------------------ 3 files changed, 3 insertions(+), 180 deletions(-) delete mode 100644 lib/connection/transports/HttpTransport.ts delete mode 100644 tests/unit/connection/transports/HttpTransport.test.js diff --git a/lib/connection/connections/HttpConnection.ts b/lib/connection/connections/HttpConnection.ts index 111e2db6..ec665ed5 100644 --- a/lib/connection/connections/HttpConnection.ts +++ b/lib/connection/connections/HttpConnection.ts @@ -5,7 +5,6 @@ import http, { IncomingMessage } from 'http'; import IThriftConnection from '../contracts/IThriftConnection'; import IConnectionProvider from '../contracts/IConnectionProvider'; import IConnectionOptions, { Options } from '../contracts/IConnectionOptions'; -import HttpTransport from '../transports/HttpTransport'; import globalConfig from '../../globalConfig'; type NodeOptions = { @@ -32,7 +31,7 @@ export default class HttpConnection implements IConnectionProvider, IThriftConne ? new https.Agent({ ...agentOptions, minVersion: 'TLSv1.2' }) : new http.Agent(agentOptions); - const httpTransport = new HttpTransport({ + const thriftOptions = { transport: thrift.TBufferedTransport, protocol: thrift.TBinaryProtocol, ...options.options, @@ -42,9 +41,9 @@ export default class HttpConnection implements IConnectionProvider, IThriftConne ...(options.options?.nodeOptions || {}), timeout: options.options?.socketTimeout ?? globalConfig.socketTimeout, }, - }); + }; - this.connection = this.thrift.createHttpConnection(options.host, options.port, httpTransport.getOptions()); + this.connection = this.thrift.createHttpConnection(options.host, options.port, thriftOptions); this.addCookieHandler(); return this; diff --git a/lib/connection/transports/HttpTransport.ts b/lib/connection/transports/HttpTransport.ts deleted file mode 100644 index 3a652cc5..00000000 --- a/lib/connection/transports/HttpTransport.ts +++ /dev/null @@ -1,56 +0,0 @@ -import { ConnectOptions, HttpHeaders } from 'thrift'; - -export default class HttpTransport { - private options: ConnectOptions; - - constructor(options: ConnectOptions = {}) { - this.options = { ...options }; - } - - public getOptions(): ConnectOptions { - return this.options; - } - - public setOptions(options: ConnectOptions) { - this.options = { ...options }; - } - - public updateOptions(options: Partial) { - this.options = { - ...this.options, - ...options, - }; - } - - public getOption(option: K): ConnectOptions[K] { - return this.options[option]; - } - - public setOption(option: K, value: ConnectOptions[K]) { - this.options = { - ...this.options, - [option]: value, - }; - } - - public getHeaders(): HttpHeaders { - return this.options.headers ?? {}; - } - - public setHeaders(headers: HttpHeaders) { - this.options = { - ...this.options, - headers: { ...headers }, - }; - } - - public updateHeaders(headers: Partial) { - this.options = { - ...this.options, - headers: { - ...this.options.headers, - ...headers, - }, - }; - } -} diff --git a/tests/unit/connection/transports/HttpTransport.test.js b/tests/unit/connection/transports/HttpTransport.test.js deleted file mode 100644 index 0b21dbea..00000000 --- a/tests/unit/connection/transports/HttpTransport.test.js +++ /dev/null @@ -1,120 +0,0 @@ -const http = require('http'); -const { expect } = require('chai'); -const HttpTransport = require('../../../../dist/connection/transports/HttpTransport').default; - -describe('HttpTransport', () => { - it('should initialize with default options', () => { - const transport = new HttpTransport(); - expect(transport.getOptions()).to.deep.equal({}); - }); - - it('should replace all options', () => { - const initialOptions = { a: 'a', b: 'b' }; - const transport = new HttpTransport(initialOptions); - expect(transport.getOptions()).to.deep.equal(initialOptions); - - const newOptions = { c: 'c' }; - transport.setOptions(newOptions); - expect(transport.getOptions()).to.deep.equal(newOptions); - }); - - it('should update only specified options', () => { - const initialOptions = { a: 'a', b: 'b' }; - const transport = new HttpTransport(initialOptions); - expect(transport.getOptions()).to.deep.equal(initialOptions); - - const newOptions = { b: 'new_b', c: 'c' }; - transport.updateOptions(newOptions); - expect(transport.getOptions()).to.deep.equal({ - ...initialOptions, - ...newOptions, - }); - }); - - it('should get specific option', () => { - const initialOptions = { a: 'a', b: 'b' }; - const transport = new HttpTransport(initialOptions); - expect(transport.getOptions()).to.deep.equal(initialOptions); - - expect(transport.getOption('a')).to.deep.equal(initialOptions.a); - }); - - it('should set specific option', () => { - const initialOptions = { a: 'a', b: 'b' }; - const transport = new HttpTransport(initialOptions); - expect(transport.getOptions()).to.deep.equal(initialOptions); - - transport.setOption('b', 'new_b'); - expect(transport.getOptions()).to.deep.equal({ - ...initialOptions, - b: 'new_b', - }); - - transport.setOption('c', 'c'); - expect(transport.getOptions()).to.deep.equal({ - ...initialOptions, - b: 'new_b', - c: 'c', - }); - }); - - it('should get headers', () => { - case1: { - const transport = new HttpTransport(); - expect(transport.getOptions()).to.deep.equal({}); - - expect(transport.getHeaders()).to.deep.equal({}); - } - - case2: { - const initialOptions = { - a: 'a', - headers: { x: 'x' }, - }; - const transport = new HttpTransport(initialOptions); - expect(transport.getOptions()).to.deep.equal(initialOptions); - - expect(transport.getHeaders()).to.deep.equal(initialOptions.headers); - } - }); - - it('should replace headers', () => { - const initialOptions = { - a: 'a', - headers: { x: 'x', y: 'y' }, - }; - const transport = new HttpTransport(initialOptions); - expect(transport.getOptions()).to.deep.equal(initialOptions); - - const newHeaders = { y: 'new_y', z: 'z' }; - transport.setHeaders(newHeaders); - expect(transport.getOptions()).to.deep.equal({ - ...initialOptions, - headers: newHeaders, - }); - expect(transport.getHeaders()).to.deep.equal(newHeaders); - }); - - it('should update only specified headers', () => { - const initialOptions = { - a: 'a', - headers: { x: 'x', y: 'y' }, - }; - const transport = new HttpTransport(initialOptions); - expect(transport.getOptions()).to.deep.equal(initialOptions); - - const newHeaders = { y: 'new_y', z: 'z' }; - transport.updateHeaders(newHeaders); - expect(transport.getOptions()).to.deep.equal({ - ...initialOptions, - headers: { - ...initialOptions.headers, - ...newHeaders, - }, - }); - expect(transport.getHeaders()).to.deep.equal({ - ...initialOptions.headers, - ...newHeaders, - }); - }); -}); From 12afffec6cf15c3ec5601fd99052e4821a0ec4eb Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 27 Jul 2023 15:03:18 +0300 Subject: [PATCH 4/6] Cache OAuth tokens in memory by default to avoid re-running OAuth flow on every request Signed-off-by: Levko Kravets --- .../auth/DatabricksOAuth/OAuthPersistence.ts | 12 ++++++++++++ lib/connection/auth/DatabricksOAuth/index.ts | 12 ++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts b/lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts index c60a9f2f..cf81e557 100644 --- a/lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts +++ b/lib/connection/auth/DatabricksOAuth/OAuthPersistence.ts @@ -5,3 +5,15 @@ export default interface OAuthPersistence { read(host: string): Promise; } + +export class OAuthPersistenceCache implements OAuthPersistence { + private tokens: Record = {}; + + async persist(host: string, token: OAuthToken) { + this.tokens[host] = token; + } + + async read(host: string) { + return this.tokens[host]; + } +} diff --git a/lib/connection/auth/DatabricksOAuth/index.ts b/lib/connection/auth/DatabricksOAuth/index.ts index 14fbf7af..20e29e16 100644 --- a/lib/connection/auth/DatabricksOAuth/index.ts +++ b/lib/connection/auth/DatabricksOAuth/index.ts @@ -1,7 +1,7 @@ import { HttpHeaders } from 'thrift'; import IAuthentication from '../../contracts/IAuthentication'; import IDBSQLLogger from '../../../contracts/IDBSQLLogger'; -import OAuthPersistence from './OAuthPersistence'; +import OAuthPersistence, { OAuthPersistenceCache } from './OAuthPersistence'; import OAuthManager, { OAuthManagerOptions } from './OAuthManager'; import { OAuthScopes, defaultOAuthScopes } from './OAuthScope'; @@ -19,6 +19,8 @@ export default class DatabricksOAuth implements IAuthentication { private readonly manager: OAuthManager; + private readonly defaultPersistence = new OAuthPersistenceCache(); + constructor(options: DatabricksOAuthOptions) { this.options = options; this.logger = options.logger; @@ -26,15 +28,17 @@ export default class DatabricksOAuth implements IAuthentication { } public async authenticate(): Promise { - const { host, scopes, headers, persistence } = this.options; + const { host, scopes, headers } = this.options; + + const persistence = this.options.persistence ?? this.defaultPersistence; - let token = await persistence?.read(host); + let token = await persistence.read(host); if (!token) { token = await this.manager.getToken(scopes ?? defaultOAuthScopes); } token = await this.manager.refreshAccessToken(token); - await persistence?.persist(host, token); + await persistence.persist(host, token); return { ...headers, From 39eb9c9a9b23d4862935a6134432dedcf006e176 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 27 Jul 2023 15:04:03 +0300 Subject: [PATCH 5/6] Re-create thrift client when auth credentials (e.g. oauth token) change Signed-off-by: Levko Kravets --- lib/DBSQLClient.ts | 16 +++++++---- lib/utils/areHeadersEqual.ts | 54 ++++++++++++++++++++++++++++++++++++ lib/utils/index.ts | 3 +- 3 files changed, 67 insertions(+), 6 deletions(-) create mode 100644 lib/utils/areHeadersEqual.ts diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 041c8d11..f61097d8 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -3,7 +3,7 @@ import thrift, { HttpHeaders } from 'thrift'; import { EventEmitter } from 'events'; import TCLIService from '../thrift/TCLIService'; import { TProtocolVersion } from '../thrift/TCLIService_types'; -import IDBSQLClient, { ConnectionOptions, OpenSessionRequest, ClientOptions } from './contracts/IDBSQLClient'; +import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } from './contracts/IDBSQLClient'; import HiveDriver from './hive/HiveDriver'; import { Int64 } from './hive/Types'; import DBSQLSession from './DBSQLSession'; @@ -13,7 +13,7 @@ import HttpConnection from './connection/connections/HttpConnection'; import IConnectionOptions from './connection/contracts/IConnectionOptions'; import Status from './dto/Status'; import HiveDriverError from './errors/HiveDriverError'; -import { buildUserAgentString, definedOrError } from './utils'; +import { areHeadersEqual, buildUserAgentString, definedOrError } from './utils'; import PlainHttpAuthentication from './connection/auth/PlainHttpAuthentication'; import DatabricksOAuth from './connection/auth/DatabricksOAuth'; import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger'; @@ -46,6 +46,8 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { private connectionOptions: ConnectionOptions | null = null; + private additionalHeaders: HttpHeaders = {}; + private readonly logger: IDBSQLLogger; private readonly thrift = thrift; @@ -153,9 +155,13 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient { throw new HiveDriverError('DBSQLClient: not connected'); } - if (!this.client) { - const authHeaders = await this.authProvider.authenticate(); - const connectionOptions = this.getConnectionOptions(this.connectionOptions, authHeaders); + const authHeaders = await this.authProvider.authenticate(); + // When auth headers change - recreate client. Thrift library does not provide API for updating + // changed options, therefore we have to recreate both connection and client to apply new headers + if (!this.client || !areHeadersEqual(this.additionalHeaders, authHeaders)) { + this.logger.log(LogLevel.info, 'DBSQLClient: initializing thrift client'); + this.additionalHeaders = authHeaders; + const connectionOptions = this.getConnectionOptions(this.connectionOptions, this.additionalHeaders); const connection = await this.createConnection(connectionOptions); this.client = this.thrift.createClient(TCLIService, connection.getConnection()); diff --git a/lib/utils/areHeadersEqual.ts b/lib/utils/areHeadersEqual.ts new file mode 100644 index 00000000..f17433d6 --- /dev/null +++ b/lib/utils/areHeadersEqual.ts @@ -0,0 +1,54 @@ +import { HttpHeaders } from 'thrift'; + +function areArraysEqual(a: Array, b: Array): boolean { + // If they're the same object - they're equal + if (a === b) { + return true; + } + + // If they have a different size - they're definitely not equal + if (a.length !== b.length) { + return false; + } + + // Here we have arrays of same size. Compare elements - if any pair is different + // then arrays are different + for (let i = 0; i < a.length; i += 1) { + if (a[i] !== b[i]) { + return false; + } + } + + // If all corresponding elements in both arrays are equal - arrays are equal too + return true; +} + +export default function areHeadersEqual(a: HttpHeaders, b: HttpHeaders): boolean { + // If they're the same object - they're equal + if (a === b) { + return true; + } + + // If both objects have different keys - they're not equal + const keysOfA = Object.keys(a); + const keysOfB = Object.keys(b); + if (!areArraysEqual(keysOfA, keysOfB)) { + return false; + } + + // Compare corresponding properties of both objects. If any pair is different - objects are different + for (const key of keysOfA) { + const propA = a[key]; + const propB = b[key]; + + if (Array.isArray(propA) && Array.isArray(propB)) { + if (!areArraysEqual(propA, propB)) { + return false; + } + } else if (propA !== propB) { + return false; + } + } + + return true; +} diff --git a/lib/utils/index.ts b/lib/utils/index.ts index 4603277a..f89f71ea 100644 --- a/lib/utils/index.ts +++ b/lib/utils/index.ts @@ -1,5 +1,6 @@ +import areHeadersEqual from './areHeadersEqual'; import definedOrError from './definedOrError'; import buildUserAgentString from './buildUserAgentString'; import formatProgress, { ProgressUpdateTransformer } from './formatProgress'; -export { definedOrError, buildUserAgentString, formatProgress, ProgressUpdateTransformer }; +export { areHeadersEqual, definedOrError, buildUserAgentString, formatProgress, ProgressUpdateTransformer }; From 45df1e22d1a3a5488d59151c944d5f4468b6a637 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 27 Jul 2023 19:16:26 +0300 Subject: [PATCH 6/6] Update tests Signed-off-by: Levko Kravets --- tests/unit/DBSQLClient.test.js | 157 ++++++++++++++++++++++------- tests/unit/hive/HiveDriver.test.js | 14 ++- tests/unit/utils.test.js | 105 ++++++++++++++++++- 3 files changed, 233 insertions(+), 43 deletions(-) diff --git a/tests/unit/DBSQLClient.test.js b/tests/unit/DBSQLClient.test.js index 3ac1255a..9f450d07 100644 --- a/tests/unit/DBSQLClient.test.js +++ b/tests/unit/DBSQLClient.test.js @@ -7,22 +7,17 @@ const PlainHttpAuthentication = require('../../dist/connection/auth/PlainHttpAut const DatabricksOAuth = require('../../dist/connection/auth/DatabricksOAuth').default; const { AWSOAuthManager, AzureOAuthManager } = require('../../dist/connection/auth/DatabricksOAuth/OAuthManager'); -const ConnectionProviderMock = (connection) => ({ - connect(options, auth) { - this.options = options; - this.auth = auth; - - return Promise.resolve({ - getConnection() { - return ( - connection || { - on: () => {}, - } - ); - }, - }); - }, -}); +const HttpConnectionModule = require('../../dist/connection/connections/HttpConnection'); + +class AuthProviderMock { + constructor() { + this.authResult = {}; + } + + authenticate() { + return Promise.resolve(this.authResult); + } +} describe('DBSQLClient.connect', () => { const options = { @@ -49,30 +44,18 @@ describe('DBSQLClient.connect', () => { expect(connectionOptions.options.path).to.equal(path); }); - // client.connect() now does not actually attempt any network operations. for http it never did it - // even before, but this test was not quite correct even then. it needs to be updated - it.skip('should handle network errors', (cb) => { + it('should initialize connection state', async () => { const client = new DBSQLClient(); - client.thrift = { - createClient() {}, - }; - const connectionProvider = ConnectionProviderMock({ - on(name, handler) { - handler(new Error('network error')); - }, - }); - sinon.stub(client, 'createConnection').returns(Promise.resolve(connectionProvider)); + expect(client.client).to.be.null; + expect(client.authProvider).to.be.null; + expect(client.connectionOptions).to.be.null; - client.on('error', (error) => { - expect(error.message).to.be.eq('network error'); - cb(); - }); + await client.connect(options); - client.connectionProvider = connectionProvider; - client.connect(options).catch((error) => { - cb(error); - }); + expect(client.client).to.be.null; // it should not be initialized at this point + expect(client.authProvider).to.be.instanceOf(PlainHttpAuthentication); + expect(client.connectionOptions).to.be.deep.equal(options); }); }); @@ -155,7 +138,13 @@ describe('DBSQLClient.openSession', () => { }); describe('DBSQLClient.getClient', () => { - it('should throw an error if the client is not set', async () => { + const options = { + host: '127.0.0.1', + path: '', + token: 'dapi********************************', + }; + + it('should throw an error if not connected', async () => { const client = new DBSQLClient(); try { await client.getClient(); @@ -167,15 +156,108 @@ describe('DBSQLClient.getClient', () => { expect(error.message).to.contain('DBSQLClient: not connected'); } }); + + it("should create client if wasn't not initialized yet", async () => { + const client = new DBSQLClient(); + + const thriftClient = {}; + + client.authProvider = new AuthProviderMock(); + client.connectionOptions = { ...options }; + client.thrift = { + createClient: sinon.stub().returns(thriftClient), + }; + sinon.stub(client, 'createConnection').returns({ + getConnection: () => null, + }); + + const result = await client.getClient(); + expect(client.thrift.createClient.called).to.be.true; + expect(client.createConnection.called).to.be.true; + expect(result).to.be.equal(thriftClient); + }); + + it('should re-create client if auth credentials change', async () => { + const client = new DBSQLClient(); + + const thriftClient = {}; + + client.authProvider = new AuthProviderMock(); + client.connectionOptions = { ...options }; + client.thrift = { + createClient: sinon.stub().returns(thriftClient), + }; + sinon.stub(client, 'createConnection').returns({ + getConnection: () => null, + }); + + // initialize client + firstCall: { + const result = await client.getClient(); + expect(client.thrift.createClient.callCount).to.be.equal(1); + expect(client.createConnection.callCount).to.be.equal(1); + expect(result).to.be.equal(thriftClient); + } + + // credentials stay the same, client should not be re-created + secondCall: { + const result = await client.getClient(); + expect(client.thrift.createClient.callCount).to.be.equal(1); + expect(client.createConnection.callCount).to.be.equal(1); + expect(result).to.be.equal(thriftClient); + } + + // change credentials mock - client should be re-created + thirdCall: { + client.authProvider.authResult = { b: 2 }; + + const result = await client.getClient(); + expect(client.thrift.createClient.callCount).to.be.equal(2); + expect(client.createConnection.callCount).to.be.equal(2); + expect(result).to.be.equal(thriftClient); + } + }); +}); + +describe('DBSQLClient.createConnection', () => { + afterEach(() => { + HttpConnectionModule.default.restore?.(); + }); + + it('should create connection', async () => { + const thriftConnection = { + on: sinon.stub(), + }; + + const connectionMock = { + getConnection: sinon.stub().returns(thriftConnection), + }; + + const connectionProviderMock = { + connect: sinon.stub().returns(Promise.resolve(connectionMock)), + }; + + sinon.stub(HttpConnectionModule, 'default').returns(connectionProviderMock); + + const client = new DBSQLClient(); + + const result = await client.createConnection({}); + expect(result).to.be.equal(connectionMock); + expect(connectionProviderMock.connect.called).to.be.true; + expect(connectionMock.getConnection.called).to.be.true; + expect(thriftConnection.on.called).to.be.true; + }); }); describe('DBSQLClient.close', () => { it('should close the connection if it was initiated', async () => { const client = new DBSQLClient(); + client.client = {}; client.authProvider = {}; client.connectionOptions = {}; await client.close(); + expect(client.client).to.be.null; expect(client.authProvider).to.be.null; expect(client.connectionOptions).to.be.null; // No additional asserts needed - it should just reach this point @@ -185,6 +267,7 @@ describe('DBSQLClient.close', () => { const client = new DBSQLClient(); await client.close(); + expect(client.client).to.be.null; expect(client.authProvider).to.be.null; expect(client.connectionOptions).to.be.null; // No additional asserts needed - it should just reach this point diff --git a/tests/unit/hive/HiveDriver.test.js b/tests/unit/hive/HiveDriver.test.js index 54511f23..59c665f8 100644 --- a/tests/unit/hive/HiveDriver.test.js +++ b/tests/unit/hive/HiveDriver.test.js @@ -1,20 +1,24 @@ const { expect } = require('chai'); +const sinon = require('sinon'); const { TCLIService_types } = require('../../../').thrift; const HiveDriver = require('../../../dist/hive/HiveDriver').default; const toTitleCase = (str) => str[0].toUpperCase() + str.slice(1); -const testCommand = (command, request) => { +const testCommand = async (command, request) => { const client = {}; - const driver = new HiveDriver(() => Promise.resolve(client)); + const clientFactory = sinon.stub().returns(Promise.resolve(client)); + const driver = new HiveDriver(clientFactory); + const response = { response: 'value' }; client[toTitleCase(command)] = function (req, cb) { expect(req).to.be.deep.eq(new TCLIService_types[`T${toTitleCase(command)}Req`](request)); cb(null, response); }; - return driver[command](request).then((resp) => { - expect(resp).to.be.deep.eq(response); - }); + + const resp = await driver[command](request); + expect(resp).to.be.deep.eq(response); + expect(clientFactory.called).to.be.true; }; describe('HiveDriver', () => { diff --git a/tests/unit/utils.test.js b/tests/unit/utils.test.js index 1ecf3d49..62aa7668 100644 --- a/tests/unit/utils.test.js +++ b/tests/unit/utils.test.js @@ -1,6 +1,12 @@ const { expect } = require('chai'); -const { buildUserAgentString, definedOrError, formatProgress, ProgressUpdateTransformer } = require('../../dist/utils'); +const { + areHeadersEqual, + buildUserAgentString, + definedOrError, + formatProgress, + ProgressUpdateTransformer, +} = require('../../dist/utils'); describe('buildUserAgentString', () => { // It should follow https://www.rfc-editor.org/rfc/rfc7231#section-5.5.3 and @@ -91,3 +97,100 @@ describe('definedOrError', () => { }).to.throw(); }); }); + +describe('areHeadersEqual', () => { + it('should return true for same objects', () => { + const a = {}; + expect(areHeadersEqual(a, a)).to.be.true; + }); + + it('should return false for objects with different keys', () => { + const a = { a: 1, x: 2 }; + const b = { b: 3, x: 4 }; + const c = { c: 5 }; + + expect(areHeadersEqual(a, b)).to.be.false; + expect(areHeadersEqual(b, a)).to.be.false; + expect(areHeadersEqual(a, c)).to.be.false; + expect(areHeadersEqual(c, a)).to.be.false; + }); + + it('should compare different types of properties', () => { + case1: { + expect( + areHeadersEqual( + { + a: 1, + b: 'b', + c: ['x', 'y', 'z'], + }, + { + a: 1, + b: 'b', + c: ['x', 'y', 'z'], + }, + ), + ).to.be.true; + } + + case2: { + const arr = ['a', 'b']; + + expect( + areHeadersEqual( + { + a: 1, + b: 'b', + c: arr, + }, + { + a: 1, + b: 'b', + c: arr, + }, + ), + ).to.be.true; + } + + case3: { + expect( + areHeadersEqual( + { + arr: ['a'], + }, + { + arr: ['b'], + }, + ), + ).to.be.false; + } + + case4: { + expect( + areHeadersEqual( + { + arr: ['a'], + }, + { + arr: ['a', 'b'], + }, + ), + ).to.be.false; + } + + case5: { + expect( + areHeadersEqual( + { + arr: ['a'], + prop: 'x', + }, + { + arr: ['a'], + prop: 1, + }, + ), + ).to.be.false; + } + }); +});