Skip to content

Commit

Permalink
Adds COPY INTO support
Browse files Browse the repository at this point in the history
  • Loading branch information
rad-pat committed Dec 16, 2024
1 parent ef5bd67 commit fee0775
Show file tree
Hide file tree
Showing 5 changed files with 723 additions and 2 deletions.
23 changes: 23 additions & 0 deletions databend_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,26 @@

VERSION = (0, 4, 8)
__version__ = ".".join(str(x) for x in VERSION)


from .dml import (
Merge,
WhenMergeUnMatchedClause,
WhenMergeMatchedDeleteClause,
WhenMergeMatchedUpdateClause,
CopyIntoTable,
CopyIntoLocation,
CopyIntoTableOptions,
CopyIntoLocationOptions,
CSVFormat,
TSVFormat,
NDJSONFormat,
ParquetFormat,
ORCFormat,
AmazonS3,
AzureBlobStorage,
GoogleCloudStorage,
FileColumnClause,
StageClause,
Compression,
)
126 changes: 125 additions & 1 deletion databend_sqlalchemy/databend_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import re
import operator
import datetime
from types import NoneType

import sqlalchemy.types as sqltypes
from typing import Any, Dict, Optional, Union
from sqlalchemy import util as sa_util
Expand All @@ -50,7 +52,11 @@
)
from sqlalchemy.engine import ExecutionContext, default
from sqlalchemy.exc import DBAPIError, NoSuchTableError
from .dml import Merge

from .dml import (
Merge, StageClause, _StorageClause, GoogleCloudStorage,
AzureBlobStorage, AmazonS3
)

RESERVED_WORDS = {
'Error', 'EOI', 'Whitespace', 'Comment', 'CommentBlock', 'Ident', 'ColumnPosition', 'LiteralString',
Expand Down Expand Up @@ -490,6 +496,124 @@ def visit_when_merge_unmatched(self, merge_unmatched, **kw):
", ".join(map(lambda e: e._compiler_dispatch(self, **kw), sets_vals)),
)

def visit_copy_into(self, copy_into, **kw):
target = (
self.preparer.format_table(copy_into.target)
if isinstance(copy_into.target, (TableClause,))
else copy_into.target._compiler_dispatch(self, **kw)
)

if isinstance(copy_into.from_, (TableClause,)):
source = self.preparer.format_table(copy_into.from_)
elif isinstance(copy_into.from_, (_StorageClause, StageClause)):
source = copy_into.from_._compiler_dispatch(self, **kw)
# elif isinstance(copy_into.from_, (FileColumnClause)):
# source = f"({copy_into.from_._compiler_dispatch(self, **kw)})"
else:
source = f"({copy_into.from_._compiler_dispatch(self, **kw)})"

result = (
f"COPY INTO {target}"
f" FROM {source}"
)
if hasattr(copy_into, 'files') and isinstance(copy_into.files, list):
result += f"FILES = {', '.join([f for f in copy_into.files])}"
if hasattr(copy_into, 'pattern') and copy_into.pattern:
result += f" PATTERN = '{copy_into.pattern}'"
if not isinstance(copy_into.file_format, NoneType):
result += f" {copy_into.file_format._compiler_dispatch(self, **kw)}\n"
if not isinstance(copy_into.options, NoneType):
result += f" {copy_into.options._compiler_dispatch(self, **kw)}\n"

return result

def visit_copy_format(self, file_format, **kw):
options_list = list(file_format.options.items())
if kw.get("deterministic", False):
options_list.sort(key=operator.itemgetter(0))
# predefined format name
if "format_name" in file_format.options:
return f"FILE_FORMAT=(format_name = {file_format.options['format_name']})"
# format specifics
format_options = [f"TYPE = {file_format.format_type}"]
format_options.extend([
"{} = {}".format(
option,
(
value._compiler_dispatch(self, **kw)
if hasattr(value, "_compiler_dispatch")
else str(value)
),
)
for option, value in options_list
])
return f"FILE_FORMAT = ({', '.join(format_options)})"

def visit_copy_into_options(self, copy_into_options, **kw):
options_list = list(copy_into_options.options.items())
# if kw.get("deterministic", False):
# options_list.sort(key=operator.itemgetter(0))
return "\n".join([
f"{k} = {v}"
for k, v in options_list
])

def visit_file_column(self, file_column_clause, **kw):
if isinstance(file_column_clause.from_, (TableClause,)):
source = self.preparer.format_table(file_column_clause.from_)
elif isinstance(file_column_clause.from_, (_StorageClause, StageClause)):
source = file_column_clause.from_._compiler_dispatch(self, **kw)
else:
source = f"({file_column_clause.from_._compiler_dispatch(self, **kw)})"
if isinstance(file_column_clause.columns, str):
select_str = file_column_clause.columns
else:
select_str = ",".join([col._compiler_dispatch(self, **kw) for col in file_column_clause.columns])
return (
f"SELECT {select_str}"
f" FROM {source}"
)

def visit_amazon_s3(self, amazon_s3: AmazonS3, **kw):
connection_params_str = f" ACCESS_KEY_ID = '{amazon_s3.access_key_id}' \n"
connection_params_str += f" SECRET_ACCESS_KEY = '{amazon_s3.secret_access_key}'\n"
if amazon_s3.endpoint_url:
connection_params_str += f" ENDPOINT_URL = '{amazon_s3.endpoint_url}' \n"
if amazon_s3.enable_virtual_host_style:
connection_params_str += f" ENABLE_VIRTUAL_HOST_STYLE = '{amazon_s3.enable_virtual_host_style}'\n"
if amazon_s3.master_key:
connection_params_str += f" MASTER_KEY = '{amazon_s3.master_key}'\n"
if amazon_s3.region:
connection_params_str += f" REGION = '{amazon_s3.region}'\n"
if amazon_s3.security_token:
connection_params_str += f" SECURITY_TOKEN = '{amazon_s3.security_token}'\n"

return (
f"{amazon_s3.uri} \n"
f"CONNECTION = (\n"
f"{connection_params_str}\n"
f")"
)

def visit_azure_blob_storage(self, azure: AzureBlobStorage, **kw):
return (
f"{azure.uri} \n"
f"CONNECTION = (\n"
f" ENDPOINT_URL = 'https://{azure.account_name}.blob.core.windows.net' \n"
f" ACCOUNT_NAME = '{azure.account_name}' \n"
f" ACCOUNT_KEY = '{azure.account_key}'\n"
f")"
)

def visit_google_cloud_storage(self, gcs: GoogleCloudStorage, **kw):
return (
f"{gcs.uri} \n"
f"CONNECTION = (\n"
f" ENDPOINT_URL = 'https://storage.googleapis.com' \n"
f" CREDENTIAL = '{gcs.credentials}' \n"
f")"
)


class DatabendExecutionContext(default.DefaultExecutionContext):
@sa_util.memoized_property
Expand Down
Loading

0 comments on commit fee0775

Please sign in to comment.