Skip to content

Commit

Permalink
Merge pull request #154 from planetscale/function-overloading
Browse files Browse the repository at this point in the history
Allow a type parameter for the row results
  • Loading branch information
ayrton authored Dec 14, 2023
2 parents f4f7177 + 06a3d9a commit 0a7eeee
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 42 deletions.
24 changes: 12 additions & 12 deletions __tests__/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import SqlString from 'sqlstring'
import { cast, connect, format, hex, ExecutedQuery, DatabaseError } from '../dist/index'
import { cast, connect, format, hex, DatabaseError } from '../dist/index'
import { fetch, MockAgent, setGlobalDispatcher } from 'undici'
import packageJSON from '../package.json'

Expand Down Expand Up @@ -143,7 +143,7 @@ describe('execute', () => {
timing: 1
}

const want: ExecutedQuery = {
const want = {
headers: [':vtg1', 'null'],
types: { ':vtg1': 'INT32', null: 'NULL' },
fields: [
Expand Down Expand Up @@ -192,7 +192,7 @@ describe('execute', () => {
timing: 1
}

const want: ExecutedQuery = {
const want = {
headers: ['null'],
types: { null: 'NULL' },
fields: [{ name: 'null', type: 'NULL' }],
Expand Down Expand Up @@ -238,7 +238,7 @@ describe('execute', () => {
timing: 1
}

const want: ExecutedQuery = {
const want = {
headers: [':vtg1'],
types: { ':vtg1': 'INT32' },
rows: [[1]],
Expand Down Expand Up @@ -273,7 +273,7 @@ describe('execute', () => {
mockPool.intercept({ path: EXECUTE_PATH, method: 'POST' }).reply(200, mockResponse)

const query = 'CREATE TABLE `foo` (bar json);'
const want: ExecutedQuery = {
const want = {
headers: [],
types: {},
fields: [],
Expand Down Expand Up @@ -303,7 +303,7 @@ describe('execute', () => {
mockPool.intercept({ path: EXECUTE_PATH, method: 'POST' }).reply(200, mockResponse)

const query = "UPDATE `foo` SET bar='planetscale'"
const want: ExecutedQuery = {
const want = {
headers: [],
types: {},
fields: [],
Expand Down Expand Up @@ -334,7 +334,7 @@ describe('execute', () => {
mockPool.intercept({ path: EXECUTE_PATH, method: 'POST' }).reply(200, mockResponse)

const query = "INSERT INTO `foo` (bar) VALUES ('planetscale');"
const want: ExecutedQuery = {
const want = {
headers: [],
types: {},
fields: [],
Expand Down Expand Up @@ -418,7 +418,7 @@ describe('execute', () => {
timing: 1
}

const want: ExecutedQuery = {
const want = {
headers: [':vtg1'],
rows: [{ ':vtg1': 1 }],
types: { ':vtg1': 'INT32' },
Expand Down Expand Up @@ -452,7 +452,7 @@ describe('execute', () => {
timing: 1
}

const want: ExecutedQuery = {
const want = {
headers: [':vtg1'],
types: { ':vtg1': 'INT32' },
fields: [{ name: ':vtg1', type: 'INT32' }],
Expand Down Expand Up @@ -486,7 +486,7 @@ describe('execute', () => {
timing: 1
}

const want: ExecutedQuery = {
const want = {
headers: [':vtg1'],
types: { ':vtg1': 'INT64' },
fields: [{ name: ':vtg1', type: 'INT64' }],
Expand Down Expand Up @@ -521,7 +521,7 @@ describe('execute', () => {
timing: 1
}

const want: ExecutedQuery = {
const want = {
headers: [':vtg1'],
types: { ':vtg1': 'INT64' },
fields: [{ name: ':vtg1', type: 'INT64' }],
Expand Down Expand Up @@ -558,7 +558,7 @@ describe('execute', () => {
timing: 1
}

const want: ExecutedQuery = {
const want = {
headers: ['document'],
types: { document: 'JSON' },
fields: [{ name: 'document', type: 'JSON' }],
Expand Down
81 changes: 51 additions & 30 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export { hex } from './text.js'
import { decode } from './text.js'
import { Version } from './version.js'

type Row = Record<string, any> | any[]
type Row<T extends ExecuteAs = 'object'> = T extends 'array' ? any[] : T extends 'object' ? Record<string, any> : never

interface VitessError {
message: string
Expand All @@ -24,10 +24,10 @@ export class DatabaseError extends Error {

type Types = Record<string, string>

export interface ExecutedQuery {
export interface ExecutedQuery<T = Row<'array'> | Row<'object'>> {
headers: string[]
types: Types
rows: Row[]
rows: T[]
fields: Field[]
size: number
statement: string
Expand Down Expand Up @@ -102,17 +102,8 @@ interface QueryResult {

type ExecuteAs = 'array' | 'object'

type ExecuteOptions = {
as?: ExecuteAs
cast?: Cast
}

type ExecuteArgs = object | any[] | null

const defaultExecuteOptions: ExecuteOptions = {
as: 'object'
}

export class Client {
private config: Config

Expand All @@ -124,12 +115,22 @@ export class Client {
return this.connection().transaction(fn)
}

async execute(
async execute<T = Row<'object'>>(
query: string,
args?: ExecuteArgs,
options?: { as?: 'object'; cast?: Cast }
): Promise<ExecutedQuery<T>>
async execute<T = Row<'array'>>(
query: string,
args: ExecuteArgs,
options: { as: 'array'; cast?: Cast }
): Promise<ExecutedQuery<T>>
async execute<T = Row<'object'> | Row<'array'>>(
query: string,
args: ExecuteArgs = null,
options: ExecuteOptions = defaultExecuteOptions
): Promise<ExecutedQuery> {
return this.connection().execute(query, args, options)
options: any = { as: 'object' }
): Promise<ExecutedQuery<T>> {
return this.connection().execute<T>(query, args, options)
}

connection(): Connection {
Expand All @@ -146,12 +147,22 @@ class Tx {
this.conn = conn
}

async execute(
async execute<T = Row<'object'>>(
query: string,
args?: ExecuteArgs,
options?: { as?: 'object'; cast?: Cast }
): Promise<ExecutedQuery<T>>
async execute<T = Row<'array'>>(
query: string,
args: ExecuteArgs,
options: { as: 'array'; cast?: Cast }
): Promise<ExecutedQuery<T>>
async execute<T = Row<'object'> | Row<'array'>>(
query: string,
args: ExecuteArgs = null,
options: ExecuteOptions = defaultExecuteOptions
): Promise<ExecutedQuery> {
return this.conn.execute(query, args, options)
options: any = { as: 'object' }
): Promise<ExecutedQuery<T>> {
return this.conn.execute<T>(query, args, options)
}
}

Expand Down Expand Up @@ -209,11 +220,21 @@ export class Connection {
await this.createSession()
}

async execute(
async execute<T = Row<'object'>>(
query: string,
args?: ExecuteArgs,
options?: { as?: 'object'; cast?: Cast }
): Promise<ExecutedQuery<T>>
async execute<T = Row<'array'>>(
query: string,
args: ExecuteArgs,
options: { as: 'array'; cast?: Cast }
): Promise<ExecutedQuery<T>>
async execute<T = Row<'object'> | Row<'array'>>(
query: string,
args: ExecuteArgs = null,
options: ExecuteOptions = defaultExecuteOptions
): Promise<ExecutedQuery> {
options: any = { as: 'object' }
): Promise<ExecutedQuery<T>> {
const url = new URL('/psdb.v1alpha1.Database/Execute', this.url)

const formatter = this.config.format || format
Expand Down Expand Up @@ -244,7 +265,7 @@ export class Connection {
}

const castFn = options.cast || this.config.cast || cast
const rows = result ? parse(result, castFn, options.as || 'object') : []
const rows = result ? parse<T>(result, castFn, options.as || 'object') : []
const headers = fields.map((f) => f.name)

const typeByName = (acc, { name, type }) => ({ ...acc, [name]: type })
Expand Down Expand Up @@ -307,28 +328,28 @@ export function connect(config: Config): Connection {
return new Connection(config)
}

function parseArrayRow(fields: Field[], rawRow: QueryResultRow, cast: Cast): Row {
function parseArrayRow<T = Row<'array'>>(fields: Field[], rawRow: QueryResultRow, cast: Cast): T {
const row = decodeRow(rawRow)

return fields.map((field, ix) => {
return cast(field, row[ix])
})
}) as T
}

function parseObjectRow(fields: Field[], rawRow: QueryResultRow, cast: Cast): Row {
function parseObjectRow<T = Row<'object'>>(fields: Field[], rawRow: QueryResultRow, cast: Cast): T {
const row = decodeRow(rawRow)

return fields.reduce((acc, field, ix) => {
acc[field.name] = cast(field, row[ix])
return acc
}, {} as Row)
}, {} as T)
}

function parse(result: QueryResult, cast: Cast, returnAs: ExecuteAs): Row[] {
function parse<T>(result: QueryResult, cast: Cast, returnAs: ExecuteAs): T[] {
const fields = result.fields
const rows = result.rows ?? []
return rows.map((row) =>
returnAs === 'array' ? parseArrayRow(fields, row, cast) : parseObjectRow(fields, row, cast)
returnAs === 'array' ? parseArrayRow<T>(fields, row, cast) : parseObjectRow<T>(fields, row, cast)
)
}

Expand Down

0 comments on commit 0a7eeee

Please sign in to comment.