Skip to content

Commit

Permalink
Merge branch 'main' into iterable-operation
Browse files Browse the repository at this point in the history
  • Loading branch information
kravets-levko committed Apr 29, 2024
2 parents 0c0b529 + c239fca commit 513c2f8
Show file tree
Hide file tree
Showing 70 changed files with 525 additions and 236 deletions.
5 changes: 4 additions & 1 deletion lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import thrift from 'thrift';
import Int64 from 'node-int64';

import { EventEmitter } from 'events';
import TCLIService from '../thrift/TCLIService';
Expand All @@ -7,7 +8,6 @@ import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } fr
import IDriver from './contracts/IDriver';
import IClientContext, { ClientConfig } from './contracts/IClientContext';
import HiveDriver from './hive/HiveDriver';
import { Int64 } from './hive/Types';
import DBSQLSession from './DBSQLSession';
import IDBSQLSession from './contracts/IDBSQLSession';
import IAuthentication from './connection/contracts/IAuthentication';
Expand Down Expand Up @@ -73,6 +73,9 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I

private static getDefaultConfig(): ClientConfig {
return {
directResultsDefaultMaxRows: 100000,
fetchChunkDefaultMaxRows: 100000,

arrowEnabled: true,
useArrowNativeTypes: true,
socketTimeout: 15 * 60 * 1000, // 15 minutes
Expand Down
8 changes: 4 additions & 4 deletions lib/DBSQLOperation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ import { OperationChunksIterator, OperationRowsIterator } from './utils/Operatio
import HiveDriverError from './errors/HiveDriverError';
import IClientContext from './contracts/IClientContext';

const defaultMaxRows = 100000;

interface DBSQLOperationConstructorOptions {
handle: TOperationHandle;
directResults?: TSparkDirectResults;
Expand Down Expand Up @@ -176,8 +174,10 @@ export default class DBSQLOperation implements IOperation {
setTimeout(resolve, 0);
});

const defaultMaxRows = this.context.getConfig().fetchChunkDefaultMaxRows;

const result = resultHandler.fetchNext({
limit: options?.maxRows || defaultMaxRows,
limit: options?.maxRows ?? defaultMaxRows,
disableBuffering: options?.disableBuffering,
});
await this.failIfClosed();
Expand All @@ -186,7 +186,7 @@ export default class DBSQLOperation implements IOperation {
.getLogger()
.log(
LogLevel.debug,
`Fetched chunk of size: ${options?.maxRows || defaultMaxRows} from operation with id: ${this.id}`,
`Fetched chunk of size: ${options?.maxRows ?? defaultMaxRows} from operation with id: ${this.id}`,
);
return result;
}
Expand Down
57 changes: 39 additions & 18 deletions lib/DBSQLSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import * as path from 'path';
import stream from 'node:stream';
import util from 'node:util';
import { stringify, NIL } from 'uuid';
import Int64 from 'node-int64';
import fetch, { HeadersInit } from 'node-fetch';
import {
TSessionHandle,
Expand All @@ -12,7 +13,6 @@ import {
TSparkArrowTypes,
TSparkParameter,
} from '../thrift/TCLIService_types';
import { Int64 } from './hive/Types';
import IDBSQLSession, {
ExecuteStatementOptions,
TypeInfoRequest,
Expand Down Expand Up @@ -41,22 +41,35 @@ import IClientContext, { ClientConfig } from './contracts/IClientContext';
// Explicitly promisify a callback-style `pipeline` because `node:stream/promises` is not available in Node 14
const pipeline = util.promisify(stream.pipeline);

const defaultMaxRows = 100000;

interface OperationResponseShape {
status: TStatus;
operationHandle?: TOperationHandle;
directResults?: TSparkDirectResults;
}

function getDirectResultsOptions(maxRows: number | null = defaultMaxRows) {
export function numberToInt64(value: number | bigint | Int64): Int64 {
if (value instanceof Int64) {
return value;
}

if (typeof value === 'bigint') {
const buffer = new ArrayBuffer(BigInt64Array.BYTES_PER_ELEMENT);
const view = new DataView(buffer);
view.setBigInt64(0, value, false); // `false` to use big-endian order
return new Int64(Buffer.from(buffer));
}

return new Int64(value);
}

function getDirectResultsOptions(maxRows: number | bigint | Int64 | null | undefined, config: ClientConfig) {
if (maxRows === null) {
return {};
}

return {
getDirectResults: {
maxRows: new Int64(maxRows),
maxRows: numberToInt64(maxRows ?? config.directResultsDefaultMaxRows),
},
};
}
Expand Down Expand Up @@ -86,7 +99,6 @@ function getArrowOptions(config: ClientConfig): {
}

function getQueryParameters(
sessionHandle: TSessionHandle,
namedParameters?: Record<string, DBSQLParameter | DBSQLParameterValue>,
ordinalParameters?: Array<DBSQLParameter | DBSQLParameterValue>,
): Array<TSparkParameter> {
Expand Down Expand Up @@ -184,12 +196,12 @@ export default class DBSQLSession implements IDBSQLSession {
const operationPromise = driver.executeStatement({
sessionHandle: this.sessionHandle,
statement,
queryTimeout: options.queryTimeout,
queryTimeout: options.queryTimeout ? numberToInt64(options.queryTimeout) : undefined,
runAsync: true,
...getDirectResultsOptions(options.maxRows),
...getDirectResultsOptions(options.maxRows, clientConfig),
...getArrowOptions(clientConfig),
canDownloadResult: options.useCloudFetch ?? clientConfig.useCloudFetch,
parameters: getQueryParameters(this.sessionHandle, options.namedParameters, options.ordinalParameters),
parameters: getQueryParameters(options.namedParameters, options.ordinalParameters),
canDecompressLZ4Result: clientConfig.useLZ4Compression && Boolean(LZ4),
});
const response = await this.handleResponse(operationPromise);
Expand Down Expand Up @@ -339,10 +351,11 @@ export default class DBSQLSession implements IDBSQLSession {
public async getTypeInfo(request: TypeInfoRequest = {}): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getTypeInfo({
sessionHandle: this.sessionHandle,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -357,10 +370,11 @@ export default class DBSQLSession implements IDBSQLSession {
public async getCatalogs(request: CatalogsRequest = {}): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getCatalogs({
sessionHandle: this.sessionHandle,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -375,12 +389,13 @@ export default class DBSQLSession implements IDBSQLSession {
public async getSchemas(request: SchemasRequest = {}): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getSchemas({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -395,14 +410,15 @@ export default class DBSQLSession implements IDBSQLSession {
public async getTables(request: TablesRequest = {}): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getTables({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
tableName: request.tableName,
tableTypes: request.tableTypes,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -417,10 +433,11 @@ export default class DBSQLSession implements IDBSQLSession {
public async getTableTypes(request: TableTypesRequest = {}): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getTableTypes({
sessionHandle: this.sessionHandle,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -435,14 +452,15 @@ export default class DBSQLSession implements IDBSQLSession {
public async getColumns(request: ColumnsRequest = {}): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getColumns({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
tableName: request.tableName,
columnName: request.columnName,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -457,13 +475,14 @@ export default class DBSQLSession implements IDBSQLSession {
public async getFunctions(request: FunctionsRequest): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getFunctions({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
functionName: request.functionName,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -472,13 +491,14 @@ export default class DBSQLSession implements IDBSQLSession {
public async getPrimaryKeys(request: PrimaryKeysRequest): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getPrimaryKeys({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
tableName: request.tableName,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -493,6 +513,7 @@ export default class DBSQLSession implements IDBSQLSession {
public async getCrossReference(request: CrossReferenceRequest): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getCrossReference({
sessionHandle: this.sessionHandle,
parentCatalogName: request.parentCatalogName,
Expand All @@ -502,7 +523,7 @@ export default class DBSQLSession implements IDBSQLSession {
foreignSchemaName: request.foreignSchemaName,
foreignTableName: request.foreignTableName,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand Down
7 changes: 4 additions & 3 deletions lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import open from 'open';
import { LogLevel } from '../../../contracts/IDBSQLLogger';
import { OAuthScopes, scopeDelimiter } from './OAuthScope';
import IClientContext from '../../../contracts/IClientContext';
import AuthenticationError from '../../../errors/AuthenticationError';

export interface AuthorizationCodeOptions {
client: BaseClient;
Expand Down Expand Up @@ -113,9 +114,9 @@ export default class AuthorizationCode {
if (!receivedParams || !receivedParams.code) {
if (receivedParams?.error) {
const errorMessage = `OAuth error: ${receivedParams.error} ${receivedParams.error_description}`;
throw new Error(errorMessage);
throw new AuthenticationError(errorMessage);
}
throw new Error(`No path parameters were returned to the callback at ${redirectUri}`);
throw new AuthenticationError(`No path parameters were returned to the callback at ${redirectUri}`);
}

return { code: receivedParams.code, verifier: verifierString, redirectUri };
Expand Down Expand Up @@ -152,7 +153,7 @@ export default class AuthorizationCode {
}
}

throw new Error('Failed to start server: all ports are in use');
throw new AuthenticationError('Failed to start server: all ports are in use');
}

private renderCallbackResponse(): string {
Expand Down
12 changes: 6 additions & 6 deletions lib/connection/auth/DatabricksOAuth/OAuthManager.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import http from 'http';
import { Issuer, BaseClient, custom } from 'openid-client';
import HiveDriverError from '../../../errors/HiveDriverError';
import AuthenticationError from '../../../errors/AuthenticationError';
import { LogLevel } from '../../../contracts/IDBSQLLogger';
import OAuthToken from './OAuthToken';
import AuthorizationCode from './AuthorizationCode';
Expand Down Expand Up @@ -104,7 +104,7 @@ export default abstract class OAuthManager {
if (!token.refreshToken) {
const message = `OAuth access token expired on ${token.expirationTime}.`;
this.context.getLogger().log(LogLevel.error, message);
throw new HiveDriverError(message);
throw new AuthenticationError(message);
}

// Try to refresh using the refresh token
Expand All @@ -115,7 +115,7 @@ export default abstract class OAuthManager {
const client = await this.getClient();
const { access_token: accessToken, refresh_token: refreshToken } = await client.refresh(token.refreshToken);
if (!accessToken || !refreshToken) {
throw new Error('Failed to refresh token: invalid response');
throw new AuthenticationError('Failed to refresh token: invalid response');
}
return new OAuthToken(accessToken, refreshToken, token.scopes);
}
Expand Down Expand Up @@ -165,7 +165,7 @@ export default abstract class OAuthManager {
});

if (!accessToken) {
throw new Error('Failed to fetch access token');
throw new AuthenticationError('Failed to fetch access token');
}
return new OAuthToken(accessToken, refreshToken, mappedScopes);
}
Expand All @@ -185,7 +185,7 @@ export default abstract class OAuthManager {
});

if (!accessToken) {
throw new Error('Failed to fetch access token');
throw new AuthenticationError('Failed to fetch access token');
}
return new OAuthToken(accessToken, undefined, mappedScopes);
}
Expand Down Expand Up @@ -234,7 +234,7 @@ export default abstract class OAuthManager {
}
}

throw new Error(`OAuth is not supported for ${options.host}`);
throw new AuthenticationError(`OAuth is not supported for ${options.host}`);
}
}

Expand Down
3 changes: 3 additions & 0 deletions lib/contracts/IClientContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import IConnectionProvider from '../connection/contracts/IConnectionProvider';
import TCLIService from '../../thrift/TCLIService';

export interface ClientConfig {
directResultsDefaultMaxRows: number;
fetchChunkDefaultMaxRows: number;

arrowEnabled?: boolean;
useArrowNativeTypes?: boolean;
socketTimeout: number;
Expand Down
Loading

0 comments on commit 513c2f8

Please sign in to comment.