Skip to content

Commit

Permalink
bigquery: use ArrayQueryParameter for REPEATED mode fields
Browse files Browse the repository at this point in the history
  • Loading branch information
cgsheeh committed Oct 30, 2024
1 parent 0598477 commit d0eca3e
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,13 +432,17 @@ def get_last_run_timestamp(bq_client: bigquery.Client) -> Optional[datetime]:
return rows[0].last_run


def load_table_schema(bq_client: bigquery.Client, table_id: str) -> dict[str, str]:
def load_table_schema(
bq_client: bigquery.Client, table_id: str
) -> dict[str, bigquery.TableFieldSchema]:
"""Load a mapping of `field` -> `bigquery type` for the given table."""
table = bq_client.get_table(table_id)
return {field.name: field.field_type for field in table.schema}
return {field.name: field for field in table.schema}


def load_table_schemas(bq_client: bigquery.Client) -> dict[str, dict[str, str]]:
def load_table_schemas(
bq_client: bigquery.Client,
) -> dict[str, dict[str, bigquery.TableFieldSchema]]:
"""Return a mapping of each table ID to the table field->type schema."""
return {
BQ_REVISIONS_TABLE_ID: load_table_schema(bq_client, BQ_REVISIONS_TABLE_ID),
Expand Down Expand Up @@ -480,7 +484,7 @@ def merge_into_bigquery(
bq_client: bigquery.Client,
table_id: str,
id_column: str,
table_schema_mapping: dict[str, str],
table_schema_mapping: dict[str, bigquery.TableFieldSchema],
row: dict[str, Any],
):
"""Use a `MERGE` statement to upsert rows into BigQuery.
Expand All @@ -505,8 +509,17 @@ def merge_into_bigquery(
"id", table_schema_mapping[id_column], row[id_column]
)
] + [
bigquery.ScalarQueryParameter(f"param_{key}", table_schema_mapping[key], value)
for key, value in row.items()
(
bigquery.ScalarQueryParameter(
f"param_{key}", table_schema_mapping[key].field_type, field
)
if table_schema_mapping[key].mode != "REPEATED"
# Use the `ArrayQueryParameter` for `REPEATED` mode fields.
else bigquery.ArrayQueryParameter(
f"param_{key}", table_schema_mapping[key].field_type, field
)
)
for key, field in row.items()
]

job_config = bigquery.QueryJobConfig(query_parameters=query_parameters)
Expand All @@ -517,7 +530,7 @@ def merge_into_bigquery(

def submit_to_bigquery(
bq_client: bigquery.Client,
table_schema_mappings: dict[str, dict[str, str]],
table_schema_mappings: dict[str, dict[str, bigquery.TableFieldSchema]],
table_id: str,
id_column: str,
rows: list[dict],
Expand Down

0 comments on commit d0eca3e

Please sign in to comment.