From f26d3f29d40ee30395ffc4eaf46da1608dfeb947 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Tue, 27 Feb 2024 15:18:40 +0200 Subject: [PATCH] Infer Arrow schema when it is not available Signed-off-by: Levko Kravets --- lib/result/ArrowResultHandler.ts | 7 +- lib/result/ResultSlicer.ts | 6 +- lib/result/utils.ts | 67 ++++++++++++++++++++ tests/unit/result/ArrowResultHandler.test.js | 27 ++++++++ tests/unit/result/compatibility.test.js | 39 ++++++++++-- 5 files changed, 135 insertions(+), 11 deletions(-) diff --git a/lib/result/ArrowResultHandler.ts b/lib/result/ArrowResultHandler.ts index 693750ce..2b9a3238 100644 --- a/lib/result/ArrowResultHandler.ts +++ b/lib/result/ArrowResultHandler.ts @@ -2,6 +2,7 @@ import LZ4 from 'lz4'; import { TGetResultSetMetadataResp, TRowSet } from '../../thrift/TCLIService_types'; import IClientContext from '../contracts/IClientContext'; import IResultsProvider, { ResultsProviderFetchNextOptions } from './IResultsProvider'; +import { hiveSchemaToArrowSchema } from './utils'; export default class ArrowResultHandler implements IResultsProvider> { protected readonly context: IClientContext; @@ -15,11 +16,13 @@ export default class ArrowResultHandler implements IResultsProvider, - { arrowSchema, lz4Compressed }: TGetResultSetMetadataResp, + { schema, arrowSchema, lz4Compressed }: TGetResultSetMetadataResp, ) { this.context = context; this.source = source; - this.arrowSchema = arrowSchema; + // Arrow schema is not available in old DBR versions, which also don't support native Arrow types, + // so it's possible to infer Arrow schema from Hive schema ignoring `useArrowNativeTypes` option + this.arrowSchema = arrowSchema ?? hiveSchemaToArrowSchema(schema); this.isLZ4Compressed = lz4Compressed ?? false; } diff --git a/lib/result/ResultSlicer.ts b/lib/result/ResultSlicer.ts index 0f640a9a..11a2a15f 100644 --- a/lib/result/ResultSlicer.ts +++ b/lib/result/ResultSlicer.ts @@ -52,11 +52,13 @@ export default class ResultSlicer implements IResultsProvider> { // Fetch items from source results provider until we reach a requested count while (resultsCount < options.limit) { // eslint-disable-next-line no-await-in-loop - const chunk = await this.source.fetchNext(options); - if (chunk.length === 0) { + const hasMore = await this.source.hasMore(); + if (!hasMore) { break; } + // eslint-disable-next-line no-await-in-loop + const chunk = await this.source.fetchNext(options); result.push(chunk); resultsCount += chunk.length; } diff --git a/lib/result/utils.ts b/lib/result/utils.ts index b4351df0..25e4d067 100644 --- a/lib/result/utils.ts +++ b/lib/result/utils.ts @@ -1,5 +1,23 @@ import Int64 from 'node-int64'; +import { + Schema, + Field, + DataType, + Bool as ArrowBool, + Int8 as ArrowInt8, + Int16 as ArrowInt16, + Int32 as ArrowInt32, + Int64 as ArrowInt64, + Float32 as ArrowFloat32, + Float64 as ArrowFloat64, + Utf8 as ArrowString, + Date_ as ArrowDate, + Binary as ArrowBinary, + DateUnit, + RecordBatchWriter, +} from 'apache-arrow'; import { TTableSchema, TColumnDesc, TPrimitiveTypeEntry, TTypeId } from '../../thrift/TCLIService_types'; +import HiveDriverError from '../errors/HiveDriverError'; export function getSchemaColumns(schema?: TTableSchema): Array { if (!schema) { @@ -73,3 +91,52 @@ export function convertThriftValue(typeDescriptor: TPrimitiveTypeEntry | undefin return value; } } + +// This type map corresponds to Arrow without native types support (most complex types are serialized as strings) +const hiveTypeToArrowType: Record = { + [TTypeId.BOOLEAN_TYPE]: new ArrowBool(), + [TTypeId.TINYINT_TYPE]: new ArrowInt8(), + [TTypeId.SMALLINT_TYPE]: new ArrowInt16(), + [TTypeId.INT_TYPE]: new ArrowInt32(), + [TTypeId.BIGINT_TYPE]: new ArrowInt64(), + [TTypeId.FLOAT_TYPE]: new ArrowFloat32(), + [TTypeId.DOUBLE_TYPE]: new ArrowFloat64(), + [TTypeId.STRING_TYPE]: new ArrowString(), + [TTypeId.TIMESTAMP_TYPE]: new ArrowString(), + [TTypeId.BINARY_TYPE]: new ArrowBinary(), + [TTypeId.ARRAY_TYPE]: new ArrowString(), + [TTypeId.MAP_TYPE]: new ArrowString(), + [TTypeId.STRUCT_TYPE]: new ArrowString(), + [TTypeId.UNION_TYPE]: new ArrowString(), + [TTypeId.USER_DEFINED_TYPE]: new ArrowString(), + [TTypeId.DECIMAL_TYPE]: new ArrowString(), + [TTypeId.NULL_TYPE]: null, + [TTypeId.DATE_TYPE]: new ArrowDate(DateUnit.DAY), + [TTypeId.VARCHAR_TYPE]: new ArrowString(), + [TTypeId.CHAR_TYPE]: new ArrowString(), + [TTypeId.INTERVAL_YEAR_MONTH_TYPE]: new ArrowString(), + [TTypeId.INTERVAL_DAY_TIME_TYPE]: new ArrowString(), +}; + +export function hiveSchemaToArrowSchema(schema?: TTableSchema): Buffer | undefined { + if (!schema) { + return undefined; + } + + const columns = getSchemaColumns(schema); + + const arrowFields = columns.map((column) => { + const hiveType = column.typeDesc.types[0].primitiveEntry?.type ?? undefined; + const arrowType = hiveType !== undefined ? hiveTypeToArrowType[hiveType] : undefined; + if (!arrowType) { + throw new HiveDriverError(`Unsupported column type: ${hiveType ? TTypeId[hiveType] : 'undefined'}`); + } + return new Field(column.columnName, arrowType, true); + }); + + const arrowSchema = new Schema(arrowFields); + const writer = new RecordBatchWriter(); + writer.reset(undefined, arrowSchema); + writer.finish(); + return Buffer.from(writer.toUint8Array(true)); +} diff --git a/tests/unit/result/ArrowResultHandler.test.js b/tests/unit/result/ArrowResultHandler.test.js index 0eeccf32..74bf37c3 100644 --- a/tests/unit/result/ArrowResultHandler.test.js +++ b/tests/unit/result/ArrowResultHandler.test.js @@ -127,6 +127,33 @@ describe('ArrowResultHandler', () => { } }); + it('should infer arrow schema from thrift schema', async () => { + const context = {}; + const rowSetProvider = new ResultsProviderMock([sampleRowSet2]); + + const sampleThriftSchema = { + columns: [ + { + columnName: '1', + typeDesc: { + types: [ + { + primitiveEntry: { + type: 3, + typeQualifiers: null, + }, + }, + ], + }, + position: 1, + }, + ], + }; + + const result = new ArrowResultHandler(context, rowSetProvider, { schema: sampleThriftSchema }); + expect(result.arrowSchema).to.not.be.undefined; + }); + it('should return empty array if no schema available', async () => { const context = {}; const rowSetProvider = new ResultsProviderMock([sampleRowSet2]); diff --git a/tests/unit/result/compatibility.test.js b/tests/unit/result/compatibility.test.js index 91686cd5..eb4119b5 100644 --- a/tests/unit/result/compatibility.test.js +++ b/tests/unit/result/compatibility.test.js @@ -2,6 +2,7 @@ const { expect } = require('chai'); const ArrowResultHandler = require('../../../dist/result/ArrowResultHandler').default; const ArrowResultConverter = require('../../../dist/result/ArrowResultConverter').default; const JsonResultHandler = require('../../../dist/result/JsonResultHandler').default; +const ResultSlicer = require('../../../dist/result/ResultSlicer').default; const { fixArrowResult } = require('../../fixtures/compatibility'); const fixtureColumn = require('../../fixtures/compatibility/column'); @@ -14,7 +15,10 @@ describe('Result handlers compatibility tests', () => { it('colum-based data', async () => { const context = {}; const rowSetProvider = new ResultsProviderMock(fixtureColumn.rowSets); - const result = new JsonResultHandler(context, rowSetProvider, { schema: fixtureColumn.schema }); + const result = new ResultSlicer( + context, + new JsonResultHandler(context, rowSetProvider, { schema: fixtureColumn.schema }), + ); const rows = await result.fetchNext({ limit: 10000 }); expect(rows).to.deep.equal(fixtureColumn.expected); }); @@ -22,10 +26,13 @@ describe('Result handlers compatibility tests', () => { it('arrow-based data without native types', async () => { const context = {}; const rowSetProvider = new ResultsProviderMock(fixtureArrow.rowSets); - const result = new ArrowResultConverter( + const result = new ResultSlicer( context, - new ArrowResultHandler(context, rowSetProvider, { arrowSchema: fixtureArrow.arrowSchema }), - { schema: fixtureArrow.schema }, + new ArrowResultConverter( + context, + new ArrowResultHandler(context, rowSetProvider, { arrowSchema: fixtureArrow.arrowSchema }), + { schema: fixtureArrow.schema }, + ), ); const rows = await result.fetchNext({ limit: 10000 }); expect(fixArrowResult(rows)).to.deep.equal(fixtureArrow.expected); @@ -34,12 +41,30 @@ describe('Result handlers compatibility tests', () => { it('arrow-based data with native types', async () => { const context = {}; const rowSetProvider = new ResultsProviderMock(fixtureArrowNT.rowSets); - const result = new ArrowResultConverter( + const result = new ResultSlicer( context, - new ArrowResultHandler(context, rowSetProvider, { arrowSchema: fixtureArrowNT.arrowSchema }), - { schema: fixtureArrowNT.schema }, + new ArrowResultConverter( + context, + new ArrowResultHandler(context, rowSetProvider, { arrowSchema: fixtureArrowNT.arrowSchema }), + { schema: fixtureArrowNT.schema }, + ), ); const rows = await result.fetchNext({ limit: 10000 }); expect(fixArrowResult(rows)).to.deep.equal(fixtureArrowNT.expected); }); + + it('should infer arrow schema from thrift schema', async () => { + const context = {}; + const rowSetProvider = new ResultsProviderMock(fixtureArrow.rowSets); + const result = new ResultSlicer( + context, + new ArrowResultConverter( + context, + new ArrowResultHandler(context, rowSetProvider, { schema: fixtureArrow.schema }), + { schema: fixtureArrow.schema }, + ), + ); + const rows = await result.fetchNext({ limit: 10000 }); + expect(fixArrowResult(rows)).to.deep.equal(fixtureArrow.expected); + }); });