From d0eca3ed7f41e70e7264388d1f9bf97ec153fb7e Mon Sep 17 00:00:00 2001 From: Connor Sheehan Date: Wed, 30 Oct 2024 19:34:08 -0400 Subject: [PATCH] bigquery: use `ArrayQueryParameter` for `REPEATED` mode fields --- stats.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/stats.py b/stats.py index c0304fa..b0e576c 100755 --- a/stats.py +++ b/stats.py @@ -432,13 +432,17 @@ def get_last_run_timestamp(bq_client: bigquery.Client) -> Optional[datetime]: return rows[0].last_run -def load_table_schema(bq_client: bigquery.Client, table_id: str) -> dict[str, str]: +def load_table_schema( + bq_client: bigquery.Client, table_id: str +) -> dict[str, bigquery.TableFieldSchema]: """Load a mapping of `field` -> `bigquery type` for the given table.""" table = bq_client.get_table(table_id) - return {field.name: field.field_type for field in table.schema} + return {field.name: field for field in table.schema} -def load_table_schemas(bq_client: bigquery.Client) -> dict[str, dict[str, str]]: +def load_table_schemas( + bq_client: bigquery.Client, +) -> dict[str, dict[str, bigquery.TableFieldSchema]]: """Return a mapping of each table ID to the table field->type schema.""" return { BQ_REVISIONS_TABLE_ID: load_table_schema(bq_client, BQ_REVISIONS_TABLE_ID), @@ -480,7 +484,7 @@ def merge_into_bigquery( bq_client: bigquery.Client, table_id: str, id_column: str, - table_schema_mapping: dict[str, str], + table_schema_mapping: dict[str, bigquery.TableFieldSchema], row: dict[str, Any], ): """Use a `MERGE` statement to upsert rows into BigQuery. @@ -505,8 +509,17 @@ def merge_into_bigquery( "id", table_schema_mapping[id_column], row[id_column] ) ] + [ - bigquery.ScalarQueryParameter(f"param_{key}", table_schema_mapping[key], value) - for key, value in row.items() + ( + bigquery.ScalarQueryParameter( + f"param_{key}", table_schema_mapping[key].field_type, field + ) + if table_schema_mapping[key].mode != "REPEATED" + # Use the `ArrayQueryParameter` for `REPEATED` mode fields. + else bigquery.ArrayQueryParameter( + f"param_{key}", table_schema_mapping[key].field_type, field + ) + ) + for key, field in row.items() ] job_config = bigquery.QueryJobConfig(query_parameters=query_parameters) @@ -517,7 +530,7 @@ def merge_into_bigquery( def submit_to_bigquery( bq_client: bigquery.Client, - table_schema_mappings: dict[str, dict[str, str]], + table_schema_mappings: dict[str, dict[str, bigquery.TableFieldSchema]], table_id: str, id_column: str, rows: list[dict],