From 628ed674b40dc80e87b58507d8690b122ebdba16 Mon Sep 17 00:00:00 2001 From: Pat Buxton Date: Fri, 13 Dec 2024 15:37:36 +0000 Subject: [PATCH] Adds COPY INTO support --- databend_sqlalchemy/databend_dialect.py | 111 +++++- databend_sqlalchemy/dml.py | 435 ++++++++++++++++++++++++ tests/test_copy_into.py | 139 ++++++++ 3 files changed, 684 insertions(+), 1 deletion(-) create mode 100644 tests/test_copy_into.py diff --git a/databend_sqlalchemy/databend_dialect.py b/databend_sqlalchemy/databend_dialect.py index 127ff14..f1ecd64 100644 --- a/databend_sqlalchemy/databend_dialect.py +++ b/databend_sqlalchemy/databend_dialect.py @@ -28,6 +28,8 @@ import re import operator import datetime +from types import NoneType + import sqlalchemy.types as sqltypes from typing import Any, Dict, Optional, Union from sqlalchemy import util as sa_util @@ -50,7 +52,8 @@ ) from sqlalchemy.engine import ExecutionContext, default from sqlalchemy.exc import DBAPIError, NoSuchTableError -from .dml import Merge +from .dml import Merge, StageClause, _StorageClause, CopyIntoTable, CopyIntoLocation, GoogleCloudStorage, \ + AzureBlobStorage, AmazonS3, FileColumnClause RESERVED_WORDS = { 'Error', 'EOI', 'Whitespace', 'Comment', 'CommentBlock', 'Ident', 'ColumnPosition', 'LiteralString', @@ -490,6 +493,112 @@ def visit_when_merge_unmatched(self, merge_unmatched, **kw): ", ".join(map(lambda e: e._compiler_dispatch(self, **kw), sets_vals)), ) + def visit_copy_into(self, copy_into, **kw): + target = ( + self.preparer.format_table(copy_into.target) + if isinstance(copy_into.target, (TableClause,)) + else copy_into.target._compiler_dispatch(self, **kw) + ) + + if isinstance(copy_into.from_, (TableClause,)): + source = self.preparer.format_table(copy_into.from_) + elif isinstance(copy_into.from_, (_StorageClause, StageClause)): + source = copy_into.from_._compiler_dispatch(self, **kw) + # elif isinstance(copy_into.from_, (FileColumnClause)): + # source = f"({copy_into.from_._compiler_dispatch(self, **kw)})" + else: + source = f"({copy_into.from_._compiler_dispatch(self, **kw)})" + + return ( + f"COPY INTO {target}\n" + f" FROM {source}\n" + f" FILES = {', '.join(['f' for f in copy_into.files])}" if isinstance(copy_into.files, list) else "" + f" PATTERN = '{copy_into.pattern}'" if copy_into.pattern else "" + f" {copy_into.file_format._compiler_dispatch(self, **kw)}\n" if not isinstance(copy_into.file_format, NoneType) else "" + f" {copy_into.options._compiler_dispatch(self, **kw)}\n" if not isinstance(copy_into.options, NoneType) else "" + ) + + def visit_copy_format(self, file_format, **kw): + options_list = list(file_format.options.items()) + if kw.get("deterministic", False): + options_list.sort(key=operator.itemgetter(0)) + # predefined format name + if "format_name" in file_format.options: + return f"FILE_FORMAT=(format_name = {file_format.options['format_name']})" + # format specifics + format_options = [f"TYPE = {file_format.format_type}"] + format_options.extend([ + "{} = {}".format( + option, + ( + value._compiler_dispatch(self, **kw) + if hasattr(value, "_compiler_dispatch") + else str(value) + ), + ) + for option, value in options_list + ]) + return f"FILE_FORMAT = ({', '.join(format_options)})" + + def visit_copy_into_options(self, copy_into_options, **kw): + options_list = list(copy_into_options.options.items()) + # if kw.get("deterministic", False): + # options_list.sort(key=operator.itemgetter(0)) + return "\n".join([ + f"{k} = {v}" + for k, v in options_list + ]) + + def visit_file_column(self, file_column_clause, **kw): + if isinstance(file_column_clause.from_, (TableClause,)): + source = self.preparer.format_table(file_column_clause.from_) + elif isinstance(file_column_clause.from_, (_StorageClause, StageClause)): + source = file_column_clause.from_._compiler_dispatch(self, **kw) + else: + source = f"({file_column_clause.from_._compiler_dispatch(self, **kw)})" + if isinstance(file_column_clause.columns, str): + select_str = file_column_clause.columns + else: + select_str = ",".join([col._compiler_dispatch(self, **kw) for col in file_column_clause.columns]) + return ( + f"SELECT {select_str}" + f" FROM {source}" + ) + + + def visit_amazon_s3(self, amazon_s3: AmazonS3, **kw): + return ( + f"{amazon_s3.uri} \n" + f"CONNECTION = (\n" + f" ENDPOINT_URL = '{amazon_s3.endpoint_url}' \n" if amazon_s3.endpoint_url else "" + f" ACCESS_KEY_ID = '{amazon_s3.access_key_id}' \n" + f" SECRET_ACCESS_KEY = '{amazon_s3.secret_access_key}'\n" + f" ENABLE_VIRTUAL_HOST_STYLE = '{amazon_s3.enable_virtual_host_style}'\n" if amazon_s3.enable_virtual_host_style else "" + f" MASTER_KEY = '{amazon_s3.master_key}'\n" if amazon_s3.master_key else "" + f" REGION = '{amazon_s3.region}'\n" if amazon_s3.region else "" + f" SECURITY_TOKEN = '{amazon_s3.security_token}'\n" if amazon_s3.security_token else "" + f")" + ) + + def visit_azure_blob_storage(self, azure: AzureBlobStorage, **kw): + return ( + f"{azure.uri} \n" + f"CONNECTION = (\n" + f" ENDPOINT_URL = 'https://{azure.account_name}.blob.core.windows.net' \n" + f" ACCOUNT_NAME = '{azure.account_name}' \n" + f" ACCOUNT_KEY = '{azure.account_key}'\n" + f")" + ) + + def visit_google_cloud_storage(self, gcs: GoogleCloudStorage, **kw): + return ( + f"{gcs.uri} \n" + f"CONNECTION = (\n" + f" ENDPOINT_URL = 'https://storage.googleapis.com' \n" + f" CREDENTIAL = '{gcs.credentials}' \n" + f")" + ) + class DatabendExecutionContext(default.DefaultExecutionContext): @sa_util.memoized_property diff --git a/databend_sqlalchemy/dml.py b/databend_sqlalchemy/dml.py index ab71da7..7290a0f 100644 --- a/databend_sqlalchemy/dml.py +++ b/databend_sqlalchemy/dml.py @@ -2,11 +2,15 @@ # # Note: parts of the file come from https://github.com/snowflakedb/snowflake-sqlalchemy # licensed under the same Apache 2.0 License +from enum import Enum +from types import NoneType +from urllib.parse import urlparse from sqlalchemy.sql.selectable import Select, Subquery, TableClause from sqlalchemy.sql.dml import UpdateBase from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.expression import select +from sqlalchemy.sql.roles import FromClauseRole class _OnMergeBaseClause(ClauseElement): @@ -98,3 +102,434 @@ def when_not_matched_then_insert(self): clause = WhenMergeUnMatchedClause() self.clauses.append(clause) return clause + + + +# +# +# +# class FilesOption: +# """ +# Class to represent FILES option for the snowflake COPY INTO statement +# """ +# +# def __init__(self, file_names: List[str]): +# self.file_names = file_names +# +# def __str__(self): +# the_files = ["'" + f.replace("'", "\\'") + "'" for f in self.file_names] +# return f"({','.join(the_files)})" + + +class _CopyIntoBase(UpdateBase): + __visit_name__ = "copy_into" + _bind = None + + def __init__(self, target: ['TableClause', 'StageClause', '_StorageClause'], from_, file_format: 'CopyFormat' = None, + options: ['CopyIntoLocationOptions', 'CopyIntoTableOptions'] = None): + self.target = target + self.from_ = from_ + self.file_format = file_format + self.options = options + + def __repr__(self): + """ + repr for debugging / logging purposes only. For compilation logic, see + the corresponding visitor in base.py + """ + val = f"COPY INTO {self.target} FROM {repr(self.from_)}" + return val + f" {repr(self.file_format)} ({self.options})" + + def bind(self): + return None + + +class CopyIntoLocation(_CopyIntoBase): + inherit_cache = False + def __init__(self, *, target: ['StageClause', '_StorageClause'], from_, file_format: 'CopyFormat' = None, options: 'CopyIntoLocationOptions' = None): + super().__init__(target, from_, file_format, options) + + +class CopyIntoTable(_CopyIntoBase): + inherit_cache = False + def __init__(self, *, target: [TableClause], from_: ['StageClause', '_StorageClause', 'FileColumnClause'], + files: list = None, pattern: str = None, file_format: 'CopyFormat' = None, options: 'CopyIntoTableOptions' = None): + super().__init__(target, from_, file_format, options) + self.files = files + self.pattern = pattern + + +class _CopyIntoOptions(ClauseElement): + __visit_name = "copy_into_options" + + def __init__(self): + self.options = dict() + + def __repr__(self): + return "\n".join([ + f"{k} = {v}" + for k, v in self.options.items() + ]) + +class CopyIntoLocationOptions(_CopyIntoOptions): + #__visit_name__ = "copy_into_location_options" + + def __init__(self, *, single: bool = None, max_file_size_bytes: int = None, overwrite: bool = None, + include_query_id: bool = None, use_raw_path: bool = None): + super().__init__() + if not isinstance(single, NoneType): + self.options['SINGLE'] = "TRUE" if single else "FALSE" + if not isinstance(max_file_size_bytes, NoneType): + self.options["MAX_FILE_SIZE "] = max_file_size_bytes + if not isinstance(overwrite, NoneType): + self.options["OVERWRITE"] = "TRUE" if overwrite else "FALSE" + if not isinstance(include_query_id, NoneType): + self.options["INCLUDE_QUERY_ID"] = "TRUE" if include_query_id else "FALSE" + if not isinstance(use_raw_path, NoneType): + self.options["OVERWRITE"] = "TRUE" if use_raw_path else "FALSE" + + +class CopyIntoTableOptions(_CopyIntoOptions): + #__visit_name__ = "copy_into_table_options" + + def __init__(self, *, size_limit: int = None, purge: bool = None, force: bool = None, + disable_variant_check: bool = None, on_error: str = None, max_files: int = None, + return_failed_only: bool = None, column_match_mode: str = None): + super().__init__() + if not isinstance(size_limit, NoneType): + self.options['SIZE_LIMIT'] = size_limit + if not isinstance(purge, NoneType): + self.options["PURGE "] = "TRUE" if purge else "FALSE" + if not isinstance(force, NoneType): + self.options["FORCE"] = "TRUE" if force else "FALSE" + if not isinstance(disable_variant_check, NoneType): + self.options["DISABLE_VARIANT_CHECK"] = "TRUE" if disable_variant_check else "FALSE" + if not isinstance(on_error, NoneType): + self.options["ON_ERROR"] = on_error + if not isinstance(max_files, NoneType): + self.options["MAX_FILES"] = max_files + if not isinstance(return_failed_only, NoneType): + self.options["RETURN_FAILED_ONLY"] = return_failed_only + if not isinstance(column_match_mode, NoneType): + self.options["COLUMN_MATCH_MODE"] = column_match_mode + + + +class Compression(Enum): + NONE = "NONE" + AUTO = "AUTO" + GZIP = "GZIP" + BZ2 = "BZ2" + BROTLI = "BROTLI" + ZSTD = "ZSTD" + DEFLATE = "DEFLATE" + RAW_DEFLATE = "RAW_DEFLATE" + XZ = "XZ" + + +class CopyFormat(ClauseElement): + """ + Base class for Format specifications inside a COPY INTO statement. May also + be used to create a named format. + """ + + __visit_name__ = "copy_format" + + def __init__(self, format_name=None): + self.options = dict() + if format_name: + self.options["format_name"] = format_name + + def __repr__(self): + """ + repr for debugging / logging purposes only. For compilation logic, see + the respective visitor in the dialect + """ + return f"FILE_FORMAT=({self.options})" + + +class CSVFormat(CopyFormat): + format_type = "CSV" + + def __init__(self, *, + record_delimiter: str = None, + field_delimiter: str = None, + quote: str = None, + escape: str = None, + skip_header: int = None, + nan_display: str = None, + null_display: str = None, + error_on_column_mismatch: bool = None, + empty_field_as: str = None, + output_header: bool = None, + binary_format: str = None, + compression: Compression = None, + ): + super().__init__() + if record_delimiter: + if len(str(record_delimiter).encode().decode('unicode_escape')) != 1 and record_delimiter != '\r\n': + raise TypeError( + 'Record Delimiter should be a single character.' + ) + self.options['RECORD_DELIMITER'] = f"'{record_delimiter}'" + if field_delimiter: + if len(str(field_delimiter).encode().decode('unicode_escape')) != 1: + raise TypeError( + 'Field Delimiter should be a single character' + ) + self.options["FIELD_DELIMITER"] = f"'{field_delimiter}'" + if quote: + if quote not in ['\'', '"', '`']: + raise TypeError('Quote character must be one of [\', ", `].') + self.options["QUOTE"] = f"'{quote}'" + if escape: + if escape not in ['\\', '']: + raise TypeError('Escape character must be "\\" or "".') + self.options["ESCAPE"] = f"'{escape}'" + if skip_header: + if skip_header < 0: + raise TypeError('Skip header must be positive integer.') + self.options["SKIP_HEADER"] = skip_header + if nan_display: + if nan_display not in ['NULL', 'NaN']: + raise TypeError('NaN Display should be "NULL" or "NaN".') + self.options["NAN_DISPLAY"] = f"'{nan_display}'" + if null_display: + self.options["NULL_DISPLAY"] = f"'{null_display}'" + if error_on_column_mismatch: + self.options["ERROR_ON_COLUMN_MISMATCH"] = str(error_on_column_mismatch).upper() + if empty_field_as: + if empty_field_as not in ['NULL', 'STRING', 'FIELD_DEFAULT']: + raise TypeError('Empty Field As should be "NULL", "STRING" for "FIELD_DEFAULT".') + self.options["EMPTY_FIELD_AS"] = f"'{empty_field_as}'" + if output_header: + self.options["OUTPUT_HEADER"] = str(output_header).upper() + if binary_format: + if binary_format not in ['HEX', 'BASE64']: + raise TypeError('Binary Format should be "HEX" or "BASE64".') + self.options["BINARY_FORMAT"] = binary_format + if compression: + self.options["COMPRESSION"] = compression.value + + +class TSVFormat(CopyFormat): + format_type = "TSV" + + def __init__(self, *, + record_delimiter: str = None, + field_delimiter: str = None, + compression: Compression = None, + ): + super().__init__() + if record_delimiter: + if len(str(record_delimiter).encode().decode('unicode_escape')) != 1 and record_delimiter != '\r\n': + raise TypeError( + 'Record Delimiter should be a single character.' + ) + self.options['RECORD_DELIMITER'] = f"'{record_delimiter}'" + if field_delimiter: + if len(str(field_delimiter).encode().decode('unicode_escape')) != 1: + raise TypeError( + 'Field Delimiter should be a single character' + ) + self.options["FIELD_DELIMITER"] = f"'{field_delimiter}'" + if compression: + self.options["COMPRESSION"] = compression.value + + +class NDJSONFormat(CopyFormat): + format_type = "NDJSON" + + def __init__(self, *, + null_field_as: str = None, + missing_field_as: str = None, + compression: Compression = None, + ): + super().__init__() + if null_field_as: + if null_field_as not in ['NULL', 'FIELD_DEFAULT']: + raise TypeError('Null Field As should be "NULL" or "FIELD_DEFAULT".') + self.options["NULL_FIELD_AS"] = f"'{null_field_as}'" + if missing_field_as: + if missing_field_as not in ['ERROR', 'NULL', 'FIELD_DEFAULT', 'TYPE_DEFAULT']: + raise TypeError('Missing Field As should be "ERROR", "NULL", "FIELD_DEFAULT" or "TYPE_DEFAULT".') + self.options["MISSING_FIELD_AS"] = f"'{missing_field_as}'" + if compression: + self.options["COMPRESSION"] = compression.value + + +class ParquetFormat(CopyFormat): + format_type = "PARQUET" + + def __init__(self, *, + missing_field_as: str = None, + compression: Compression = None, + ): + super().__init__() + if missing_field_as: + if missing_field_as not in ['ERROR', 'FIELD_DEFAULT']: + raise TypeError('Missing Field As should be "ERROR" or "FIELD_DEFAULT".') + self.options["MISSING_FIELD_AS"] = f"'{missing_field_as}'" + + +class ORCFormat(CopyFormat): + format_type = "ORC" + +class StageClause(ClauseElement, FromClauseRole): + """Stage Clause""" + + __visit_name__ = "stage" + _hide_froms = () + + def __init__(self, *, name, path=None): + self.name = name + self.path = path + + def __repr__(self): + return f"@{self.name}/{self.path}" + + +class FileColumnClause(ClauseElement, FromClauseRole): + """Clause for selecting file columns from a Stage/Location""" + __visit_name__ = "file_column" + + def __init__(self, *, columns, from_: ['StageClause', '_StorageClause']): + # columns need to be expressions of column index, e.g. $1, IF($1 =='t', True, False), or string of these expressions that we just use + self.columns = columns + self.from_ = from_ + + def __repr__(self): + return ( + f"SELECT {self.columns if isinstance(self.columns, str) else ','.join(repr(col) for col in self.columns)}" + f" FROM {repr(self.from_)}" + ) + + +# +# class CreateFileFormat(DDLElement): +# """ +# Encapsulates a CREATE FILE FORMAT statement; using a format description (as in +# a COPY INTO statement) and a format name. +# """ +# +# __visit_name__ = "create_file_format" +# +# def __init__(self, format_name, formatter, replace_if_exists=False): +# super().__init__() +# self.format_name = format_name +# self.formatter = formatter +# self.replace_if_exists = replace_if_exists +# +# +# class CreateStage(DDLElement): +# """ +# Encapsulates a CREATE STAGE statement, using a container (physical base for the +# stage) and the actual ExternalStage object. +# """ +# +# __visit_name__ = "create_stage" +# +# def __init__(self, container, stage, replace_if_exists=False, *, temporary=False): +# super().__init__() +# self.container = container +# self.temporary = temporary +# self.stage = stage +# self.replace_if_exists = replace_if_exists + +class _StorageClause(ClauseElement): + pass + + +class AmazonS3(_StorageClause): + """Amazon S3""" + + __visit_name__ = "amazon_s3" + + def __init__(self, uri: str, access_key_id: str, secret_access_key: str, endpoint_url: str = None, + enable_virtual_host_style: bool = None, master_key: str = None, + region: str = None, security_token: str = None): + r = urlparse(uri) + if r.scheme != 's3': + raise ValueError(f'Invalid S3 URI: {uri}') + + self.uri = uri + self.access_key_id = access_key_id + self.secret_access_key = secret_access_key + self.bucket = r.netloc + self.path = r.path + if endpoint_url: + self.endpoint_url = endpoint_url + if enable_virtual_host_style: + self.enable_virtual_host_style = enable_virtual_host_style + if master_key: + self.master_key = master_key + if region: + self.region = region + if security_token: + self.security_token = security_token + + def __repr__(self): + return ( + f"{self.uri} \n" + f"CONNECTION = (\n" + f" ENDPOINT_URL = '{self.endpoint_url}' \n" if self.endpoint_url else "" + f" ACCESS_KEY_ID = '{self.access_key_id}' \n" + f" SECRET_ACCESS_KEY = '{self.secret_access_key}'\n" + f" ENABLE_VIRTUAL_HOST_STYLE = '{self.enable_virtual_host_style}'\n" if self.enable_virtual_host_style else "" + f" MASTER_KEY = '{self.master_key}'\n" if self.master_key else "" + f" REGION = '{self.region}'\n" if self.region else "" + f" SECURITY_TOKEN = '{self.security_token}'\n" if self.security_token else "" + f")" + ) + + +class AzureBlobStorage(_StorageClause): + """Microsoft Azure Blob Storage""" + + __visit_name__ = "azure_blob_storage" + + def __init__(self, *, uri: str, account_name: str, account_key: str): + r = urlparse(uri) + if r.scheme != 'azblob': + raise ValueError(f'Invalid Azure URI: {uri}') + + self.uri = uri + self.account_name = account_name + self.account_key = account_key + self.container = r.netloc + self.path = r.path + + def __repr__(self): + return ( + f"{self.uri} \n" + f"CONNECTION = (\n" + f" ENDPOINT_URL = 'https://{self.account_name}.blob.core.windows.net' \n" + f" ACCOUNT_NAME = '{self.account_name}' \n" + f" ACCOUNT_KEY = '{self.account_key}'\n" + f")" + ) + + +class GoogleCloudStorage(_StorageClause): + """Google Cloud Storage""" + + __visit_name__ = "google_cloud_storage" + + def __init__(self, *, uri, credentials): + r = urlparse(uri) + if r.scheme != 'gcs': + raise ValueError(f'Invalid Google Cloud Storage URI: {uri}') + + self.uri = uri + self.credentials = credentials + + + def __repr__(self): + return ( + f"{self.uri} \n" + f"CONNECTION = (\n" + f" ENDPOINT_URL = 'https://storage.googleapis.com' \n" + f" CREDENTIAL = '{self.credentials}' \n" + f")" + ) + diff --git a/tests/test_copy_into.py b/tests/test_copy_into.py new file mode 100644 index 0000000..2c09802 --- /dev/null +++ b/tests/test_copy_into.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python + +from sqlalchemy.testing import config, fixture, fixtures, util +from sqlalchemy.testing.assertions import AssertsCompiledSQL +from sqlalchemy import Table, Column, Integer, String, func, MetaData, schema, cast, literal_column + +from databend_sqlalchemy.dml import (CopyIntoTable, CopyIntoLocation, CopyIntoTableOptions, CSVFormat, ParquetFormat, + GoogleCloudStorage, Compression, FileColumnClause) + + + +class CompileDatabendCopyIntoTableTest(fixtures.TestBase, AssertsCompiledSQL): + + __only_on__ = "databend" + + def test_copy_into_table(self): + m = MetaData() + tbl = Table( + 'atable', m, Column("id", Integer), + schema="test_schema", + ) + + copy_into = CopyIntoTable( + target=tbl, + from_=GoogleCloudStorage( + uri='gcs://some-bucket/a/path/to/files', + credentials='XYZ', + ), + #files='', + #pattern='', + file_format=CSVFormat( + record_delimiter='\n', + field_delimiter=',', + quote='"', + #escape='\\', + #skip_header=1, + #nan_display='' + #null_display='', + error_on_column_mismatch=False, + #empty_field_as='STRING', + output_header=True, + #binary_format='', + compression=Compression.GZIP + ), + options=CopyIntoTableOptions( + size_limit=None, + purge=None, + force=None, + disable_variant_check=None, + on_error=None, + max_files=None, + return_failed_only=None, + column_match_mode=None, + ) + ) + + + self.assert_compile( + copy_into, + ("COPY INTO test_schema.atable" + " FROM gcs://some-bucket/a/path/to/files " + "CONNECTION = (" + " ENDPOINT_URL = 'https://storage.googleapis.com' " + " CREDENTIAL = 'XYZ' " + ")" + " FILE_FORMAT = (TYPE = CSV, " + "RECORD_DELIMITER = '', FIELD_DELIMITER = ',', QUOTE = '\"', OUTPUT_HEADER = TRUE, COMPRESSION = GZIP)" + ) + ) + + def test_copy_into_table_sub_select_string_columns(self): + m = MetaData() + tbl = Table( + 'atable', m, Column("id", Integer), + schema="test_schema", + ) + + copy_into = CopyIntoTable( + target=tbl, + from_=FileColumnClause( + columns='$1, $2, $3', + from_=GoogleCloudStorage( + uri='gcs://some-bucket/a/path/to/files', + credentials='XYZ', + ) + ), + file_format=CSVFormat(), + ) + + self.assert_compile( + copy_into, + ("COPY INTO test_schema.atable" + " FROM (SELECT $1, $2, $3" + " FROM gcs://some-bucket/a/path/to/files " + "CONNECTION = (" + " ENDPOINT_URL = 'https://storage.googleapis.com' " + " CREDENTIAL = 'XYZ' " + ")" + ") FILE_FORMAT = (TYPE = CSV)" + ) + ) + + def test_copy_into_table_sub_select_column_clauses(self): + m = MetaData() + tbl = Table( + 'atable', m, Column("id", Integer), + schema="test_schema", + ) + + copy_into = CopyIntoTable( + target=tbl, + from_=FileColumnClause( + columns=[func.IF(literal_column("$1") == 'xyz', 'NULL', 'NOTNULL')], + # columns='$1, $2, $3', + from_=GoogleCloudStorage( + uri='gcs://some-bucket/a/path/to/files', + credentials='XYZ', + ) + ), + file_format=CSVFormat(), + ) + + self.assert_compile( + copy_into, + ("COPY INTO test_schema.atable" + " FROM (SELECT IF($1 = %(1_1)s, %(IF_1)s, %(IF_2)s)" + " FROM gcs://some-bucket/a/path/to/files " + "CONNECTION = (" + " ENDPOINT_URL = 'https://storage.googleapis.com' " + " CREDENTIAL = 'XYZ' " + ")" + ") FILE_FORMAT = (TYPE = CSV)" + ), + checkparams={ + "1_1": "xyz", + "IF_1": "NULL", + "IF_2": "NOTNULL" + }, + )