Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore #255

Merged
merged 6 commits into from
Apr 29, 2024
Merged

Chore #255

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import thrift from 'thrift';
import Int64 from 'node-int64';

import { EventEmitter } from 'events';
import TCLIService from '../thrift/TCLIService';
Expand All @@ -7,7 +8,6 @@ import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } fr
import IDriver from './contracts/IDriver';
import IClientContext, { ClientConfig } from './contracts/IClientContext';
import HiveDriver from './hive/HiveDriver';
import { Int64 } from './hive/Types';
import DBSQLSession from './DBSQLSession';
import IDBSQLSession from './contracts/IDBSQLSession';
import IAuthentication from './connection/contracts/IAuthentication';
Expand Down Expand Up @@ -73,6 +73,9 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I

private static getDefaultConfig(): ClientConfig {
return {
directResultsDefaultMaxRows: 100000,
fetchChunkDefaultMaxRows: 100000,

arrowEnabled: true,
useArrowNativeTypes: true,
socketTimeout: 15 * 60 * 1000, // 15 minutes
Expand Down
8 changes: 4 additions & 4 deletions lib/DBSQLOperation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ import { definedOrError } from './utils';
import HiveDriverError from './errors/HiveDriverError';
import IClientContext from './contracts/IClientContext';

const defaultMaxRows = 100000;

interface DBSQLOperationConstructorOptions {
handle: TOperationHandle;
directResults?: TSparkDirectResults;
Expand Down Expand Up @@ -164,8 +162,10 @@ export default class DBSQLOperation implements IOperation {
setTimeout(resolve, 0);
});

const defaultMaxRows = this.context.getConfig().fetchChunkDefaultMaxRows;

const result = resultHandler.fetchNext({
limit: options?.maxRows || defaultMaxRows,
limit: options?.maxRows ?? defaultMaxRows,
disableBuffering: options?.disableBuffering,
});
await this.failIfClosed();
Expand All @@ -174,7 +174,7 @@ export default class DBSQLOperation implements IOperation {
.getLogger()
.log(
LogLevel.debug,
`Fetched chunk of size: ${options?.maxRows || defaultMaxRows} from operation with id: ${this.id}`,
`Fetched chunk of size: ${options?.maxRows ?? defaultMaxRows} from operation with id: ${this.id}`,
);
return result;
}
Expand Down
57 changes: 39 additions & 18 deletions lib/DBSQLSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import * as path from 'path';
import stream from 'node:stream';
import util from 'node:util';
import { stringify, NIL } from 'uuid';
import Int64 from 'node-int64';
import fetch, { HeadersInit } from 'node-fetch';
import {
TSessionHandle,
Expand All @@ -12,7 +13,6 @@ import {
TSparkArrowTypes,
TSparkParameter,
} from '../thrift/TCLIService_types';
import { Int64 } from './hive/Types';
import IDBSQLSession, {
ExecuteStatementOptions,
TypeInfoRequest,
Expand Down Expand Up @@ -41,22 +41,35 @@ import IClientContext, { ClientConfig } from './contracts/IClientContext';
// Explicitly promisify a callback-style `pipeline` because `node:stream/promises` is not available in Node 14
const pipeline = util.promisify(stream.pipeline);

const defaultMaxRows = 100000;

interface OperationResponseShape {
status: TStatus;
operationHandle?: TOperationHandle;
directResults?: TSparkDirectResults;
}

function getDirectResultsOptions(maxRows: number | null = defaultMaxRows) {
export function numberToInt64(value: number | bigint | Int64): Int64 {
if (value instanceof Int64) {
return value;
}

if (typeof value === 'bigint') {
const buffer = new ArrayBuffer(BigInt64Array.BYTES_PER_ELEMENT);
const view = new DataView(buffer);
view.setBigInt64(0, value, false); // `false` to use big-endian order
return new Int64(Buffer.from(buffer));
}

return new Int64(value);
}

function getDirectResultsOptions(maxRows: number | bigint | Int64 | null | undefined, config: ClientConfig) {
if (maxRows === null) {
return {};
}

return {
getDirectResults: {
maxRows: new Int64(maxRows),
maxRows: numberToInt64(maxRows ?? config.directResultsDefaultMaxRows),
},
};
}
Expand Down Expand Up @@ -86,7 +99,6 @@ function getArrowOptions(config: ClientConfig): {
}

function getQueryParameters(
sessionHandle: TSessionHandle,
namedParameters?: Record<string, DBSQLParameter | DBSQLParameterValue>,
ordinalParameters?: Array<DBSQLParameter | DBSQLParameterValue>,
): Array<TSparkParameter> {
Expand Down Expand Up @@ -184,12 +196,12 @@ export default class DBSQLSession implements IDBSQLSession {
const operationPromise = driver.executeStatement({
sessionHandle: this.sessionHandle,
statement,
queryTimeout: options.queryTimeout,
queryTimeout: options.queryTimeout ? numberToInt64(options.queryTimeout) : undefined,
runAsync: true,
...getDirectResultsOptions(options.maxRows),
...getDirectResultsOptions(options.maxRows, clientConfig),
...getArrowOptions(clientConfig),
canDownloadResult: options.useCloudFetch ?? clientConfig.useCloudFetch,
parameters: getQueryParameters(this.sessionHandle, options.namedParameters, options.ordinalParameters),
parameters: getQueryParameters(options.namedParameters, options.ordinalParameters),
canDecompressLZ4Result: clientConfig.useLZ4Compression && Boolean(LZ4),
});
const response = await this.handleResponse(operationPromise);
Expand Down Expand Up @@ -339,10 +351,11 @@ export default class DBSQLSession implements IDBSQLSession {
public async getTypeInfo(request: TypeInfoRequest = {}): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getTypeInfo({
sessionHandle: this.sessionHandle,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -357,10 +370,11 @@ export default class DBSQLSession implements IDBSQLSession {
public async getCatalogs(request: CatalogsRequest = {}): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getCatalogs({
sessionHandle: this.sessionHandle,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -375,12 +389,13 @@ export default class DBSQLSession implements IDBSQLSession {
public async getSchemas(request: SchemasRequest = {}): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getSchemas({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -395,14 +410,15 @@ export default class DBSQLSession implements IDBSQLSession {
public async getTables(request: TablesRequest = {}): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getTables({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
tableName: request.tableName,
tableTypes: request.tableTypes,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -417,10 +433,11 @@ export default class DBSQLSession implements IDBSQLSession {
public async getTableTypes(request: TableTypesRequest = {}): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getTableTypes({
sessionHandle: this.sessionHandle,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -435,14 +452,15 @@ export default class DBSQLSession implements IDBSQLSession {
public async getColumns(request: ColumnsRequest = {}): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getColumns({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
tableName: request.tableName,
columnName: request.columnName,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -457,13 +475,14 @@ export default class DBSQLSession implements IDBSQLSession {
public async getFunctions(request: FunctionsRequest): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getFunctions({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
functionName: request.functionName,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -472,13 +491,14 @@ export default class DBSQLSession implements IDBSQLSession {
public async getPrimaryKeys(request: PrimaryKeysRequest): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getPrimaryKeys({
sessionHandle: this.sessionHandle,
catalogName: request.catalogName,
schemaName: request.schemaName,
tableName: request.tableName,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand All @@ -493,6 +513,7 @@ export default class DBSQLSession implements IDBSQLSession {
public async getCrossReference(request: CrossReferenceRequest): Promise<IOperation> {
await this.failIfClosed();
const driver = await this.context.getDriver();
const clientConfig = this.context.getConfig();
const operationPromise = driver.getCrossReference({
sessionHandle: this.sessionHandle,
parentCatalogName: request.parentCatalogName,
Expand All @@ -502,7 +523,7 @@ export default class DBSQLSession implements IDBSQLSession {
foreignSchemaName: request.foreignSchemaName,
foreignTableName: request.foreignTableName,
runAsync: true,
...getDirectResultsOptions(request.maxRows),
...getDirectResultsOptions(request.maxRows, clientConfig),
});
const response = await this.handleResponse(operationPromise);
return this.createOperation(response);
Expand Down
7 changes: 4 additions & 3 deletions lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import open from 'open';
import { LogLevel } from '../../../contracts/IDBSQLLogger';
import { OAuthScopes, scopeDelimiter } from './OAuthScope';
import IClientContext from '../../../contracts/IClientContext';
import AuthenticationError from '../../../errors/AuthenticationError';

export interface AuthorizationCodeOptions {
client: BaseClient;
Expand Down Expand Up @@ -113,9 +114,9 @@ export default class AuthorizationCode {
if (!receivedParams || !receivedParams.code) {
if (receivedParams?.error) {
const errorMessage = `OAuth error: ${receivedParams.error} ${receivedParams.error_description}`;
throw new Error(errorMessage);
throw new AuthenticationError(errorMessage);
}
throw new Error(`No path parameters were returned to the callback at ${redirectUri}`);
throw new AuthenticationError(`No path parameters were returned to the callback at ${redirectUri}`);
}

return { code: receivedParams.code, verifier: verifierString, redirectUri };
Expand Down Expand Up @@ -152,7 +153,7 @@ export default class AuthorizationCode {
}
}

throw new Error('Failed to start server: all ports are in use');
throw new AuthenticationError('Failed to start server: all ports are in use');
}

private renderCallbackResponse(): string {
Expand Down
12 changes: 6 additions & 6 deletions lib/connection/auth/DatabricksOAuth/OAuthManager.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import http from 'http';
import { Issuer, BaseClient, custom } from 'openid-client';
import HiveDriverError from '../../../errors/HiveDriverError';
import AuthenticationError from '../../../errors/AuthenticationError';
import { LogLevel } from '../../../contracts/IDBSQLLogger';
import OAuthToken from './OAuthToken';
import AuthorizationCode from './AuthorizationCode';
Expand Down Expand Up @@ -104,7 +104,7 @@ export default abstract class OAuthManager {
if (!token.refreshToken) {
const message = `OAuth access token expired on ${token.expirationTime}.`;
this.context.getLogger().log(LogLevel.error, message);
throw new HiveDriverError(message);
throw new AuthenticationError(message);
}

// Try to refresh using the refresh token
Expand All @@ -115,7 +115,7 @@ export default abstract class OAuthManager {
const client = await this.getClient();
const { access_token: accessToken, refresh_token: refreshToken } = await client.refresh(token.refreshToken);
if (!accessToken || !refreshToken) {
throw new Error('Failed to refresh token: invalid response');
throw new AuthenticationError('Failed to refresh token: invalid response');
}
return new OAuthToken(accessToken, refreshToken, token.scopes);
}
Expand Down Expand Up @@ -165,7 +165,7 @@ export default abstract class OAuthManager {
});

if (!accessToken) {
throw new Error('Failed to fetch access token');
throw new AuthenticationError('Failed to fetch access token');
}
return new OAuthToken(accessToken, refreshToken, mappedScopes);
}
Expand All @@ -185,7 +185,7 @@ export default abstract class OAuthManager {
});

if (!accessToken) {
throw new Error('Failed to fetch access token');
throw new AuthenticationError('Failed to fetch access token');
}
return new OAuthToken(accessToken, undefined, mappedScopes);
}
Expand Down Expand Up @@ -234,7 +234,7 @@ export default abstract class OAuthManager {
}
}

throw new Error(`OAuth is not supported for ${options.host}`);
throw new AuthenticationError(`OAuth is not supported for ${options.host}`);
}
}

Expand Down
3 changes: 3 additions & 0 deletions lib/contracts/IClientContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import IConnectionProvider from '../connection/contracts/IConnectionProvider';
import TCLIService from '../../thrift/TCLIService';

export interface ClientConfig {
directResultsDefaultMaxRows: number;
fetchChunkDefaultMaxRows: number;

arrowEnabled?: boolean;
useArrowNativeTypes?: boolean;
socketTimeout: number;
Expand Down
Loading
Loading