Skip to content

Commit

Permalink
Merge branch 'candidate/rb20240201' into dan/dates
Browse files Browse the repository at this point in the history
  • Loading branch information
ohaibbq authored Apr 7, 2024
2 parents d50a4c0 + 7e669c6 commit d050edd
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 23 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ A list of ZetaSQL ( Google Standard SQL ) specifications and features supported
- [x] UNNEST and STRUCTs
- [ ] Explicit and implicit UNNEST
- [ ] UNNEST and NULLs
- [ ] UNNEST and WITH OFFSET
- [ ] PIVOT operator
- [ ] UNPIVOT operator
- [X] UNNEST and WITH OFFSET
- [x] PIVOT operator
- [x] UNPIVOT operator
- [ ] TABLESAMPLE operator
- [x] JOIN operation
- [x] INNER JOIN
Expand Down
6 changes: 5 additions & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,13 @@ func newZetaSQLiteConn(db *sql.DB, catalog *internal.Catalog) (*ZetaSQLiteConn,
if err != nil {
return nil, fmt.Errorf("failed to get sqlite3 connection: %w", err)
}
analyzer, err := internal.NewAnalyzer(catalog)
if err != nil {
return nil, fmt.Errorf("failed to create analyzer: %w", err)
}
return &ZetaSQLiteConn{
conn: conn,
analyzer: internal.NewAnalyzer(catalog),
analyzer: analyzer,
}, nil
}

Expand Down
22 changes: 17 additions & 5 deletions internal/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@ type Analyzer struct {
opt *zetasql.AnalyzerOptions
}

func NewAnalyzer(catalog *Catalog) *Analyzer {
func NewAnalyzer(catalog *Catalog) (*Analyzer, error) {
opt, err := newAnalyzerOptions()
if err != nil {
return nil, err
}
return &Analyzer{
catalog: catalog,
opt: newAnalyzerOptions(),
opt: opt,
namePath: &NamePath{},
}
}, nil
}

func newAnalyzerOptions() *zetasql.AnalyzerOptions {
func newAnalyzerOptions() (*zetasql.AnalyzerOptions, error) {
langOpt := zetasql.NewLanguageOptions()
langOpt.SetNameResolutionMode(zetasql.NameResolutionDefault)
langOpt.SetProductMode(types.ProductInternal)
Expand Down Expand Up @@ -69,6 +73,8 @@ func newAnalyzerOptions() *zetasql.AnalyzerOptions {
zetasql.FeatureV13ExtendedGeographyParsers,
zetasql.FeatureTemplateFunctions,
zetasql.FeatureV11WithOnSubquery,
zetasql.FeatureV13Pivot,
zetasql.FeatureV13Unpivot,
})
langOpt.SetSupportedStatementKinds([]ast.Kind{
ast.BeginStmt,
Expand All @@ -88,11 +94,17 @@ func newAnalyzerOptions() *zetasql.AnalyzerOptions {
ast.CreateViewStmt,
ast.DropFunctionStmt,
})
// Enable QUALIFY without WHERE
//https://github.com/google/zetasql/issues/124
err := langOpt.EnableReservableKeyword("QUALIFY", true)
if err != nil {
return nil, err
}
opt := zetasql.NewAnalyzerOptions()
opt.SetAllowUndeclaredParameters(true)
opt.SetLanguage(langOpt)
opt.SetParseLocationRecordType(zetasql.ParseLocationRecordFullNodeScope)
return opt
return opt, nil
}

func (a *Analyzer) SetAutoIndexMode(enabled bool) {
Expand Down
66 changes: 53 additions & 13 deletions internal/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,12 @@ func (n *ArrayScanNode) FormatSQL(ctx context.Context) (string, error) {
return "", err
}
colName := uniqueColumnName(ctx, n.node.ElementColumn())
columns := []string{fmt.Sprintf("json_each.value AS `%s`", colName)}

if offsetColumn := n.node.ArrayOffsetColumn(); offsetColumn != nil {
offsetColName := uniqueColumnName(ctx, offsetColumn.Column())
columns = append(columns, fmt.Sprintf("json_each.key AS `%s`", offsetColName))
}
if n.node.InputScan() != nil {
input, err := newNode(n.node.InputScan()).FormatSQL(ctx)
if err != nil {
Expand Down Expand Up @@ -762,15 +768,15 @@ func (n *ArrayScanNode) FormatSQL(ctx context.Context) (string, error) {
}

return fmt.Sprintf(
"SELECT *, json_each.value AS `%s` %s %s",
colName,
"SELECT *, %s %s %s",
strings.Join(columns, ","),
formattedInput,
arrayJoinExpr,
), nil
}
return fmt.Sprintf(
"SELECT json_each.value AS `%s` FROM json_each(zetasqlite_decode_array(%s))",
colName,
"SELECT %s FROM json_each(zetasqlite_decode_array(%s))",
strings.Join(columns, ","),
arrayExpr,
), nil
}
Expand All @@ -794,12 +800,6 @@ func (n *FilterScanNode) FormatSQL(ctx context.Context) (string, error) {
if err != nil {
return "", err
}
nodeMap := nodeMapFromContext(ctx)
for _, node := range nodeMap.FindNodeFromResolvedNode(n.node) {
if _, ok := node.(*parsed_ast.HavingNode); ok {
return fmt.Sprintf("%s HAVING %s", input, filter), nil
}
}
currentQuery := string(removeExpressions.ReplaceAllString(input, ""))

// Qualify the statement if the input is not wrapped in parens
Expand Down Expand Up @@ -974,15 +974,31 @@ func (n *SetOperationScanNode) FormatSQL(ctx context.Context) (string, error) {
case ast.SetOperationTypeExceptDistinct:
opType = "EXCEPT"
default:
opType = "UNKONWN"
opType = "UNKNOWN"
}
var queries []string
for _, item := range n.node.InputItemList() {
var outputColumns []string
for _, outputColumn := range item.OutputColumnList() {
outputColumns = append(outputColumns, fmt.Sprintf("`%s`", uniqueColumnName(ctx, outputColumn)))
}
query, err := newNode(item).FormatSQL(ctx)
if err != nil {
return "", err
}
queries = append(queries, query)

formattedInput, err := formatInput(query)
if err != nil {
return "", err
}

queries = append(
queries,
fmt.Sprintf("SELECT %s %s",
strings.Join(outputColumns, ", "),
formattedInput,
),
)
}
columnMaps := []string{}
if len(n.node.InputItemList()) != 0 {
Expand Down Expand Up @@ -1348,11 +1364,35 @@ func (n *ExplainStmtNode) FormatSQL(ctx context.Context) (string, error) {
return "", nil
}

// FormatSQL Formats the outermost query statement that runs and produces rows of output, like a SELECT
// The node's `OutputColumnList()` gives user-visible column names that should be returned. There may be duplicate names,
// and multiple output columns may reference the same column from `Query()`
// https://github.com/google/zetasql/blob/master/docs/resolved_ast.md#ResolvedQueryStmt
func (n *QueryStmtNode) FormatSQL(ctx context.Context) (string, error) {
if n.node == nil {
return "", nil
}
return newNode(n.node.Query()).FormatSQL(ctx)
input, err := newNode(n.node.Query()).FormatSQL(ctx)
if err != nil {
return "", err
}

var columns []string
for _, outputColumnNode := range n.node.OutputColumnList() {
columns = append(
columns,
fmt.Sprintf("`%s` AS `%s`",
uniqueColumnName(ctx, outputColumnNode.Column()),
outputColumnNode.Name(),
),
)
}

return fmt.Sprintf(
"SELECT %s FROM (%s)",
strings.Join(columns, ", "),
input,
), nil
}

func (n *CreateDatabaseStmtNode) FormatSQL(ctx context.Context) (string, error) {
Expand Down
12 changes: 12 additions & 0 deletions internal/function_bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -3707,6 +3707,18 @@ func bindWindowCountIf() func() *WindowAggregator {
}
}

func bindWindowLogicalAnd() func() *WindowAggregator {
return func() *WindowAggregator {
return newSingleItemWindowAggregator(&WINDOW_LOGICAL_AND{})
}
}

func bindWindowLogicalOr() func() *WindowAggregator {
return func() *WindowAggregator {
return newSingleItemWindowAggregator(&WINDOW_LOGICAL_OR{})
}
}

func bindWindowMax() func() *WindowAggregator {
return func() *WindowAggregator {
return newSingleItemWindowAggregator(&WINDOW_MAX{})
Expand Down
2 changes: 2 additions & 0 deletions internal/function_register.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ var windowFuncs = []*WindowFuncInfo{
{Name: "count", BindFunc: bindWindowCount},
{Name: "count_star", BindFunc: bindWindowCountStar},
{Name: "countif", BindFunc: bindWindowCountIf},
{Name: "logical_and", BindFunc: bindWindowLogicalAnd},
{Name: "logical_or", BindFunc: bindWindowLogicalOr},
{Name: "max", BindFunc: bindWindowMax},
{Name: "min", BindFunc: bindWindowMin},
{Name: "string_agg", BindFunc: bindWindowStringAgg},
Expand Down
46 changes: 46 additions & 0 deletions internal/function_window.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,52 @@ func (f *WINDOW_LAG) Done(agg *WindowFuncAggregatedStatus) (Value, error) {
return agg.Values[len(agg.Values)-f.offset-1], nil
}

type WINDOW_LOGICAL_AND struct {
}

func (f *WINDOW_LOGICAL_AND) Done(agg *WindowFuncAggregatedStatus) (Value, error) {
values, err := agg.RelevantValues()
if err != nil {
return nil, err
}

for _, cond := range values {
b, err := cond.ToBool()
if err != nil {
return nil, err
}

if !b {
return BoolValue(false), nil
}
}

return BoolValue(true), nil
}

type WINDOW_LOGICAL_OR struct {
}

func (f *WINDOW_LOGICAL_OR) Done(agg *WindowFuncAggregatedStatus) (Value, error) {
values, err := agg.RelevantValues()
if err != nil {
return nil, err
}

for _, cond := range values {
b, err := cond.ToBool()
if err != nil {
return nil, err
}

if b {
return BoolValue(true), nil
}
}

return BoolValue(false), nil
}

type WINDOW_PERCENTILE_CONT struct {
percentile Value
}
Expand Down
Loading

0 comments on commit d050edd

Please sign in to comment.