From 3cadc5b88f4e43e0b7ea187660a53a04aed448b4 Mon Sep 17 00:00:00 2001 From: lichi Date: Tue, 24 Dec 2024 18:43:21 +0800 Subject: [PATCH 1/4] [feature](nereids)support create function command in nereids --- .../org/apache/doris/nereids/DorisParser.g4 | 23 +- .../nereids/parser/LogicalPlanBuilder.java | 114 +++ .../doris/nereids/trees/plans/PlanType.java | 2 + .../plans/commands/CreateFunctionCommand.java | 908 ++++++++++++++++++ .../plans/commands/DropFunctionCommand.java | 108 +++ .../commands/info/FunctionArgsDefInfo.java | 60 ++ .../trees/plans/visitor/CommandVisitor.java | 10 + .../doris/nereids/util/TypeCoercionUtils.java | 242 +++++ 8 files changed, 1455 insertions(+), 12 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java diff --git a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 index 7045733bafdba7..23f08eba90ff47 100644 --- a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 +++ b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 @@ -199,6 +199,15 @@ supportedCreateStatement | CREATE SQL_BLOCK_RULE (IF NOT EXISTS)? name=identifier properties=propertyClause? #createSqlBlockRule | CREATE ENCRYPTKEY (IF NOT EXISTS)? multipartIdentifier AS STRING_LITERAL #createEncryptkey + | CREATE (GLOBAL | SESSION | LOCAL)? + (TABLES | AGGREGATE)? FUNCTION (IF NOT EXISTS)? + functionIdentifier LEFT_PAREN functionArguments? RIGHT_PAREN + RETURNS returnType=dataType (INTERMEDIATE intermediateType=dataType)? + properties=propertyClause? #createUserDefineFunction + | CREATE (GLOBAL | SESSION | LOCAL)? ALIAS FUNCTION (IF NOT EXISTS)? + functionIdentifier LEFT_PAREN functionArguments? RIGHT_PAREN + WITH PARAMETER LEFT_PAREN parameters=identifierSeq? RIGHT_PAREN + AS expression #createAliasFunction ; supportedAlterStatement @@ -239,7 +248,8 @@ supportedDropStatement ((FROM | IN) database=identifier)? properties=propertyClause #dropFile | DROP WORKLOAD POLICY (IF EXISTS)? name=identifierOrText #dropWorkloadPolicy | DROP REPOSITORY name=identifier #dropRepository - + | DROP (GLOBAL | SESSION | LOCAL)? FUNCTION (IF EXISTS)? + functionIdentifier LEFT_PAREN functionArguments? RIGHT_PAREN #dropFunction ; supportedShowStatement @@ -696,8 +706,6 @@ fromRollup unsupportedDropStatement : DROP (DATABASE | SCHEMA) (IF EXISTS)? name=multipartIdentifier FORCE? #dropDatabase - | DROP (GLOBAL | SESSION | LOCAL)? FUNCTION (IF EXISTS)? - functionIdentifier LEFT_PAREN functionArguments? RIGHT_PAREN #dropFunction | DROP TABLE (IF EXISTS)? name=multipartIdentifier FORCE? #dropTable | DROP VIEW (IF EXISTS)? name=multipartIdentifier #dropView | DROP INDEX (IF EXISTS)? name=identifier ON tableName=multipartIdentifier #dropIndex @@ -753,15 +761,6 @@ analyzeProperties unsupportedCreateStatement : CREATE (DATABASE | SCHEMA) (IF NOT EXISTS)? name=multipartIdentifier properties=propertyClause? #createDatabase - | CREATE (GLOBAL | SESSION | LOCAL)? - (TABLES | AGGREGATE)? FUNCTION (IF NOT EXISTS)? - functionIdentifier LEFT_PAREN functionArguments? RIGHT_PAREN - RETURNS returnType=dataType (INTERMEDIATE intermediateType=dataType)? - properties=propertyClause? #createUserDefineFunction - | CREATE (GLOBAL | SESSION | LOCAL)? ALIAS FUNCTION (IF NOT EXISTS)? - functionIdentifier LEFT_PAREN functionArguments? RIGHT_PAREN - WITH PARAMETER LEFT_PAREN parameters=identifierSeq? RIGHT_PAREN - AS expression #createAliasFunction | CREATE USER (IF NOT EXISTS)? grantUserIdentify (SUPERUSER | DEFAULT ROLE role=STRING_LITERAL)? passwordOption (COMMENT STRING_LITERAL)? #createUser diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index b3a62e604c53c4..0ce8459e8db282 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -23,6 +23,7 @@ import org.apache.doris.analysis.ColumnPosition; import org.apache.doris.analysis.DbName; import org.apache.doris.analysis.EncryptKeyName; +import org.apache.doris.analysis.FunctionName; import org.apache.doris.analysis.PassVar; import org.apache.doris.analysis.SetType; import org.apache.doris.analysis.StorageBackend; @@ -109,6 +110,7 @@ import org.apache.doris.nereids.DorisParser.ComplexColTypeListContext; import org.apache.doris.nereids.DorisParser.ComplexDataTypeContext; import org.apache.doris.nereids.DorisParser.ConstantContext; +import org.apache.doris.nereids.DorisParser.CreateAliasFunctionContext; import org.apache.doris.nereids.DorisParser.CreateCatalogContext; import org.apache.doris.nereids.DorisParser.CreateEncryptkeyContext; import org.apache.doris.nereids.DorisParser.CreateFileContext; @@ -120,6 +122,7 @@ import org.apache.doris.nereids.DorisParser.CreateSqlBlockRuleContext; import org.apache.doris.nereids.DorisParser.CreateTableContext; import org.apache.doris.nereids.DorisParser.CreateTableLikeContext; +import org.apache.doris.nereids.DorisParser.CreateUserDefineFunctionContext; import org.apache.doris.nereids.DorisParser.CreateViewContext; import org.apache.doris.nereids.DorisParser.CreateWorkloadGroupContext; import org.apache.doris.nereids.DorisParser.CteContext; @@ -133,6 +136,7 @@ import org.apache.doris.nereids.DorisParser.DropConstraintContext; import org.apache.doris.nereids.DorisParser.DropEncryptkeyContext; import org.apache.doris.nereids.DorisParser.DropFileContext; +import org.apache.doris.nereids.DorisParser.DropFunctionContext; import org.apache.doris.nereids.DorisParser.DropIndexClauseContext; import org.apache.doris.nereids.DorisParser.DropMTMVContext; import org.apache.doris.nereids.DorisParser.DropPartitionClauseContext; @@ -154,6 +158,8 @@ import org.apache.doris.nereids.DorisParser.ExportContext; import org.apache.doris.nereids.DorisParser.FixedPartitionDefContext; import org.apache.doris.nereids.DorisParser.FromClauseContext; +import org.apache.doris.nereids.DorisParser.FunctionArgumentContext; +import org.apache.doris.nereids.DorisParser.FunctionArgumentsContext; import org.apache.doris.nereids.DorisParser.GroupingElementContext; import org.apache.doris.nereids.DorisParser.GroupingSetContext; import org.apache.doris.nereids.DorisParser.HavingClauseContext; @@ -509,6 +515,7 @@ import org.apache.doris.nereids.trees.plans.commands.CreateCatalogCommand; import org.apache.doris.nereids.trees.plans.commands.CreateEncryptkeyCommand; import org.apache.doris.nereids.trees.plans.commands.CreateFileCommand; +import org.apache.doris.nereids.trees.plans.commands.CreateFunctionCommand; import org.apache.doris.nereids.trees.plans.commands.CreateJobCommand; import org.apache.doris.nereids.trees.plans.commands.CreateMTMVCommand; import org.apache.doris.nereids.trees.plans.commands.CreatePolicyCommand; @@ -527,6 +534,7 @@ import org.apache.doris.nereids.trees.plans.commands.DropConstraintCommand; import org.apache.doris.nereids.trees.plans.commands.DropEncryptkeyCommand; import org.apache.doris.nereids.trees.plans.commands.DropFileCommand; +import org.apache.doris.nereids.trees.plans.commands.DropFunctionCommand; import org.apache.doris.nereids.trees.plans.commands.DropJobCommand; import org.apache.doris.nereids.trees.plans.commands.DropMTMVCommand; import org.apache.doris.nereids.trees.plans.commands.DropProcedureCommand; @@ -652,6 +660,7 @@ import org.apache.doris.nereids.trees.plans.commands.info.EnableFeatureOp; import org.apache.doris.nereids.trees.plans.commands.info.FixedRangePartition; import org.apache.doris.nereids.trees.plans.commands.info.FuncNameInfo; +import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgsDefInfo; import org.apache.doris.nereids.trees.plans.commands.info.GeneratedColumnDesc; import org.apache.doris.nereids.trees.plans.commands.info.InPartition; import org.apache.doris.nereids.trees.plans.commands.info.IndexDefinition; @@ -4129,6 +4138,111 @@ public LogicalPlan visitCreateTableLike(CreateTableLikeContext ctx) { return new CreateTableLikeCommand(info); } + @Override + public Command visitCreateUserDefineFunction(CreateUserDefineFunctionContext ctx) { + SetType setType; + if (ctx.GLOBAL() != null) { + setType = SetType.GLOBAL; + } else if (ctx.LOCAL() != null || ctx.SESSION() != null) { + setType = SetType.SESSION; + } else { + setType = SetType.DEFAULT; + } + boolean ifNotExists = ctx.EXISTS() != null; + boolean isAggFunction = ctx.AGGREGATE() != null; + boolean isTableFunction = ctx.TABLES() != null; + String functionName = ctx.functionIdentifier().functionNameIdentifier().getText(); + String dbName = ctx.functionIdentifier().dbName != null ? ctx.functionIdentifier().dbName.getText() : null; + FunctionName function = new FunctionName(dbName, functionName); + FunctionArgsDefInfo functionArgsDefInfo; + if (ctx.functionArguments() != null) { + functionArgsDefInfo = visitFunctionArguments(ctx.functionArguments()); + } else { + functionArgsDefInfo = new FunctionArgsDefInfo(new ArrayList<>(), false); + } + DataType returnType = typedVisit(ctx.returnType); + DataType intermediateType = ctx.intermediateType != null ? typedVisit(ctx.intermediateType) : null; + Map properties = ctx.propertyClause() != null + ? Maps.newHashMap(visitPropertyClause(ctx.propertyClause())) + : Maps.newHashMap(); + if (isTableFunction) { + return new CreateFunctionCommand(setType, ifNotExists, function, functionArgsDefInfo, returnType, + intermediateType, properties); + } else { + return new CreateFunctionCommand(setType, ifNotExists, isAggFunction, function, functionArgsDefInfo, + returnType, intermediateType, properties); + } + } + + @Override + public Command visitCreateAliasFunction(CreateAliasFunctionContext ctx) { + SetType setType; + if (ctx.GLOBAL() != null) { + setType = SetType.GLOBAL; + } else if (ctx.LOCAL() != null || ctx.SESSION() != null) { + setType = SetType.SESSION; + } else { + setType = SetType.DEFAULT; + } + boolean ifNotExists = ctx.EXISTS() != null; + String functionName = ctx.functionIdentifier().functionNameIdentifier().getText(); + String dbName = ctx.functionIdentifier().dbName != null ? ctx.functionIdentifier().dbName.getText() : null; + FunctionName function = new FunctionName(dbName, functionName); + FunctionArgsDefInfo functionArgsDefInfo; + if (ctx.functionArguments() != null) { + functionArgsDefInfo = visitFunctionArguments(ctx.functionArguments()); + } else { + functionArgsDefInfo = new FunctionArgsDefInfo(new ArrayList<>(), false); + } + List parameters = ctx.parameters != null ? visitIdentifierSeq(ctx.parameters) : new ArrayList<>(); + Expression originFunction = getExpression(ctx.expression()); + return new CreateFunctionCommand(setType, ifNotExists, function, functionArgsDefInfo, parameters, + originFunction); + } + + @Override + public Command visitDropFunction(DropFunctionContext ctx) { + SetType setType; + if (ctx.GLOBAL() != null) { + setType = SetType.GLOBAL; + } else if (ctx.LOCAL() != null || ctx.SESSION() != null) { + setType = SetType.SESSION; + } else { + setType = SetType.DEFAULT; + } + boolean ifExists = ctx.EXISTS() != null; + String functionName = ctx.functionIdentifier().functionNameIdentifier().getText(); + String dbName = ctx.functionIdentifier().dbName != null ? ctx.functionIdentifier().dbName.getText() : null; + FunctionName function = new FunctionName(dbName, functionName); + FunctionArgsDefInfo functionArgsDefInfo = null; + if (ctx.functionArguments() != null) { + functionArgsDefInfo = visitFunctionArguments(ctx.functionArguments()); + } + return new DropFunctionCommand(setType, ifExists, function, functionArgsDefInfo); + } + + @Override + public FunctionArgsDefInfo visitFunctionArguments(FunctionArgumentsContext ctx) { + boolean isVariadic = false; + List argTypeDefs = new ArrayList<>(4); + for (Object child : ctx.children) { + if (child instanceof FunctionArgumentContext) { + DataType dataType = visitFunctionArgument((FunctionArgumentContext) child); + if (dataType != null) { + argTypeDefs.add(dataType); + } else { + isVariadic = true; + } + } + } + return new FunctionArgsDefInfo(argTypeDefs, isVariadic); + } + + @Override + public DataType visitFunctionArgument(FunctionArgumentContext ctx) { + return ctx.dataType() != null ? typedVisit(ctx.dataType()) : null; + } + @Override public LogicalPlan visitShowAuthors(ShowAuthorsContext ctx) { return new ShowAuthorsCommand(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java index adbf3720c05c2d..fd753018ed45ea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java @@ -146,6 +146,8 @@ public enum PlanType { LOAD_COMMAND, SELECT_INTO_OUTFILE_COMMAND, UPDATE_COMMAND, + CREATE_FUNCTION_COMMAND, + DROP_FUNCTION_COMMAND, CREATE_MTMV_COMMAND, CREATE_JOB_COMMAND, PAUSE_JOB_COMMAND, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java new file mode 100644 index 00000000000000..2c2a803f477fa5 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java @@ -0,0 +1,908 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.plans.commands; + +import org.apache.doris.analysis.Expr; +import org.apache.doris.analysis.FunctionName; +import org.apache.doris.analysis.SetType; +import org.apache.doris.analysis.StmtType; +import org.apache.doris.catalog.AggregateFunction; +import org.apache.doris.catalog.AliasFunction; +import org.apache.doris.catalog.ArrayType; +import org.apache.doris.catalog.Database; +import org.apache.doris.catalog.Env; +import org.apache.doris.catalog.Function; +import org.apache.doris.catalog.Function.NullableMode; +import org.apache.doris.catalog.FunctionUtil; +import org.apache.doris.catalog.MapType; +import org.apache.doris.catalog.ScalarFunction; +import org.apache.doris.catalog.ScalarType; +import org.apache.doris.catalog.StructType; +import org.apache.doris.catalog.Type; +import org.apache.doris.common.AnalysisException; +import org.apache.doris.common.Config; +import org.apache.doris.common.ErrorCode; +import org.apache.doris.common.ErrorReport; +import org.apache.doris.common.FeConstants; +import org.apache.doris.common.util.URI; +import org.apache.doris.common.util.Util; +import org.apache.doris.mysql.privilege.PrivPredicate; +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.glue.translator.ExpressionTranslator; +import org.apache.doris.nereids.glue.translator.PlanTranslatorContext; +import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.plans.PlanType; +import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgsDefInfo; +import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; +import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.VarcharType; +import org.apache.doris.nereids.util.TypeCoercionUtils; +import org.apache.doris.proto.FunctionService; +import org.apache.doris.proto.PFunctionServiceGrpc; +import org.apache.doris.proto.Types; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.StmtExecutor; +import org.apache.doris.thrift.TFunctionBinaryType; + +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSortedMap; +import io.grpc.ManagedChannel; +import io.grpc.netty.NettyChannelBuilder; +import org.apache.commons.codec.binary.Hex; +import org.apache.commons.lang3.StringUtils; + +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.Parameter; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLClassLoader; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * create a alias or user defined function + */ +public class CreateFunctionCommand extends Command implements ForwardWithSync { + @Deprecated + public static final String OBJECT_FILE_KEY = "object_file"; + public static final String FILE_KEY = "file"; + public static final String SYMBOL_KEY = "symbol"; + public static final String PREPARE_SYMBOL_KEY = "prepare_fn"; + public static final String CLOSE_SYMBOL_KEY = "close_fn"; + public static final String MD5_CHECKSUM = "md5"; + public static final String INIT_KEY = "init_fn"; + public static final String UPDATE_KEY = "update_fn"; + public static final String MERGE_KEY = "merge_fn"; + public static final String SERIALIZE_KEY = "serialize_fn"; + public static final String FINALIZE_KEY = "finalize_fn"; + public static final String GET_VALUE_KEY = "get_value_fn"; + public static final String REMOVE_KEY = "remove_fn"; + public static final String BINARY_TYPE = "type"; + public static final String EVAL_METHOD_KEY = "evaluate"; + public static final String CREATE_METHOD_NAME = "create"; + public static final String DESTROY_METHOD_NAME = "destroy"; + public static final String ADD_METHOD_NAME = "add"; + public static final String SERIALIZE_METHOD_NAME = "serialize"; + public static final String MERGE_METHOD_NAME = "merge"; + public static final String GETVALUE_METHOD_NAME = "getValue"; + public static final String STATE_CLASS_NAME = "State"; + // add for java udf check return type nullable mode, always_nullable or always_not_nullable + public static final String IS_RETURN_NULL = "always_nullable"; + // iff is static load, BE will be cache the udf class load, so only need load once + public static final String IS_STATIC_LOAD = "static_load"; + public static final String EXPIRATION_TIME = "expiration_time"; + + // timeout for both connection and read. 10 seconds is long enough. + private static final int HTTP_TIMEOUT_MS = 10000; + private SetType setType = SetType.DEFAULT; + private final boolean ifNotExists; + private final FunctionName functionName; + private final boolean isAggregate; + private final boolean isAlias; + private boolean isTableFunction; + private final FunctionArgsDefInfo argsDef; + private final DataType returnType; + private DataType intermediateType; + private final Map properties; + private final List parameters; + private final Expression originFunction; + private TFunctionBinaryType binaryType = TFunctionBinaryType.JAVA_UDF; + // needed item set after analyzed + private String userFile; + private Function function; + private String checksum = ""; + private boolean isStaticLoad = false; + private long expirationTime = 360; // default 6 hours = 360 minutes + // now set udf default NullableMode is ALWAYS_NULLABLE + // if not, will core dump when input is not null column, but need return null + // like https://github.com/apache/doris/pull/14002/files + private NullableMode returnNullMode = NullableMode.ALWAYS_NULLABLE; + + /** + * CreateFunctionCommand + */ + public CreateFunctionCommand(SetType setType, boolean ifNotExists, boolean isAggregate, FunctionName functionName, + FunctionArgsDefInfo argsDef, + DataType returnType, DataType intermediateType, Map properties) { + super(PlanType.CREATE_FUNCTION_COMMAND); + this.setType = setType; + this.ifNotExists = ifNotExists; + this.functionName = functionName; + this.isAggregate = isAggregate; + this.argsDef = argsDef; + this.returnType = returnType; + this.intermediateType = intermediateType; + if (properties == null) { + this.properties = ImmutableSortedMap.of(); + } else { + this.properties = ImmutableSortedMap.copyOf(properties, String.CASE_INSENSITIVE_ORDER); + } + this.isAlias = false; + this.isTableFunction = false; + this.parameters = ImmutableList.of(); + this.originFunction = null; + } + + public CreateFunctionCommand(SetType setType, boolean ifNotExists, FunctionName functionName, + FunctionArgsDefInfo argsDef, + DataType returnType, DataType intermediateType, Map properties) { + this(setType, ifNotExists, false, functionName, argsDef, returnType, intermediateType, properties); + this.isTableFunction = true; + } + + /** + * CreateFunctionCommand + */ + public CreateFunctionCommand(SetType setType, boolean ifNotExists, FunctionName functionName, + FunctionArgsDefInfo argsDef, + List parameters, Expression originFunction) { + super(PlanType.CREATE_FUNCTION_COMMAND); + this.setType = setType; + this.ifNotExists = ifNotExists; + this.functionName = functionName; + this.isAlias = true; + this.argsDef = argsDef; + if (parameters == null) { + this.parameters = ImmutableList.of(); + } else { + this.parameters = ImmutableList.copyOf(parameters); + } + this.originFunction = originFunction; + this.isAggregate = false; + this.isTableFunction = false; + this.returnType = VarcharType.MAX_VARCHAR_TYPE; + this.properties = ImmutableSortedMap.of(); + } + + @Override + public void run(ConnectContext ctx, StmtExecutor executor) throws Exception { + validate(ctx); + if (SetType.GLOBAL.equals(setType)) { + Env.getCurrentEnv().getGlobalFunctionMgr().addFunction(function, ifNotExists); + } else { + String dbName = functionName.getDb(); + if (dbName == null) { + dbName = ctx.getDatabase(); + } + Database db = Env.getCurrentInternalCatalog().getDbOrDdlException(dbName); + db.addFunction(function, ifNotExists); + if (function.isUDTFunction()) { + // all of the table function in doris will have two function + // one is the noraml, and another is outer, the different of them is deal with + // empty: whether need to insert NULL result value + Function outerFunction = function.clone(); + FunctionName name = outerFunction.getFunctionName(); + name.setFn(name.getFunction() + "_outer"); + db.addFunction(outerFunction, ifNotExists); + } + } + } + + @Override + public R accept(PlanVisitor visitor, C context) { + return visitor.visitCreateFunctionCommand(this, context); + } + + @Override + public StmtType stmtType() { + return StmtType.CREATE; + } + + private void validate(ConnectContext ctx) throws Exception { + // https://github.com/apache/doris/issues/17810 + // this error report in P0 test, so we suspect that it is related to concurrency + // add this change to test it. + if (Config.use_fuzzy_session_variable) { + synchronized (CreateFunctionCommand.class) { + analyzeCommon(ctx); + // check + if (isAggregate) { + analyzeUda(); + } else if (isAlias) { + analyzeAliasFunction(ctx); + } else if (isTableFunction) { + analyzeTableFunction(); + } else { + analyzeUdf(); + } + } + } else { + analyzeCommon(ctx); + // check + if (isAggregate) { + analyzeUda(); + } else if (isAlias) { + analyzeAliasFunction(ctx); + } else if (isTableFunction) { + analyzeTableFunction(); + } else { + analyzeUdf(); + } + } + } + + private void analyzeCommon(ConnectContext ctx) throws AnalysisException { + // check function name + if (functionName.getDb() == null) { + String db = ctx.getDatabase(); + if (Strings.isNullOrEmpty(db) && setType != SetType.GLOBAL) { + ErrorReport.reportAnalysisException(ErrorCode.ERR_NO_DB_ERROR); + } + } + + // check operation privilege + if (!Env.getCurrentEnv().getAccessManager().checkGlobalPriv(ConnectContext.get(), PrivPredicate.ADMIN)) { + ErrorReport.reportAnalysisException(ErrorCode.ERR_SPECIFIC_ACCESS_DENIED_ERROR, "ADMIN"); + } + // check argument + argsDef.validate(); + + // alias function does not need analyze following params + if (isAlias) { + return; + } + TypeCoercionUtils.validateDataType(returnType); + if (intermediateType != null) { + TypeCoercionUtils.validateDataType(intermediateType); + } else { + intermediateType = returnType; + } + + String type = properties.getOrDefault(BINARY_TYPE, "JAVA_UDF"); + binaryType = getFunctionBinaryType(type); + if (binaryType == null) { + throw new AnalysisException("unknown function type"); + } + if (type.equals("NATIVE")) { + throw new AnalysisException("do not support 'NATIVE' udf type after doris version 1.2.0," + + "please use JAVA_UDF or RPC instead"); + } + + userFile = properties.getOrDefault(FILE_KEY, properties.get(OBJECT_FILE_KEY)); + if (!Strings.isNullOrEmpty(userFile) && binaryType != TFunctionBinaryType.RPC) { + try { + computeObjectChecksum(); + } catch (IOException | NoSuchAlgorithmException e) { + throw new AnalysisException("cannot to compute object's checksum. err: " + e.getMessage()); + } + String md5sum = properties.get(MD5_CHECKSUM); + if (md5sum != null && !md5sum.equalsIgnoreCase(checksum)) { + throw new AnalysisException("library's checksum is not equal with input, checksum=" + checksum); + } + } + if (binaryType == TFunctionBinaryType.JAVA_UDF) { + FunctionUtil.checkEnableJavaUdf(); + + // always_nullable the default value is true, equal null means true + Boolean isReturnNull = parseBooleanFromProperties(IS_RETURN_NULL); + if (isReturnNull != null && !isReturnNull) { + returnNullMode = NullableMode.ALWAYS_NOT_NULLABLE; + } + // static_load the default value is false, equal null means false + Boolean staticLoad = parseBooleanFromProperties(IS_STATIC_LOAD); + if (staticLoad != null && staticLoad) { + isStaticLoad = true; + } + String expirationTimeString = properties.get(EXPIRATION_TIME); + if (expirationTimeString != null) { + long timeMinutes = Long.parseLong(expirationTimeString); + if (timeMinutes <= 0) { + throw new AnalysisException("expirationTime should greater than zero: "); + } + this.expirationTime = timeMinutes; + } + } + } + + private Boolean parseBooleanFromProperties(String propertyString) throws AnalysisException { + String valueOfString = properties.get(propertyString); + if (valueOfString == null) { + return null; + } + if (!valueOfString.equalsIgnoreCase("false") && !valueOfString.equalsIgnoreCase("true")) { + throw new AnalysisException(propertyString + " in properties, you should set it false or true"); + } + return Boolean.parseBoolean(valueOfString); + } + + private void computeObjectChecksum() throws IOException, NoSuchAlgorithmException { + if (FeConstants.runningUnitTest) { + // skip checking checksum when running ut + return; + } + + try (InputStream inputStream = Util.getInputStreamFromUrl(userFile, null, HTTP_TIMEOUT_MS, HTTP_TIMEOUT_MS)) { + MessageDigest digest = MessageDigest.getInstance("MD5"); + byte[] buf = new byte[4096]; + int bytesRead = 0; + do { + bytesRead = inputStream.read(buf); + if (bytesRead < 0) { + break; + } + digest.update(buf, 0, bytesRead); + } while (true); + + checksum = Hex.encodeHexString(digest.digest()); + } + } + + private void analyzeTableFunction() throws AnalysisException { + String symbol = properties.get(SYMBOL_KEY); + if (Strings.isNullOrEmpty(symbol)) { + throw new AnalysisException("No 'symbol' in properties"); + } + if (!returnType.isArrayType()) { + throw new AnalysisException("JAVA_UDF OF UDTF return type must be array type"); + } + analyzeJavaUdf(symbol); + URI location; + if (!Strings.isNullOrEmpty(userFile)) { + location = URI.create(userFile); + } else { + location = null; + } + function = ScalarFunction.createUdf(binaryType, + functionName, argsDef.getArgTypes(), + ((ArrayType) (returnType.toCatalogDataType())).getItemType(), argsDef.isVariadic(), + location, symbol, null, null); + function.setChecksum(checksum); + function.setNullableMode(returnNullMode); + function.setUDTFunction(true); + // Todo: maybe in create tables function, need register two function, one is + // normal and one is outer as those have different result when result is NULL. + } + + private void analyzeUda() throws AnalysisException { + AggregateFunction.AggregateFunctionBuilder builder = AggregateFunction.AggregateFunctionBuilder + .createUdfBuilder(); + URI location; + if (!Strings.isNullOrEmpty(userFile)) { + location = URI.create(userFile); + } else { + location = null; + } + builder.name(functionName).argsType(argsDef.getArgTypes()).retType(returnType.toCatalogDataType()) + .hasVarArgs(argsDef.isVariadic()).intermediateType(intermediateType.toCatalogDataType()) + .location(location); + String initFnSymbol = properties.get(INIT_KEY); + if (initFnSymbol == null && !(binaryType == TFunctionBinaryType.JAVA_UDF + || binaryType == TFunctionBinaryType.RPC)) { + throw new AnalysisException("No 'init_fn' in properties"); + } + String updateFnSymbol = properties.get(UPDATE_KEY); + if (updateFnSymbol == null && !(binaryType == TFunctionBinaryType.JAVA_UDF)) { + throw new AnalysisException("No 'update_fn' in properties"); + } + String mergeFnSymbol = properties.get(MERGE_KEY); + if (mergeFnSymbol == null && !(binaryType == TFunctionBinaryType.JAVA_UDF)) { + throw new AnalysisException("No 'merge_fn' in properties"); + } + String serializeFnSymbol = properties.get(SERIALIZE_KEY); + String finalizeFnSymbol = properties.get(FINALIZE_KEY); + String getValueFnSymbol = properties.get(GET_VALUE_KEY); + String removeFnSymbol = properties.get(REMOVE_KEY); + String symbol = properties.get(SYMBOL_KEY); + if (binaryType == TFunctionBinaryType.RPC && !userFile.contains("://")) { + if (initFnSymbol != null) { + checkRPCUdf(initFnSymbol); + } + checkRPCUdf(updateFnSymbol); + checkRPCUdf(mergeFnSymbol); + if (serializeFnSymbol != null) { + checkRPCUdf(serializeFnSymbol); + } + if (finalizeFnSymbol != null) { + checkRPCUdf(finalizeFnSymbol); + } + if (getValueFnSymbol != null) { + checkRPCUdf(getValueFnSymbol); + } + if (removeFnSymbol != null) { + checkRPCUdf(removeFnSymbol); + } + } else if (binaryType == TFunctionBinaryType.JAVA_UDF) { + if (Strings.isNullOrEmpty(symbol)) { + throw new AnalysisException("No 'symbol' in properties of java-udaf"); + } + analyzeJavaUdaf(symbol); + } + function = builder.initFnSymbol(initFnSymbol).updateFnSymbol(updateFnSymbol).mergeFnSymbol(mergeFnSymbol) + .serializeFnSymbol(serializeFnSymbol).finalizeFnSymbol(finalizeFnSymbol) + .getValueFnSymbol(getValueFnSymbol).removeFnSymbol(removeFnSymbol).symbolName(symbol).build(); + function.setLocation(location); + function.setBinaryType(binaryType); + function.setChecksum(checksum); + function.setNullableMode(returnNullMode); + } + + private void analyzeUdf() throws AnalysisException { + String symbol = properties.get(SYMBOL_KEY); + if (Strings.isNullOrEmpty(symbol)) { + throw new AnalysisException("No 'symbol' in properties"); + } + String prepareFnSymbol = properties.get(PREPARE_SYMBOL_KEY); + String closeFnSymbol = properties.get(CLOSE_SYMBOL_KEY); + // TODO(yangzhg) support check function in FE when function service behind load balancer + // the format for load balance can ref https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#connect-to-a-cluster + if (binaryType == TFunctionBinaryType.RPC && !userFile.contains("://")) { + if (StringUtils.isNotBlank(prepareFnSymbol) || StringUtils.isNotBlank(closeFnSymbol)) { + throw new AnalysisException("prepare and close in RPC UDF are not supported."); + } + checkRPCUdf(symbol); + } else if (binaryType == TFunctionBinaryType.JAVA_UDF) { + analyzeJavaUdf(symbol); + } + URI location; + if (!Strings.isNullOrEmpty(userFile)) { + location = URI.create(userFile); + } else { + location = null; + } + function = ScalarFunction.createUdf(binaryType, + functionName, argsDef.getArgTypes(), + returnType.toCatalogDataType(), argsDef.isVariadic(), + location, symbol, prepareFnSymbol, closeFnSymbol); + function.setChecksum(checksum); + function.setNullableMode(returnNullMode); + function.setStaticLoad(isStaticLoad); + function.setExpirationTime(expirationTime); + } + + private void analyzeJavaUdaf(String clazz) throws AnalysisException { + HashMap allMethods = new HashMap<>(); + + try { + if (Strings.isNullOrEmpty(userFile)) { + try { + ClassLoader cl = this.getClass().getClassLoader(); + checkUdafClass(clazz, cl, allMethods); + return; + } catch (ClassNotFoundException e) { + throw new AnalysisException("Class [" + clazz + "] not found in classpath"); + } + } + URL[] urls = { new URL("jar:" + userFile + "!/") }; + try (URLClassLoader cl = URLClassLoader.newInstance(urls)) { + checkUdafClass(clazz, cl, allMethods); + } catch (ClassNotFoundException e) { + throw new AnalysisException( + "Class [" + clazz + "] or inner class [State] not found in file :" + userFile); + } catch (IOException e) { + throw new AnalysisException("Failed to load file: " + userFile); + } + } catch (MalformedURLException e) { + throw new AnalysisException("Failed to load file: " + userFile); + } + } + + private void checkUdafClass(String clazz, ClassLoader cl, HashMap allMethods) + throws ClassNotFoundException, AnalysisException { + Class udfClass = cl.loadClass(clazz); + String udfClassName = udfClass.getCanonicalName(); + String stateClassName = udfClassName + "$" + STATE_CLASS_NAME; + Class stateClass = cl.loadClass(stateClassName); + + for (Method m : udfClass.getMethods()) { + if (!m.getDeclaringClass().equals(udfClass)) { + continue; + } + String name = m.getName(); + if (allMethods.containsKey(name)) { + throw new AnalysisException( + String.format("UDF class '%s' has multiple methods with name '%s' ", udfClassName, + name)); + } + allMethods.put(name, m); + } + + if (allMethods.get(CREATE_METHOD_NAME) == null) { + throw new AnalysisException( + String.format("No method '%s' in class '%s'!", CREATE_METHOD_NAME, udfClassName)); + } else { + checkMethodNonStaticAndPublic(CREATE_METHOD_NAME, allMethods.get(CREATE_METHOD_NAME), udfClassName); + checkArgumentCount(allMethods.get(CREATE_METHOD_NAME), 0, udfClassName); + checkReturnJavaType(udfClassName, allMethods.get(CREATE_METHOD_NAME), stateClass); + } + + if (allMethods.get(DESTROY_METHOD_NAME) == null) { + throw new AnalysisException( + String.format("No method '%s' in class '%s'!", DESTROY_METHOD_NAME, udfClassName)); + } else { + checkMethodNonStaticAndPublic(DESTROY_METHOD_NAME, allMethods.get(DESTROY_METHOD_NAME), + udfClassName); + checkArgumentCount(allMethods.get(DESTROY_METHOD_NAME), 1, udfClassName); + checkReturnJavaType(udfClassName, allMethods.get(DESTROY_METHOD_NAME), void.class); + } + + if (allMethods.get(ADD_METHOD_NAME) == null) { + throw new AnalysisException( + String.format("No method '%s' in class '%s'!", ADD_METHOD_NAME, udfClassName)); + } else { + checkMethodNonStaticAndPublic(ADD_METHOD_NAME, allMethods.get(ADD_METHOD_NAME), udfClassName); + checkArgumentCount(allMethods.get(ADD_METHOD_NAME), argsDef.getArgTypes().length + 1, udfClassName); + checkReturnJavaType(udfClassName, allMethods.get(ADD_METHOD_NAME), void.class); + for (int i = 0; i < argsDef.getArgTypes().length; i++) { + Parameter p = allMethods.get(ADD_METHOD_NAME).getParameters()[i + 1]; + checkUdfType(udfClass, allMethods.get(ADD_METHOD_NAME), argsDef.getArgTypes()[i], p.getType(), + p.getName()); + } + } + + if (allMethods.get(SERIALIZE_METHOD_NAME) == null) { + throw new AnalysisException( + String.format("No method '%s' in class '%s'!", SERIALIZE_METHOD_NAME, udfClassName)); + } else { + checkMethodNonStaticAndPublic(SERIALIZE_METHOD_NAME, allMethods.get(SERIALIZE_METHOD_NAME), + udfClassName); + checkArgumentCount(allMethods.get(SERIALIZE_METHOD_NAME), 2, udfClassName); + checkReturnJavaType(udfClassName, allMethods.get(SERIALIZE_METHOD_NAME), void.class); + } + + if (allMethods.get(MERGE_METHOD_NAME) == null) { + throw new AnalysisException( + String.format("No method '%s' in class '%s'!", MERGE_METHOD_NAME, udfClassName)); + } else { + checkMethodNonStaticAndPublic(MERGE_METHOD_NAME, allMethods.get(MERGE_METHOD_NAME), udfClassName); + checkArgumentCount(allMethods.get(MERGE_METHOD_NAME), 2, udfClassName); + checkReturnJavaType(udfClassName, allMethods.get(MERGE_METHOD_NAME), void.class); + } + + if (allMethods.get(GETVALUE_METHOD_NAME) == null) { + throw new AnalysisException( + String.format("No method '%s' in class '%s'!", GETVALUE_METHOD_NAME, udfClassName)); + } else { + checkMethodNonStaticAndPublic(GETVALUE_METHOD_NAME, allMethods.get(GETVALUE_METHOD_NAME), + udfClassName); + checkArgumentCount(allMethods.get(GETVALUE_METHOD_NAME), 1, udfClassName); + checkReturnUdfType(udfClass, allMethods.get(GETVALUE_METHOD_NAME), returnType.toCatalogDataType()); + } + + if (!Modifier.isPublic(stateClass.getModifiers()) || !Modifier.isStatic(stateClass.getModifiers())) { + throw new AnalysisException( + String.format( + "UDAF '%s' should have one public & static 'State' class to Construction data ", + udfClassName)); + } + } + + private void checkMethodNonStaticAndPublic(String methoName, Method method, String udfClassName) + throws AnalysisException { + if (Modifier.isStatic(method.getModifiers())) { + throw new AnalysisException( + String.format("Method '%s' in class '%s' should be non-static", methoName, udfClassName)); + } + if (!Modifier.isPublic(method.getModifiers())) { + throw new AnalysisException( + String.format("Method '%s' in class '%s' should be public", methoName, udfClassName)); + } + } + + private void checkArgumentCount(Method method, int argumentCount, String udfClassName) throws AnalysisException { + if (method.getParameters().length != argumentCount) { + throw new AnalysisException( + String.format("The number of parameters for method '%s' in class '%s' should be %d", + method.getName(), udfClassName, argumentCount)); + } + } + + private void checkReturnJavaType(String udfClassName, Method method, Class expType) throws AnalysisException { + checkJavaType(udfClassName, method, expType, method.getReturnType(), "return"); + } + + private void checkJavaType(String udfClassName, Method method, Class expType, Class ptype, String pname) + throws AnalysisException { + if (!expType.equals(ptype)) { + throw new AnalysisException( + String.format("UDF class '%s' method '%s' parameter %s[%s] expect type %s", udfClassName, + method.getName(), pname, ptype.getCanonicalName(), expType.getCanonicalName())); + } + } + + private void checkReturnUdfType(Class clazz, Method method, Type expType) throws AnalysisException { + checkUdfType(clazz, method, expType, method.getReturnType(), "return"); + } + + private void analyzeJavaUdf(String clazz) throws AnalysisException { + try { + if (Strings.isNullOrEmpty(userFile)) { + try { + ClassLoader cl = this.getClass().getClassLoader(); + checkUdfClass(clazz, cl); + return; + } catch (ClassNotFoundException e) { + throw new AnalysisException("Class [" + clazz + "] not found in classpath"); + } + } + URL[] urls = { new URL("jar:" + userFile + "!/") }; + try (URLClassLoader cl = URLClassLoader.newInstance(urls)) { + checkUdfClass(clazz, cl); + } catch (ClassNotFoundException e) { + throw new AnalysisException("Class [" + clazz + "] not found in file :" + userFile); + } catch (IOException e) { + throw new AnalysisException("Failed to load file: " + userFile); + } + } catch (MalformedURLException e) { + throw new AnalysisException("Failed to load file: " + userFile); + } + } + + private void checkUdfClass(String clazz, ClassLoader cl) throws ClassNotFoundException, AnalysisException { + Class udfClass = cl.loadClass(clazz); + List evalList = Arrays.stream(udfClass.getMethods()) + .filter(m -> m.getDeclaringClass().equals(udfClass) && EVAL_METHOD_KEY.equals(m.getName())) + .collect(Collectors.toList()); + if (evalList.size() == 0) { + throw new AnalysisException(String.format( + "No method '%s' in class '%s'!", EVAL_METHOD_KEY, udfClass.getCanonicalName())); + } + List evalNonStaticAndPublicList = evalList.stream() + .filter(m -> !Modifier.isStatic(m.getModifiers()) && Modifier.isPublic(m.getModifiers())) + .collect(Collectors.toList()); + if (evalNonStaticAndPublicList.size() == 0) { + throw new AnalysisException( + String.format("Method '%s' in class '%s' should be non-static and public", EVAL_METHOD_KEY, + udfClass.getCanonicalName())); + } + List evalArgLengthMatchList = evalNonStaticAndPublicList.stream().filter( + m -> m.getParameters().length == argsDef.getArgTypes().length).collect(Collectors.toList()); + if (evalArgLengthMatchList.size() == 0) { + throw new AnalysisException( + String.format("The number of parameters for method '%s' in class '%s' should be %d", + EVAL_METHOD_KEY, udfClass.getCanonicalName(), argsDef.getArgTypes().length)); + } else if (evalArgLengthMatchList.size() == 1) { + Method method = evalArgLengthMatchList.get(0); + checkUdfType(udfClass, method, returnType.toCatalogDataType(), method.getReturnType(), "return"); + for (int i = 0; i < method.getParameters().length; i++) { + Parameter p = method.getParameters()[i]; + checkUdfType(udfClass, method, argsDef.getArgTypes()[i], p.getType(), p.getName()); + } + } else { + // If multiple methods have the same parameters, + // the error message returned cannot be as specific as a single method + boolean hasError = false; + for (Method method : evalArgLengthMatchList) { + try { + checkUdfType(udfClass, method, returnType.toCatalogDataType(), method.getReturnType(), "return"); + for (int i = 0; i < method.getParameters().length; i++) { + Parameter p = method.getParameters()[i]; + checkUdfType(udfClass, method, argsDef.getArgTypes()[i], p.getType(), p.getName()); + } + hasError = false; + break; + } catch (AnalysisException e) { + hasError = true; + } + } + if (hasError) { + throw new AnalysisException(String.format( + "Multi methods '%s' in class '%s' and no one passed parameter matching verification", + EVAL_METHOD_KEY, udfClass.getCanonicalName())); + } + } + } + + private void checkUdfType(Class clazz, Method method, Type expType, Class pType, String pname) + throws AnalysisException { + Set javaTypes; + if (expType instanceof ScalarType) { + ScalarType scalarType = (ScalarType) expType; + javaTypes = Type.PrimitiveTypeToJavaClassType.get(scalarType.getPrimitiveType()); + } else if (expType instanceof ArrayType) { + ArrayType arrayType = (ArrayType) expType; + javaTypes = Type.PrimitiveTypeToJavaClassType.get(arrayType.getPrimitiveType()); + } else if (expType instanceof MapType) { + MapType mapType = (MapType) expType; + javaTypes = Type.PrimitiveTypeToJavaClassType.get(mapType.getPrimitiveType()); + } else if (expType instanceof StructType) { + StructType structType = (StructType) expType; + javaTypes = Type.PrimitiveTypeToJavaClassType.get(structType.getPrimitiveType()); + } else { + throw new AnalysisException( + String.format("Method '%s' in class '%s' does not support type '%s'", + method.getName(), clazz.getCanonicalName(), expType)); + } + + if (javaTypes == null) { + throw new AnalysisException( + String.format("Method '%s' in class '%s' does not support type '%s'", + method.getName(), clazz.getCanonicalName(), expType.toString())); + } + if (!javaTypes.contains(pType)) { + throw new AnalysisException( + String.format("UDF class '%s' method '%s' %s[%s] type is not supported!", + clazz.getCanonicalName(), method.getName(), pname, pType.getCanonicalName())); + } + } + + private void checkRPCUdf(String symbol) throws AnalysisException { + // TODO(yangzhg) support check function in FE when function service behind load balancer + // the format for load balance can ref https://github.com/apache/incubator-brpc/blob/master/docs/en/client.md#connect-to-a-cluster + String[] url = userFile.split(":"); + if (url.length != 2) { + throw new AnalysisException("function server address invalid."); + } + String host = url[0]; + int port = Integer.valueOf(url[1]); + ManagedChannel channel = NettyChannelBuilder.forAddress(host, port) + .flowControlWindow(Config.grpc_max_message_size_bytes) + .maxInboundMessageSize(Config.grpc_max_message_size_bytes) + .enableRetry().maxRetryAttempts(3) + .usePlaintext().build(); + PFunctionServiceGrpc.PFunctionServiceBlockingStub stub = PFunctionServiceGrpc.newBlockingStub(channel); + FunctionService.PCheckFunctionRequest.Builder builder = FunctionService.PCheckFunctionRequest.newBuilder(); + builder.getFunctionBuilder().setFunctionName(symbol); + for (Type arg : argsDef.getArgTypes()) { + builder.getFunctionBuilder().addInputs(convertToPParameterType(arg)); + } + builder.getFunctionBuilder().setOutput(convertToPParameterType(returnType.toCatalogDataType())); + FunctionService.PCheckFunctionResponse response = stub.checkFn(builder.build()); + if (response == null || !response.hasStatus()) { + throw new AnalysisException("cannot access function server"); + } + if (response.getStatus().getStatusCode() != 0) { + throw new AnalysisException("check function [" + symbol + "] failed: " + response.getStatus()); + } + } + + private Types.PGenericType convertToPParameterType(Type arg) throws AnalysisException { + Types.PGenericType.Builder typeBuilder = Types.PGenericType.newBuilder(); + switch (arg.getPrimitiveType()) { + case INVALID_TYPE: + typeBuilder.setId(Types.PGenericType.TypeId.UNKNOWN); + break; + case BOOLEAN: + typeBuilder.setId(Types.PGenericType.TypeId.BOOLEAN); + break; + case SMALLINT: + typeBuilder.setId(Types.PGenericType.TypeId.INT16); + break; + case TINYINT: + typeBuilder.setId(Types.PGenericType.TypeId.INT8); + break; + case INT: + typeBuilder.setId(Types.PGenericType.TypeId.INT32); + break; + case BIGINT: + typeBuilder.setId(Types.PGenericType.TypeId.INT64); + break; + case FLOAT: + typeBuilder.setId(Types.PGenericType.TypeId.FLOAT); + break; + case DOUBLE: + typeBuilder.setId(Types.PGenericType.TypeId.DOUBLE); + break; + case CHAR: + case VARCHAR: + typeBuilder.setId(Types.PGenericType.TypeId.STRING); + break; + case HLL: + typeBuilder.setId(Types.PGenericType.TypeId.HLL); + break; + case BITMAP: + typeBuilder.setId(Types.PGenericType.TypeId.BITMAP); + break; + case QUANTILE_STATE: + typeBuilder.setId(Types.PGenericType.TypeId.QUANTILE_STATE); + break; + case AGG_STATE: + typeBuilder.setId(Types.PGenericType.TypeId.AGG_STATE); + break; + case DATE: + typeBuilder.setId(Types.PGenericType.TypeId.DATE); + break; + case DATEV2: + typeBuilder.setId(Types.PGenericType.TypeId.DATEV2); + break; + case DATETIME: + case TIME: + typeBuilder.setId(Types.PGenericType.TypeId.DATETIME); + break; + case DATETIMEV2: + case TIMEV2: + typeBuilder.setId(Types.PGenericType.TypeId.DATETIMEV2); + break; + case DECIMALV2: + case DECIMAL128: + typeBuilder.setId(Types.PGenericType.TypeId.DECIMAL128) + .getDecimalTypeBuilder() + .setPrecision(((ScalarType) arg).getScalarPrecision()) + .setScale(((ScalarType) arg).getScalarScale()); + break; + case DECIMAL32: + typeBuilder.setId(Types.PGenericType.TypeId.DECIMAL32) + .getDecimalTypeBuilder() + .setPrecision(((ScalarType) arg).getScalarPrecision()) + .setScale(((ScalarType) arg).getScalarScale()); + break; + case DECIMAL64: + typeBuilder.setId(Types.PGenericType.TypeId.DECIMAL64) + .getDecimalTypeBuilder() + .setPrecision(((ScalarType) arg).getScalarPrecision()) + .setScale(((ScalarType) arg).getScalarScale()); + break; + case LARGEINT: + typeBuilder.setId(Types.PGenericType.TypeId.INT128); + break; + default: + throw new AnalysisException("type " + arg.getPrimitiveType().toString() + " is not supported"); + } + return typeBuilder.build(); + } + + private TFunctionBinaryType getFunctionBinaryType(String type) { + TFunctionBinaryType binaryType = null; + try { + binaryType = TFunctionBinaryType.valueOf(type); + } catch (IllegalArgumentException e) { + // ignore enum Exception + } + return binaryType; + } + + private void analyzeAliasFunction(ConnectContext ctx) throws AnalysisException { + function = AliasFunction.createFunction(functionName, argsDef.getArgTypes(), + Type.VARCHAR, argsDef.isVariadic(), parameters, translateToLegacyExpr(originFunction, ctx)); + ((AliasFunction) function).analyze(); + } + + /** + * translate to legacy expr, which do not need complex expression and table columns + */ + private Expr translateToLegacyExpr(Expression expression, ConnectContext ctx) { + LogicalEmptyRelation plan = new LogicalEmptyRelation( + ConnectContext.get().getStatementContext().getNextRelationId(), new ArrayList<>()); + CascadesContext cascadesContext = CascadesContext.initContext(ctx.getStatementContext(), plan, + PhysicalProperties.ANY); + PlanTranslatorContext translatorContext = new PlanTranslatorContext(cascadesContext); + return ExpressionTranslator.translate(expression, translatorContext); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java new file mode 100644 index 00000000000000..75dde73e742cc5 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.plans.commands; + +import org.apache.doris.analysis.FunctionName; +import org.apache.doris.analysis.SetType; +import org.apache.doris.analysis.StmtType; +import org.apache.doris.catalog.Database; +import org.apache.doris.catalog.Env; +import org.apache.doris.catalog.FunctionSearchDesc; +import org.apache.doris.nereids.trees.plans.PlanType; +import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgsDefInfo; +import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.StmtExecutor; +import org.apache.doris.system.Backend; +import org.apache.doris.task.AgentBatchTask; +import org.apache.doris.task.AgentTaskExecutor; +import org.apache.doris.task.CleanUDFCacheTask; + +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableMap; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * drop a alias or user defined function + */ +public class DropFunctionCommand extends Command implements ForwardWithSync { + private static final Logger LOG = LogManager.getLogger(DropFunctionCommand.class); + private SetType setType; + private final boolean ifExists; + private final FunctionName functionName; + private final FunctionArgsDefInfo argsDef; + // set after analyzed + private FunctionSearchDesc function; + + /** + * DropFunctionCommand + */ + public DropFunctionCommand(SetType setType, boolean ifExists, FunctionName functionName, + FunctionArgsDefInfo argsDef) { + super(PlanType.CREATE_FUNCTION_COMMAND); + this.setType = setType; + this.ifExists = ifExists; + this.functionName = functionName; + this.argsDef = argsDef; + } + + @Override + public void run(ConnectContext ctx, StmtExecutor executor) throws Exception { + argsDef.validate(); + function = new FunctionSearchDesc(functionName, argsDef.getArgTypes(), argsDef.isVariadic()); + if (SetType.GLOBAL.equals(setType)) { + Env.getCurrentEnv().getGlobalFunctionMgr().dropFunction(function, ifExists); + } else { + String dbName = functionName.getDb(); + if (dbName == null) { + dbName = ctx.getDatabase(); + } + Database db = Env.getCurrentInternalCatalog().getDbOrDdlException(dbName); + db.dropFunction(function, ifExists); + } + // BE will cache classload, when drop function, BE need clear cache + ImmutableMap backendsInfo = Env.getCurrentSystemInfo().getAllBackendsByAllCluster(); + String functionSignature = getSignatureString(); + AgentBatchTask batchTask = new AgentBatchTask(); + for (Backend backend : backendsInfo.values()) { + CleanUDFCacheTask cleanUDFCacheTask = new CleanUDFCacheTask(backend.getId(), functionSignature); + batchTask.addTask(cleanUDFCacheTask); + LOG.info("clean udf cache in be {}, beId {}", backend.getHost(), backend.getId()); + } + AgentTaskExecutor.submit(batchTask); + + } + + @Override + public R accept(PlanVisitor visitor, C context) { + return visitor.visitDropFunctionCommand(this, context); + } + + @Override + public StmtType stmtType() { + return StmtType.DROP; + } + + private String getSignatureString() { + StringBuilder sb = new StringBuilder(); + sb.append(functionName.getFunction()).append("(").append(Joiner.on(", ").join(argsDef.getArgTypes())); + sb.append(")"); + return sb.toString(); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java new file mode 100644 index 00000000000000..b8f93c0e6ca278 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.plans.commands.info; + +import org.apache.doris.catalog.Type; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.util.TypeCoercionUtils; + +import java.util.List; + +/** + * represent function arguments + */ +public class FunctionArgsDefInfo { + private final List argTypeDefs; + private final boolean isVariadic; + + // set after analyze + private Type[] argTypes; + + public FunctionArgsDefInfo(List argTypeDefs, boolean isVariadic) { + this.argTypeDefs = argTypeDefs; + this.isVariadic = isVariadic; + } + + public Type[] getArgTypes() { + return argTypes; + } + + public boolean isVariadic() { + return isVariadic; + } + + /** + * validate + */ + public void validate() { + argTypes = new Type[argTypeDefs.size()]; + int i = 0; + for (DataType dataType : argTypeDefs) { + TypeCoercionUtils.validateDataType(dataType); + argTypes[i++] = dataType.toCatalogDataType(); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/CommandVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/CommandVisitor.java index 2c926c1de1a9a2..cee1c8476109de 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/CommandVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/visitor/CommandVisitor.java @@ -46,6 +46,7 @@ import org.apache.doris.nereids.trees.plans.commands.CreateCatalogCommand; import org.apache.doris.nereids.trees.plans.commands.CreateEncryptkeyCommand; import org.apache.doris.nereids.trees.plans.commands.CreateFileCommand; +import org.apache.doris.nereids.trees.plans.commands.CreateFunctionCommand; import org.apache.doris.nereids.trees.plans.commands.CreateJobCommand; import org.apache.doris.nereids.trees.plans.commands.CreateMTMVCommand; import org.apache.doris.nereids.trees.plans.commands.CreatePolicyCommand; @@ -63,6 +64,7 @@ import org.apache.doris.nereids.trees.plans.commands.DropConstraintCommand; import org.apache.doris.nereids.trees.plans.commands.DropEncryptkeyCommand; import org.apache.doris.nereids.trees.plans.commands.DropFileCommand; +import org.apache.doris.nereids.trees.plans.commands.DropFunctionCommand; import org.apache.doris.nereids.trees.plans.commands.DropJobCommand; import org.apache.doris.nereids.trees.plans.commands.DropMTMVCommand; import org.apache.doris.nereids.trees.plans.commands.DropProcedureCommand; @@ -218,6 +220,14 @@ default R visitCreateEncryptKeyCommand(CreateEncryptkeyCommand createEncryptKeyC return visitCommand(createEncryptKeyCommand, context); } + default R visitCreateFunctionCommand(CreateFunctionCommand createFunctionCommand, C context) { + return visitCommand(createFunctionCommand, context); + } + + default R visitDropFunctionCommand(DropFunctionCommand dropFunctionCommand, C context) { + return visitCommand(dropFunctionCommand, context); + } + default R visitCreateTableCommand(CreateTableCommand createTableCommand, C context) { return visitCommand(createTableCommand, context); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index 5f8be613552c78..15142a77e80e01 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -112,6 +112,7 @@ import org.apache.doris.nereids.types.coercion.IntegralType; import org.apache.doris.nereids.types.coercion.NumericType; import org.apache.doris.nereids.types.coercion.PrimitiveType; +import org.apache.doris.qe.SessionVariable; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; @@ -125,10 +126,12 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.ListIterator; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; @@ -1847,4 +1850,243 @@ private static boolean supportCompare(DataType dataType) { } return true; } + + public static void validateDataType(DataType dataType) { + validateCatalogDataType(dataType.toCatalogDataType()); + } + + private static void validateCatalogDataType(Type catalogType) { + if (catalogType.exceedsMaxNestingDepth()) { + throw new AnalysisException( + String.format("Type exceeds the maximum nesting depth of %s:\n%s", + Type.MAX_NESTING_DEPTH, catalogType.toSql())); + } + if (!catalogType.isSupported()) { + throw new AnalysisException("Unsupported data type: " + catalogType.toSql()); + } + + if (catalogType.isScalarType()) { + validateScalarType((ScalarType) catalogType); + } else if (catalogType.isComplexType()) { + // now we not support array / map / struct nesting complex type + if (catalogType.isArrayType()) { + Type itemType = ((org.apache.doris.catalog.ArrayType) catalogType).getItemType(); + if (itemType instanceof ScalarType) { + validateNestedType(catalogType, (ScalarType) itemType); + } + } + if (catalogType.isMapType()) { + org.apache.doris.catalog.MapType mt = + (org.apache.doris.catalog.MapType) catalogType; + if (mt.getKeyType() instanceof ScalarType) { + validateNestedType(catalogType, (ScalarType) mt.getKeyType()); + } + if (mt.getValueType() instanceof ScalarType) { + validateNestedType(catalogType, (ScalarType) mt.getValueType()); + } + } + if (catalogType.isStructType()) { + ArrayList fields = + ((org.apache.doris.catalog.StructType) catalogType).getFields(); + Set fieldNames = new HashSet<>(); + for (org.apache.doris.catalog.StructField field : fields) { + Type fieldType = field.getType(); + if (fieldType instanceof ScalarType) { + validateNestedType(catalogType, (ScalarType) fieldType); + if (!fieldNames.add(field.getName())) { + throw new AnalysisException("Duplicate field name " + field.getName() + + " in struct " + catalogType.toSql()); + } + } + } + } + } + } + + private static void validateNestedType(Type parent, Type child) throws AnalysisException { + if (child.isNull()) { + throw new AnalysisException("Unsupported data type: " + child.toSql()); + } + // check whether the sub-type is supported + if (!parent.supportSubType(child)) { + throw new AnalysisException( + parent.getPrimitiveType() + " unsupported sub-type: " + child.toSql()); + } + validateCatalogDataType(child); + } + + private static void validateScalarType(ScalarType scalarType) { + org.apache.doris.catalog.PrimitiveType type = scalarType.getPrimitiveType(); + // When string type length is not assigned, it needs to be assigned to 1. + if (scalarType.getPrimitiveType().isStringType() && !scalarType.isLengthSet()) { + if (scalarType.getPrimitiveType() == org.apache.doris.catalog.PrimitiveType.VARCHAR) { + // always set varchar length MAX_VARCHAR_LENGTH + scalarType.setLength(ScalarType.MAX_VARCHAR_LENGTH); + } else if (scalarType.getPrimitiveType() == org.apache.doris.catalog.PrimitiveType.STRING) { + // always set text length MAX_STRING_LENGTH + scalarType.setLength(ScalarType.MAX_STRING_LENGTH); + } else { + scalarType.setLength(1); + } + } + switch (type) { + case CHAR: + case VARCHAR: { + String name; + int maxLen; + if (type == org.apache.doris.catalog.PrimitiveType.VARCHAR) { + name = "VARCHAR"; + maxLen = ScalarType.MAX_VARCHAR_LENGTH; + } else { + name = "CHAR"; + maxLen = ScalarType.MAX_CHAR_LENGTH; + } + int len = scalarType.getLength(); + // len is decided by child, when it is -1. + + if (len <= 0) { + throw new AnalysisException(name + " size must be > 0: " + len); + } + if (scalarType.getLength() > maxLen) { + throw new AnalysisException(name + " size must be <= " + maxLen + ": " + len); + } + break; + } + case DECIMALV2: { + int precision = scalarType.decimalPrecision(); + int scale = scalarType.decimalScale(); + // precision: [1, 27] + if (precision < 1 || precision > ScalarType.MAX_DECIMALV2_PRECISION) { + throw new AnalysisException("Precision of decimal must between 1 and 27." + + " Precision was set to: " + precision + "."); + } + // scale: [0, 9] + if (scale < 0 || scale > ScalarType.MAX_DECIMALV2_SCALE) { + throw new AnalysisException("Scale of decimal must between 0 and 9." + + " Scale was set to: " + scale + "."); + } + if (precision - scale > ScalarType.MAX_DECIMALV2_PRECISION + - ScalarType.MAX_DECIMALV2_SCALE) { + throw new AnalysisException("Invalid decimal type with precision = " + precision + + ", scale = " + scale); + } + // scale < precision + if (scale > precision) { + throw new AnalysisException("Scale of decimal must be smaller than precision." + + " Scale is " + scale + " and precision is " + precision); + } + break; + } + case DECIMAL32: { + int decimal32Precision = scalarType.decimalPrecision(); + int decimal32Scale = scalarType.decimalScale(); + if (decimal32Precision < 1 + || decimal32Precision > ScalarType.MAX_DECIMAL32_PRECISION) { + throw new AnalysisException("Precision of decimal must between 1 and 9." + + " Precision was set to: " + decimal32Precision + "."); + } + // scale >= 0 + if (decimal32Scale < 0) { + throw new AnalysisException("Scale of decimal must not be less than 0." + + " Scale was set to: " + decimal32Scale + "."); + } + // scale < precision + if (decimal32Scale > decimal32Precision) { + throw new AnalysisException( + "Scale of decimal must be smaller than precision." + " Scale is " + + decimal32Scale + " and precision is " + decimal32Precision); + } + break; + } + case DECIMAL64: { + int decimal64Precision = scalarType.decimalPrecision(); + int decimal64Scale = scalarType.decimalScale(); + if (decimal64Precision < 1 + || decimal64Precision > ScalarType.MAX_DECIMAL64_PRECISION) { + throw new AnalysisException("Precision of decimal64 must between 1 and 18." + + " Precision was set to: " + decimal64Precision + "."); + } + // scale >= 0 + if (decimal64Scale < 0) { + throw new AnalysisException("Scale of decimal must not be less than 0." + + " Scale was set to: " + decimal64Scale + "."); + } + // scale < precision + if (decimal64Scale > decimal64Precision) { + throw new AnalysisException( + "Scale of decimal must be smaller than precision." + " Scale is " + + decimal64Scale + " and precision is " + decimal64Precision); + } + break; + } + case DECIMAL128: { + int decimal128Precision = scalarType.decimalPrecision(); + int decimal128Scale = scalarType.decimalScale(); + if (decimal128Precision < 1 + || decimal128Precision > ScalarType.MAX_DECIMAL128_PRECISION) { + throw new AnalysisException("Precision of decimal128 must between 1 and 38." + + " Precision was set to: " + decimal128Precision + "."); + } + // scale >= 0 + if (decimal128Scale < 0) { + throw new AnalysisException("Scale of decimal must not be less than 0." + + " Scale was set to: " + decimal128Scale + "."); + } + // scale < precision + if (decimal128Scale > decimal128Precision) { + throw new AnalysisException( + "Scale of decimal must be smaller than precision." + " Scale is " + + decimal128Scale + " and precision is " + decimal128Precision); + } + break; + } + case DECIMAL256: { + if (SessionVariable.getEnableDecimal256()) { + int precision = scalarType.decimalPrecision(); + int scale = scalarType.decimalScale(); + if (precision < 1 || precision > ScalarType.MAX_DECIMAL256_PRECISION) { + throw new AnalysisException("Precision of decimal256 must between 1 and 76." + + " Precision was set to: " + precision + "."); + } + // scale >= 0 + if (scale < 0) { + throw new AnalysisException("Scale of decimal must not be less than 0." + + " Scale was set to: " + scale + "."); + } + // scale < precision + if (scale > precision) { + throw new AnalysisException( + "Scale of decimal must be smaller than precision." + " Scale is " + + scale + " and precision is " + precision); + } + break; + } else { + int precision = scalarType.decimalPrecision(); + throw new AnalysisException("Column of type Decimal256 with precision " + + precision + " in not supported."); + } + } + case TIMEV2: + case DATETIMEV2: { + int precision = scalarType.decimalPrecision(); + int scale = scalarType.decimalScale(); + // precision: [1, 27] + if (precision != ScalarType.DATETIME_PRECISION) { + throw new AnalysisException( + "Precision of Datetime/Time must be " + ScalarType.DATETIME_PRECISION + + "." + " Precision was set to: " + precision + "."); + } + // scale: [0, 9] + if (scale < 0 || scale > 6) { + throw new AnalysisException("Scale of Datetime/Time must between 0 and 6." + + " Scale was set to: " + scale + "."); + } + break; + } + case INVALID_TYPE: + throw new AnalysisException("Invalid type."); + default: + break; + } + } } From 654b25fd152bb1cc6d2197c65f225a0b7f67ec49 Mon Sep 17 00:00:00 2001 From: lichi Date: Wed, 25 Dec 2024 18:02:33 +0800 Subject: [PATCH 2/4] fix case --- .../nereids/parser/LogicalPlanBuilder.java | 10 +- .../plans/commands/CreateFunctionCommand.java | 153 +++++++++++++++++- .../plans/commands/DropFunctionCommand.java | 8 + .../commands/info/FunctionArgsDefInfo.java | 4 + .../suites/ddl_p0/test_alias_function.groovy | 4 +- .../sql_functions/test_alias_function.groovy | 6 + 6 files changed, 178 insertions(+), 7 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index 0ce8459e8db282..dd8f8f8f1ee1a5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -4161,7 +4161,11 @@ public Command visitCreateUserDefineFunction(CreateUserDefineFunctionContext ctx functionArgsDefInfo = new FunctionArgsDefInfo(new ArrayList<>(), false); } DataType returnType = typedVisit(ctx.returnType); + returnType = returnType.conversion(); DataType intermediateType = ctx.intermediateType != null ? typedVisit(ctx.intermediateType) : null; + if (intermediateType != null) { + intermediateType = intermediateType.conversion(); + } Map properties = ctx.propertyClause() != null ? Maps.newHashMap(visitPropertyClause(ctx.propertyClause())) : Maps.newHashMap(); @@ -4214,9 +4218,11 @@ public Command visitDropFunction(DropFunctionContext ctx) { String functionName = ctx.functionIdentifier().functionNameIdentifier().getText(); String dbName = ctx.functionIdentifier().dbName != null ? ctx.functionIdentifier().dbName.getText() : null; FunctionName function = new FunctionName(dbName, functionName); - FunctionArgsDefInfo functionArgsDefInfo = null; + FunctionArgsDefInfo functionArgsDefInfo; if (ctx.functionArguments() != null) { functionArgsDefInfo = visitFunctionArguments(ctx.functionArguments()); + } else { + functionArgsDefInfo = new FunctionArgsDefInfo(new ArrayList<>(), false); } return new DropFunctionCommand(setType, ifExists, function, functionArgsDefInfo); } @@ -4229,7 +4235,7 @@ public FunctionArgsDefInfo visitFunctionArguments(FunctionArgumentsContext ctx) if (child instanceof FunctionArgumentContext) { DataType dataType = visitFunctionArgument((FunctionArgumentContext) child); if (dataType != null) { - argTypeDefs.add(dataType); + argTypeDefs.add(dataType.conversion()); } else { isVariadic = true; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java index 2c2a803f477fa5..06c6c20661d076 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java @@ -18,8 +18,11 @@ package org.apache.doris.nereids.trees.plans.commands; import org.apache.doris.analysis.Expr; +import org.apache.doris.analysis.FunctionCallExpr; import org.apache.doris.analysis.FunctionName; +import org.apache.doris.analysis.FunctionParams; import org.apache.doris.analysis.SetType; +import org.apache.doris.analysis.SlotRef; import org.apache.doris.analysis.StmtType; import org.apache.doris.catalog.AggregateFunction; import org.apache.doris.catalog.AliasFunction; @@ -43,10 +46,26 @@ import org.apache.doris.common.util.Util; import org.apache.doris.mysql.privilege.PrivPredicate; import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.analyzer.Scope; +import org.apache.doris.nereids.analyzer.UnboundSlot; import org.apache.doris.nereids.glue.translator.ExpressionTranslator; import org.apache.doris.nereids.glue.translator.PlanTranslatorContext; import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.BitAnd; +import org.apache.doris.nereids.trees.expressions.BitNot; +import org.apache.doris.nereids.trees.expressions.BitOr; +import org.apache.doris.nereids.trees.expressions.BitXor; +import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.IntegralDivide; +import org.apache.doris.nereids.trees.expressions.Mod; +import org.apache.doris.nereids.trees.expressions.Multiply; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.Subtract; +import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgsDefInfo; import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; @@ -67,6 +86,7 @@ import io.grpc.ManagedChannel; import io.grpc.netty.NettyChannelBuilder; import org.apache.commons.codec.binary.Hex; +import org.apache.commons.collections.map.CaseInsensitiveMap; import org.apache.commons.lang3.StringUtils; import java.io.IOException; @@ -211,6 +231,7 @@ public void run(ConnectContext ctx, StmtExecutor executor) throws Exception { String dbName = functionName.getDb(); if (dbName == null) { dbName = ctx.getDatabase(); + functionName.setDb(dbName); } Database db = Env.getCurrentInternalCatalog().getDbOrDdlException(dbName); db.addFunction(function, ifNotExists); @@ -891,18 +912,144 @@ private TFunctionBinaryType getFunctionBinaryType(String type) { private void analyzeAliasFunction(ConnectContext ctx) throws AnalysisException { function = AliasFunction.createFunction(functionName, argsDef.getArgTypes(), Type.VARCHAR, argsDef.isVariadic(), parameters, translateToLegacyExpr(originFunction, ctx)); - ((AliasFunction) function).analyze(); } /** * translate to legacy expr, which do not need complex expression and table columns */ - private Expr translateToLegacyExpr(Expression expression, ConnectContext ctx) { + private Expr translateToLegacyExpr(Expression expression, ConnectContext ctx) throws AnalysisException { LogicalEmptyRelation plan = new LogicalEmptyRelation( ConnectContext.get().getStatementContext().getNextRelationId(), new ArrayList<>()); CascadesContext cascadesContext = CascadesContext.initContext(ctx.getStatementContext(), plan, PhysicalProperties.ANY); + Map argTypeMap = new CaseInsensitiveMap(); + List argTypes = argsDef.getArgTypeDefs(); + if (!parameters.isEmpty()) { + if (parameters.size() != argTypes.size()) { + throw new AnalysisException(String.format("arguments' size must be same as parameters' size," + + "arguments : %s, parameters : %s", argTypes.size(), parameters.size())); + } + for (int i = 0; i < parameters.size(); ++i) { + argTypeMap.put(parameters.get(i), argTypes.get(i)); + } + } + ExpressionAnalyzer analyzer = new CustomExpressionAnalyzer(cascadesContext, argTypeMap); + expression = analyzer.analyze(expression); + PlanTranslatorContext translatorContext = new PlanTranslatorContext(cascadesContext); - return ExpressionTranslator.translate(expression, translatorContext); + ExpressionToExpr translator = new ExpressionToExpr(); + return expression.accept(translator, translatorContext); + } + + private static class CustomExpressionAnalyzer extends ExpressionAnalyzer { + private Map argTypeMap; + + public CustomExpressionAnalyzer(CascadesContext cascadesContext, Map argTypeMap) { + super(null, new Scope(ImmutableList.of()), cascadesContext, false, false); + this.argTypeMap = argTypeMap; + } + + @Override + public Expression visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteContext context) { + DataType dataType = argTypeMap.get(unboundSlot.getName()); + if (dataType == null) { + throw new org.apache.doris.nereids.exceptions.AnalysisException( + String.format("param %s's datatype is missed", unboundSlot.getName())); + } + return new SlotReference(unboundSlot.getName(), dataType); + } + } + + private static class ExpressionToExpr extends ExpressionTranslator { + @Override + public Expr visitSlotReference(SlotReference slotReference, PlanTranslatorContext context) { + SlotRef slotRef = new SlotRef(slotReference.getDataType().toCatalogDataType(), slotReference.nullable()); + slotRef.setLabel(slotReference.getName()); + slotRef.setCol(slotReference.getName()); + slotRef.setDisableTableName(true); + return slotRef; + } + + @Override + public Expr visitBoundFunction(BoundFunction function, PlanTranslatorContext context) { + return makeFunctionCallExpr(function, function.getName(), function.hasVarArguments(), context); + } + + @Override + public Expr visitAdd(Add add, PlanTranslatorContext context) { + return makeFunctionCallExpr(add, "add", false, context); + } + + @Override + public Expr visitSubtract(Subtract subtract, PlanTranslatorContext context) { + return makeFunctionCallExpr(subtract, "subtract", false, context); + } + + @Override + public Expr visitMultiply(Multiply multiply, PlanTranslatorContext context) { + return makeFunctionCallExpr(multiply, "multiply", false, context); + } + + @Override + public Expr visitDivide(Divide divide, PlanTranslatorContext context) { + return makeFunctionCallExpr(divide, "divide", false, context); + } + + @Override + public Expr visitIntegralDivide(IntegralDivide integralDivide, PlanTranslatorContext context) { + return makeFunctionCallExpr(integralDivide, "integralDivide", false, context); + } + + @Override + public Expr visitMod(Mod mod, PlanTranslatorContext context) { + return makeFunctionCallExpr(mod, "mod", false, context); + } + + @Override + public Expr visitBitAnd(BitAnd bitAnd, PlanTranslatorContext context) { + return makeFunctionCallExpr(bitAnd, "bitAnd", false, context); + } + + @Override + public Expr visitBitOr(BitOr bitOr, PlanTranslatorContext context) { + return makeFunctionCallExpr(bitOr, "bitOr", false, context); + } + + @Override + public Expr visitBitXor(BitXor bitXor, PlanTranslatorContext context) { + return makeFunctionCallExpr(bitXor, "bitXor", false, context); + } + + @Override + public Expr visitBitNot(BitNot bitNot, PlanTranslatorContext context) { + return makeFunctionCallExpr(bitNot, "bitNot", false, context); + } + + private Expr makeFunctionCallExpr(Expression expression, String name, boolean hasVarArguments, + PlanTranslatorContext context) { + List arguments = expression.getArguments().stream() + .map(arg -> arg.accept(this, context)) + .collect(Collectors.toList()); + + List argTypes = expression.getArguments().stream() + .map(Expression::getDataType) + .map(DataType::toCatalogDataType) + .collect(Collectors.toList()); + + NullableMode nullableMode = expression.nullable() + ? NullableMode.ALWAYS_NULLABLE + : NullableMode.ALWAYS_NOT_NULLABLE; + + org.apache.doris.catalog.ScalarFunction catalogFunction = new org.apache.doris.catalog.ScalarFunction( + new FunctionName(name), argTypes, + expression.getDataType().toCatalogDataType(), hasVarArguments, + "", TFunctionBinaryType.BUILTIN, true, true, nullableMode); + + FunctionCallExpr functionCallExpr; + // create catalog FunctionCallExpr without analyze again + functionCallExpr = new FunctionCallExpr(catalogFunction, new FunctionParams(false, arguments)); + functionCallExpr.setNullableFromNereids(expression.nullable()); + return functionCallExpr; + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java index 75dde73e742cc5..004930193e9fd7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java @@ -23,6 +23,9 @@ import org.apache.doris.catalog.Database; import org.apache.doris.catalog.Env; import org.apache.doris.catalog.FunctionSearchDesc; +import org.apache.doris.common.ErrorCode; +import org.apache.doris.common.ErrorReport; +import org.apache.doris.mysql.privilege.PrivPredicate; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgsDefInfo; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; @@ -64,6 +67,10 @@ public DropFunctionCommand(SetType setType, boolean ifExists, FunctionName funct @Override public void run(ConnectContext ctx, StmtExecutor executor) throws Exception { + // check operation privilege + if (!Env.getCurrentEnv().getAccessManager().checkGlobalPriv(ConnectContext.get(), PrivPredicate.ADMIN)) { + ErrorReport.reportAnalysisException(ErrorCode.ERR_SPECIFIC_ACCESS_DENIED_ERROR, "ADMIN"); + } argsDef.validate(); function = new FunctionSearchDesc(functionName, argsDef.getArgTypes(), argsDef.isVariadic()); if (SetType.GLOBAL.equals(setType)) { @@ -72,6 +79,7 @@ public void run(ConnectContext ctx, StmtExecutor executor) throws Exception { String dbName = functionName.getDb(); if (dbName == null) { dbName = ctx.getDatabase(); + functionName.setDb(dbName); } Database db = Env.getCurrentInternalCatalog().getDbOrDdlException(dbName); db.dropFunction(function, ifExists); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java index b8f93c0e6ca278..e3da48bfb7cbed 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java @@ -42,6 +42,10 @@ public Type[] getArgTypes() { return argTypes; } + public List getArgTypeDefs() { + return argTypeDefs; + } + public boolean isVariadic() { return isVariadic; } diff --git a/regression-test/suites/ddl_p0/test_alias_function.groovy b/regression-test/suites/ddl_p0/test_alias_function.groovy index fa2ced713f8146..7793de925531fb 100644 --- a/regression-test/suites/ddl_p0/test_alias_function.groovy +++ b/regression-test/suites/ddl_p0/test_alias_function.groovy @@ -18,10 +18,10 @@ suite("test_alias_function") { sql """DROP FUNCTION IF EXISTS mesh_udf_test1(INT,INT)""" - sql """CREATE ALIAS FUNCTION IF NOT EXISTS mesh_udf_test1(INT,INT) WITH PARAMETER(n,d) AS ROUND(1+floor(n/d));""" + sql """CREATE ALIAS FUNCTION mesh_udf_test1(INT,INT) WITH PARAMETER(n,d) AS ROUND(1+floor(n/d));""" qt_sql1 """select mesh_udf_test1(1,2);""" sql """DROP FUNCTION IF EXISTS mesh_udf_test2(INT,INT)""" - sql """CREATE ALIAS FUNCTION IF NOT EXISTS mesh_udf_test2(INT,INT) WITH PARAMETER(n,d) AS add(1,floor(divide(n,d)))""" + sql """CREATE ALIAS FUNCTION mesh_udf_test2(INT,INT) WITH PARAMETER(n,d) AS add(1,floor(divide(n,d)))""" qt_sql1 """select mesh_udf_test2(1,2);""" } diff --git a/regression-test/suites/query_p0/sql_functions/test_alias_function.groovy b/regression-test/suites/query_p0/sql_functions/test_alias_function.groovy index 095ec89e220f1b..8b281d6faa0521 100644 --- a/regression-test/suites/query_p0/sql_functions/test_alias_function.groovy +++ b/regression-test/suites/query_p0/sql_functions/test_alias_function.groovy @@ -16,6 +16,12 @@ // under the License. suite('test_alias_function', "arrow_flight_sql") { + sql ''' + DROP FUNCTION IF EXISTS f1() + ''' + sql ''' + DROP FUNCTION IF EXISTS f2() + ''' sql ''' CREATE ALIAS FUNCTION IF NOT EXISTS f1(DATETIMEV2(3), INT) with PARAMETER (datetime1, int1) as date_trunc(days_sub(datetime1, int1), 'day')''' From 0a53743b71818187123fb83aeff3bbf8e140e279 Mon Sep 17 00:00:00 2001 From: lichi Date: Thu, 2 Jan 2025 19:03:46 +0800 Subject: [PATCH 3/4] update code --- .../org/apache/doris/nereids/DorisParser.g4 | 24 +- .../nereids/parser/LogicalPlanBuilder.java | 135 ++++------ .../plans/commands/CreateFunctionCommand.java | 77 ++---- .../plans/commands/DropFunctionCommand.java | 6 +- .../plans/commands/info/ColumnDefinition.java | 242 +---------------- ...DefInfo.java => FunctionArgTypesInfo.java} | 12 +- .../apache/doris/nereids/types/DataType.java | 243 ++++++++++++++++++ .../doris/nereids/util/TypeCoercionUtils.java | 242 ----------------- 8 files changed, 344 insertions(+), 637 deletions(-) rename fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/{FunctionArgsDefInfo.java => FunctionArgTypesInfo.java} (81%) diff --git a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 index 23f08eba90ff47..4a4855ecd7f311 100644 --- a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 +++ b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 @@ -199,12 +199,12 @@ supportedCreateStatement | CREATE SQL_BLOCK_RULE (IF NOT EXISTS)? name=identifier properties=propertyClause? #createSqlBlockRule | CREATE ENCRYPTKEY (IF NOT EXISTS)? multipartIdentifier AS STRING_LITERAL #createEncryptkey - | CREATE (GLOBAL | SESSION | LOCAL)? + | CREATE statementScope? (TABLES | AGGREGATE)? FUNCTION (IF NOT EXISTS)? functionIdentifier LEFT_PAREN functionArguments? RIGHT_PAREN RETURNS returnType=dataType (INTERMEDIATE intermediateType=dataType)? properties=propertyClause? #createUserDefineFunction - | CREATE (GLOBAL | SESSION | LOCAL)? ALIAS FUNCTION (IF NOT EXISTS)? + | CREATE statementScope? ALIAS FUNCTION (IF NOT EXISTS)? functionIdentifier LEFT_PAREN functionArguments? RIGHT_PAREN WITH PARAMETER LEFT_PAREN parameters=identifierSeq? RIGHT_PAREN AS expression #createAliasFunction @@ -248,12 +248,12 @@ supportedDropStatement ((FROM | IN) database=identifier)? properties=propertyClause #dropFile | DROP WORKLOAD POLICY (IF EXISTS)? name=identifierOrText #dropWorkloadPolicy | DROP REPOSITORY name=identifier #dropRepository - | DROP (GLOBAL | SESSION | LOCAL)? FUNCTION (IF EXISTS)? + | DROP statementScope? FUNCTION (IF EXISTS)? functionIdentifier LEFT_PAREN functionArguments? RIGHT_PAREN #dropFunction ; supportedShowStatement - : SHOW (GLOBAL | SESSION | LOCAL)? VARIABLES wildWhere? #showVariables + : SHOW statementScope? VARIABLES wildWhere? #showVariables | SHOW AUTHORS #showAuthors | SHOW CREATE (DATABASE | SCHEMA) name=multipartIdentifier #showCreateDatabase | SHOW BROKER #showBroker @@ -304,7 +304,7 @@ supportedShowStatement | SHOW DATABASE databaseId=INTEGER_VALUE #showDatabaseId | SHOW TABLE tableId=INTEGER_VALUE #showTableId | SHOW TRASH (ON backend=STRING_LITERAL)? #showTrash - | SHOW (GLOBAL | SESSION | LOCAL)? STATUS #showStatus + | SHOW statementScope? STATUS #showStatus | SHOW WHITELIST #showWhitelist | SHOW TABLETS BELONG tabletIds+=INTEGER_VALUE (COMMA tabletIds+=INTEGER_VALUE)* #showTabletsBelong @@ -359,7 +359,7 @@ unsupportedShowStatement | SHOW FULL? TABLES ((FROM | IN) database=multipartIdentifier)? wildWhere? #showTables | SHOW FULL? VIEWS ((FROM | IN) database=multipartIdentifier)? wildWhere? #showViews | SHOW CREATE MATERIALIZED VIEW name=multipartIdentifier #showMaterializedView - | SHOW CREATE (GLOBAL | SESSION | LOCAL)? FUNCTION functionIdentifier + | SHOW CREATE statementScope? FUNCTION functionIdentifier LEFT_PAREN functionArguments? RIGHT_PAREN ((FROM | IN) database=multipartIdentifier)? #showCreateFunction | SHOW (DATABASES | SCHEMAS) (FROM catalog=identifier)? wildWhere? #showDatabases @@ -831,7 +831,7 @@ supportedSetStatement (COMMA (optionWithType | optionWithoutType))* #setOptions | SET identifier AS DEFAULT STORAGE VAULT #setDefaultStorageVault | SET PROPERTY (FOR user=identifierOrText)? propertyItemList #setUserProperties - | SET (GLOBAL | LOCAL | SESSION)? TRANSACTION + | SET statementScope? TRANSACTION ( transactionAccessMode | isolationLevel | transactionAccessMode COMMA isolationLevel @@ -839,7 +839,7 @@ supportedSetStatement ; optionWithType - : (GLOBAL | LOCAL | SESSION) identifier EQ (expression | DEFAULT) #setVariableWithType + : statementScope identifier EQ (expression | DEFAULT) #setVariableWithType ; optionWithoutType @@ -855,7 +855,7 @@ optionWithoutType ; variable - : (DOUBLEATSIGN ((GLOBAL | LOCAL | SESSION) DOT)?)? identifier EQ (expression | DEFAULT) #setSystemVariable + : (DOUBLEATSIGN (statementScope DOT)?)? identifier EQ (expression | DEFAULT) #setSystemVariable | ATSIGN identifier EQ expression #setUserVariable ; @@ -868,7 +868,7 @@ isolationLevel ; supportedUnsetStatement - : UNSET (GLOBAL | SESSION | LOCAL)? VARIABLE (ALL | identifier) + : UNSET statementScope? VARIABLE (ALL | identifier) | UNSET DEFAULT STORAGE VAULT ; @@ -961,6 +961,10 @@ dataDesc ; // -----------------Command accessories----------------- +statementScope + : (GLOBAL | SESSION | LOCAL) + ; + buildMode : BUILD (IMMEDIATE | DEFERRED) ; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index dd8f8f8f1ee1a5..d99b37c99f218d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -339,6 +339,7 @@ import org.apache.doris.nereids.DorisParser.SpecifiedPartitionContext; import org.apache.doris.nereids.DorisParser.StarContext; import org.apache.doris.nereids.DorisParser.StatementDefaultContext; +import org.apache.doris.nereids.DorisParser.StatementScopeContext; import org.apache.doris.nereids.DorisParser.StepPartitionDefContext; import org.apache.doris.nereids.DorisParser.StringLiteralContext; import org.apache.doris.nereids.DorisParser.StructLiteralContext; @@ -660,7 +661,7 @@ import org.apache.doris.nereids.trees.plans.commands.info.EnableFeatureOp; import org.apache.doris.nereids.trees.plans.commands.info.FixedRangePartition; import org.apache.doris.nereids.trees.plans.commands.info.FuncNameInfo; -import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgsDefInfo; +import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgTypesInfo; import org.apache.doris.nereids.trees.plans.commands.info.GeneratedColumnDesc; import org.apache.doris.nereids.trees.plans.commands.info.InPartition; import org.apache.doris.nereids.trees.plans.commands.info.IndexDefinition; @@ -4107,16 +4108,11 @@ public LogicalPlan visitSupportedUnsetStatement(SupportedUnsetStatementContext c if (ctx.DEFAULT() != null && ctx.STORAGE() != null && ctx.VAULT() != null) { return new UnsetDefaultStorageVaultCommand(); } - SetType type = SetType.DEFAULT; - if (ctx.GLOBAL() != null) { - type = SetType.GLOBAL; - } else if (ctx.LOCAL() != null || ctx.SESSION() != null) { - type = SetType.SESSION; - } + SetType statementScope = visitStatementScope(ctx.statementScope()); if (ctx.ALL() != null) { - return new UnsetVariableCommand(type, true); + return new UnsetVariableCommand(statementScope, true); } else if (ctx.identifier() != null) { - return new UnsetVariableCommand(type, ctx.identifier().getText()); + return new UnsetVariableCommand(statementScope, ctx.identifier().getText()); } throw new AnalysisException("Should add 'ALL' or variable name"); } @@ -4140,25 +4136,18 @@ public LogicalPlan visitCreateTableLike(CreateTableLikeContext ctx) { @Override public Command visitCreateUserDefineFunction(CreateUserDefineFunctionContext ctx) { - SetType setType; - if (ctx.GLOBAL() != null) { - setType = SetType.GLOBAL; - } else if (ctx.LOCAL() != null || ctx.SESSION() != null) { - setType = SetType.SESSION; - } else { - setType = SetType.DEFAULT; - } + SetType statementScope = visitStatementScope(ctx.statementScope()); boolean ifNotExists = ctx.EXISTS() != null; boolean isAggFunction = ctx.AGGREGATE() != null; boolean isTableFunction = ctx.TABLES() != null; String functionName = ctx.functionIdentifier().functionNameIdentifier().getText(); String dbName = ctx.functionIdentifier().dbName != null ? ctx.functionIdentifier().dbName.getText() : null; FunctionName function = new FunctionName(dbName, functionName); - FunctionArgsDefInfo functionArgsDefInfo; + FunctionArgTypesInfo functionArgTypesInfo; if (ctx.functionArguments() != null) { - functionArgsDefInfo = visitFunctionArguments(ctx.functionArguments()); + functionArgTypesInfo = visitFunctionArguments(ctx.functionArguments()); } else { - functionArgsDefInfo = new FunctionArgsDefInfo(new ArrayList<>(), false); + functionArgTypesInfo = new FunctionArgTypesInfo(new ArrayList<>(), false); } DataType returnType = typedVisit(ctx.returnType); returnType = returnType.conversion(); @@ -4169,66 +4158,49 @@ public Command visitCreateUserDefineFunction(CreateUserDefineFunctionContext ctx Map properties = ctx.propertyClause() != null ? Maps.newHashMap(visitPropertyClause(ctx.propertyClause())) : Maps.newHashMap(); - if (isTableFunction) { - return new CreateFunctionCommand(setType, ifNotExists, function, functionArgsDefInfo, returnType, - intermediateType, properties); - } else { - return new CreateFunctionCommand(setType, ifNotExists, isAggFunction, function, functionArgsDefInfo, - returnType, intermediateType, properties); - } + return new CreateFunctionCommand(statementScope, ifNotExists, isAggFunction, false, isTableFunction, + function, functionArgTypesInfo, returnType, intermediateType, + null, null, properties); } @Override public Command visitCreateAliasFunction(CreateAliasFunctionContext ctx) { - SetType setType; - if (ctx.GLOBAL() != null) { - setType = SetType.GLOBAL; - } else if (ctx.LOCAL() != null || ctx.SESSION() != null) { - setType = SetType.SESSION; - } else { - setType = SetType.DEFAULT; - } + SetType statementScope = visitStatementScope(ctx.statementScope()); boolean ifNotExists = ctx.EXISTS() != null; String functionName = ctx.functionIdentifier().functionNameIdentifier().getText(); String dbName = ctx.functionIdentifier().dbName != null ? ctx.functionIdentifier().dbName.getText() : null; FunctionName function = new FunctionName(dbName, functionName); - FunctionArgsDefInfo functionArgsDefInfo; + FunctionArgTypesInfo functionArgTypesInfo; if (ctx.functionArguments() != null) { - functionArgsDefInfo = visitFunctionArguments(ctx.functionArguments()); + functionArgTypesInfo = visitFunctionArguments(ctx.functionArguments()); } else { - functionArgsDefInfo = new FunctionArgsDefInfo(new ArrayList<>(), false); + functionArgTypesInfo = new FunctionArgTypesInfo(new ArrayList<>(), false); } List parameters = ctx.parameters != null ? visitIdentifierSeq(ctx.parameters) : new ArrayList<>(); Expression originFunction = getExpression(ctx.expression()); - return new CreateFunctionCommand(setType, ifNotExists, function, functionArgsDefInfo, parameters, - originFunction); + return new CreateFunctionCommand(statementScope, ifNotExists, false, true, false, + function, functionArgTypesInfo, VarcharType.MAX_VARCHAR_TYPE, null, + parameters, originFunction, null); } @Override public Command visitDropFunction(DropFunctionContext ctx) { - SetType setType; - if (ctx.GLOBAL() != null) { - setType = SetType.GLOBAL; - } else if (ctx.LOCAL() != null || ctx.SESSION() != null) { - setType = SetType.SESSION; - } else { - setType = SetType.DEFAULT; - } + SetType statementScope = visitStatementScope(ctx.statementScope()); boolean ifExists = ctx.EXISTS() != null; String functionName = ctx.functionIdentifier().functionNameIdentifier().getText(); String dbName = ctx.functionIdentifier().dbName != null ? ctx.functionIdentifier().dbName.getText() : null; FunctionName function = new FunctionName(dbName, functionName); - FunctionArgsDefInfo functionArgsDefInfo; + FunctionArgTypesInfo functionArgTypesInfo; if (ctx.functionArguments() != null) { - functionArgsDefInfo = visitFunctionArguments(ctx.functionArguments()); + functionArgTypesInfo = visitFunctionArguments(ctx.functionArguments()); } else { - functionArgsDefInfo = new FunctionArgsDefInfo(new ArrayList<>(), false); + functionArgTypesInfo = new FunctionArgTypesInfo(new ArrayList<>(), false); } - return new DropFunctionCommand(setType, ifExists, function, functionArgsDefInfo); + return new DropFunctionCommand(statementScope, ifExists, function, functionArgTypesInfo); } @Override - public FunctionArgsDefInfo visitFunctionArguments(FunctionArgumentsContext ctx) { + public FunctionArgTypesInfo visitFunctionArguments(FunctionArgumentsContext ctx) { boolean isVariadic = false; List argTypeDefs = new ArrayList<>(4); for (Object child : ctx.children) { @@ -4241,7 +4213,7 @@ public FunctionArgsDefInfo visitFunctionArguments(FunctionArgumentsContext ctx) } } } - return new FunctionArgsDefInfo(argTypeDefs, isVariadic); + return new FunctionArgTypesInfo(argTypeDefs, isVariadic); } @Override @@ -4292,28 +4264,18 @@ public SetOptionsCommand visitSetOptions(SetOptionsContext ctx) { @Override public SetVarOp visitSetSystemVariable(SetSystemVariableContext ctx) { - SetType type = SetType.DEFAULT; - if (ctx.GLOBAL() != null) { - type = SetType.GLOBAL; - } else if (ctx.LOCAL() != null || ctx.SESSION() != null) { - type = SetType.SESSION; - } + SetType statementScope = visitStatementScope(ctx.statementScope()); String name = stripQuotes(ctx.identifier().getText()); Expression expression = ctx.expression() != null ? typedVisit(ctx.expression()) : null; - return new SetSessionVarOp(type, name, expression); + return new SetSessionVarOp(statementScope, name, expression); } @Override public SetVarOp visitSetVariableWithType(SetVariableWithTypeContext ctx) { - SetType type = SetType.DEFAULT; - if (ctx.GLOBAL() != null) { - type = SetType.GLOBAL; - } else if (ctx.LOCAL() != null || ctx.SESSION() != null) { - type = SetType.SESSION; - } + SetType statementScope = visitStatementScope(ctx.statementScope()); String name = stripQuotes(ctx.identifier().getText()); Expression expression = ctx.expression() != null ? typedVisit(ctx.expression()) : null; - return new SetSessionVarOp(type, name, expression); + return new SetSessionVarOp(statementScope, name, expression); } @Override @@ -4781,15 +4743,11 @@ public AlterTableOp visitDropRollupClause(DorisParser.DropRollupClauseContext ct @Override public LogicalPlan visitShowVariables(ShowVariablesContext ctx) { - SetType type = SetType.DEFAULT; - if (ctx.GLOBAL() != null) { - type = SetType.GLOBAL; - } else if (ctx.LOCAL() != null || ctx.SESSION() != null) { - type = SetType.SESSION; - } + SetType statementScope = visitStatementScope(ctx.statementScope()); if (ctx.wildWhere() != null) { if (ctx.wildWhere().LIKE() != null) { - return new ShowVariablesCommand(type, stripQuotes(ctx.wildWhere().STRING_LITERAL().getText())); + return new ShowVariablesCommand(statementScope, + stripQuotes(ctx.wildWhere().STRING_LITERAL().getText())); } else { StringBuilder sb = new StringBuilder(); sb.append("SELECT `VARIABLE_NAME` AS `Variable_name`, `VARIABLE_VALUE` AS `Value` FROM "); @@ -4797,7 +4755,7 @@ public LogicalPlan visitShowVariables(ShowVariablesContext ctx) { sb.append("."); sb.append("`").append(InfoSchemaDb.DATABASE_NAME).append("`"); sb.append("."); - if (type == SetType.GLOBAL) { + if (statementScope == SetType.GLOBAL) { sb.append("`global_variables` "); } else { sb.append("`session_variables` "); @@ -4806,7 +4764,7 @@ public LogicalPlan visitShowVariables(ShowVariablesContext ctx) { return new NereidsParser().parseSingle(sb.toString()); } } else { - return new ShowVariablesCommand(type, null); + return new ShowVariablesCommand(statementScope, null); } } @@ -5460,15 +5418,7 @@ public LogicalPlan visitShowWarningErrorCount(ShowWarningErrorCountContext ctx) @Override public LogicalPlan visitShowStatus(ShowStatusContext ctx) { - String scope = null; - if (ctx.GLOBAL() != null) { - scope = "GLOBAL"; - } else if (ctx.SESSION() != null) { - scope = "SESSION"; - } else if (ctx.LOCAL() != null) { - scope = "LOCAL"; - } - + String scope = visitStatementScope(ctx.statementScope()).name(); return new ShowStatusCommand(scope); } @@ -5492,6 +5442,19 @@ public LogicalPlan visitShowTableCreation(ShowTableCreationContext ctx) { return new ShowTableCreationCommand(dbName, wild); } + @Override + public SetType visitStatementScope(StatementScopeContext ctx) { + SetType statementScope = SetType.DEFAULT; + if (ctx != null) { + if (ctx.GLOBAL() != null) { + statementScope = SetType.GLOBAL; + } else if (ctx.LOCAL() != null || ctx.SESSION() != null) { + statementScope = SetType.SESSION; + } + } + return statementScope; + } + @Override public LogicalPlan visitAdminShowTabletStorageFormat(AdminShowTabletStorageFormatContext ctx) { return new ShowTabletStorageFormatCommand(ctx.VERBOSE() != null); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java index 06c6c20661d076..2ba6301052965c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java @@ -67,12 +67,10 @@ import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.plans.PlanType; -import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgsDefInfo; +import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgTypesInfo; import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.types.VarcharType; -import org.apache.doris.nereids.util.TypeCoercionUtils; import org.apache.doris.proto.FunctionService; import org.apache.doris.proto.PFunctionServiceGrpc; import org.apache.doris.proto.Types; @@ -148,7 +146,7 @@ public class CreateFunctionCommand extends Command implements ForwardWithSync { private final boolean isAggregate; private final boolean isAlias; private boolean isTableFunction; - private final FunctionArgsDefInfo argsDef; + private final FunctionArgTypesInfo argsDef; private final DataType returnType; private DataType intermediateType; private final Map properties; @@ -169,62 +167,36 @@ public class CreateFunctionCommand extends Command implements ForwardWithSync { /** * CreateFunctionCommand */ - public CreateFunctionCommand(SetType setType, boolean ifNotExists, boolean isAggregate, FunctionName functionName, - FunctionArgsDefInfo argsDef, - DataType returnType, DataType intermediateType, Map properties) { + public CreateFunctionCommand(SetType setType, boolean ifNotExists, boolean isAggregate, boolean isAlias, + boolean isTableFunction, FunctionName functionName, FunctionArgTypesInfo argsDef, + DataType returnType, DataType intermediateType, List parameters, + Expression originFunction, Map properties) { super(PlanType.CREATE_FUNCTION_COMMAND); this.setType = setType; this.ifNotExists = ifNotExists; - this.functionName = functionName; this.isAggregate = isAggregate; + this.isAlias = isAlias; + this.isTableFunction = isTableFunction; + this.functionName = functionName; this.argsDef = argsDef; this.returnType = returnType; this.intermediateType = intermediateType; - if (properties == null) { - this.properties = ImmutableSortedMap.of(); - } else { - this.properties = ImmutableSortedMap.copyOf(properties, String.CASE_INSENSITIVE_ORDER); - } - this.isAlias = false; - this.isTableFunction = false; - this.parameters = ImmutableList.of(); - this.originFunction = null; - } - - public CreateFunctionCommand(SetType setType, boolean ifNotExists, FunctionName functionName, - FunctionArgsDefInfo argsDef, - DataType returnType, DataType intermediateType, Map properties) { - this(setType, ifNotExists, false, functionName, argsDef, returnType, intermediateType, properties); - this.isTableFunction = true; - } - - /** - * CreateFunctionCommand - */ - public CreateFunctionCommand(SetType setType, boolean ifNotExists, FunctionName functionName, - FunctionArgsDefInfo argsDef, - List parameters, Expression originFunction) { - super(PlanType.CREATE_FUNCTION_COMMAND); - this.setType = setType; - this.ifNotExists = ifNotExists; - this.functionName = functionName; - this.isAlias = true; - this.argsDef = argsDef; if (parameters == null) { this.parameters = ImmutableList.of(); } else { this.parameters = ImmutableList.copyOf(parameters); } this.originFunction = originFunction; - this.isAggregate = false; - this.isTableFunction = false; - this.returnType = VarcharType.MAX_VARCHAR_TYPE; - this.properties = ImmutableSortedMap.of(); + if (properties == null) { + this.properties = ImmutableSortedMap.of(); + } else { + this.properties = ImmutableSortedMap.copyOf(properties, String.CASE_INSENSITIVE_ORDER); + } } @Override public void run(ConnectContext ctx, StmtExecutor executor) throws Exception { - validate(ctx); + analyze(ctx); if (SetType.GLOBAL.equals(setType)) { Env.getCurrentEnv().getGlobalFunctionMgr().addFunction(function, ifNotExists); } else { @@ -257,7 +229,7 @@ public StmtType stmtType() { return StmtType.CREATE; } - private void validate(ConnectContext ctx) throws Exception { + private void analyze(ConnectContext ctx) throws Exception { // https://github.com/apache/doris/issues/17810 // this error report in P0 test, so we suspect that it is related to concurrency // add this change to test it. @@ -270,7 +242,7 @@ private void validate(ConnectContext ctx) throws Exception { } else if (isAlias) { analyzeAliasFunction(ctx); } else if (isTableFunction) { - analyzeTableFunction(); + analyzeUdtf(); } else { analyzeUdf(); } @@ -283,7 +255,7 @@ private void validate(ConnectContext ctx) throws Exception { } else if (isAlias) { analyzeAliasFunction(ctx); } else if (isTableFunction) { - analyzeTableFunction(); + analyzeUdtf(); } else { analyzeUdf(); } @@ -310,9 +282,9 @@ private void analyzeCommon(ConnectContext ctx) throws AnalysisException { if (isAlias) { return; } - TypeCoercionUtils.validateDataType(returnType); + returnType.validateDataType(); if (intermediateType != null) { - TypeCoercionUtils.validateDataType(intermediateType); + intermediateType.validateDataType(); } else { intermediateType = returnType; } @@ -354,7 +326,12 @@ private void analyzeCommon(ConnectContext ctx) throws AnalysisException { } String expirationTimeString = properties.get(EXPIRATION_TIME); if (expirationTimeString != null) { - long timeMinutes = Long.parseLong(expirationTimeString); + long timeMinutes = 0; + try { + timeMinutes = Long.parseLong(expirationTimeString); + } catch (NumberFormatException e) { + throw new AnalysisException(e.getMessage()); + } if (timeMinutes <= 0) { throw new AnalysisException("expirationTime should greater than zero: "); } @@ -396,7 +373,7 @@ private void computeObjectChecksum() throws IOException, NoSuchAlgorithmExceptio } } - private void analyzeTableFunction() throws AnalysisException { + private void analyzeUdtf() throws AnalysisException { String symbol = properties.get(SYMBOL_KEY); if (Strings.isNullOrEmpty(symbol)) { throw new AnalysisException("No 'symbol' in properties"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java index 004930193e9fd7..3cf19921b467e6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java @@ -27,7 +27,7 @@ import org.apache.doris.common.ErrorReport; import org.apache.doris.mysql.privilege.PrivPredicate; import org.apache.doris.nereids.trees.plans.PlanType; -import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgsDefInfo; +import org.apache.doris.nereids.trees.plans.commands.info.FunctionArgTypesInfo; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.qe.ConnectContext; import org.apache.doris.qe.StmtExecutor; @@ -49,7 +49,7 @@ public class DropFunctionCommand extends Command implements ForwardWithSync { private SetType setType; private final boolean ifExists; private final FunctionName functionName; - private final FunctionArgsDefInfo argsDef; + private final FunctionArgTypesInfo argsDef; // set after analyzed private FunctionSearchDesc function; @@ -57,7 +57,7 @@ public class DropFunctionCommand extends Command implements ForwardWithSync { * DropFunctionCommand */ public DropFunctionCommand(SetType setType, boolean ifExists, FunctionName functionName, - FunctionArgsDefInfo argsDef) { + FunctionArgTypesInfo argsDef) { super(PlanType.CREATE_FUNCTION_COMMAND); this.setType = setType; this.ifExists = ifExists; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/ColumnDefinition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/ColumnDefinition.java index 375206305e97a3..2b66ad1e768f17 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/ColumnDefinition.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/ColumnDefinition.java @@ -23,9 +23,6 @@ import org.apache.doris.catalog.AggregateType; import org.apache.doris.catalog.Column; import org.apache.doris.catalog.KeysType; -import org.apache.doris.catalog.PrimitiveType; -import org.apache.doris.catalog.ScalarType; -import org.apache.doris.catalog.Type; import org.apache.doris.common.FeNameFormat; import org.apache.doris.common.util.SqlUtils; import org.apache.doris.nereids.exceptions.AnalysisException; @@ -46,7 +43,6 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; -import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Optional; @@ -294,7 +290,7 @@ public void validate(boolean isOlap, Set keysSet, Set clusterKey } catch (Exception e) { throw new AnalysisException(e.getMessage(), e); } - validateDataType(type.toCatalogDataType()); + type.validateDataType(); type = updateCharacterTypeLength(type); if (type.isArrayType()) { int depth = 0; @@ -496,242 +492,6 @@ public void validate(boolean isOlap, Set keysSet, Set clusterKey validateGeneratedColumnInfo(); } - // from TypeDef.java analyze() - private void validateDataType(Type catalogType) { - if (catalogType.exceedsMaxNestingDepth()) { - throw new AnalysisException( - String.format("Type exceeds the maximum nesting depth of %s:\n%s", - Type.MAX_NESTING_DEPTH, catalogType.toSql())); - } - if (!catalogType.isSupported()) { - throw new AnalysisException("Unsupported data type: " + catalogType.toSql()); - } - - if (catalogType.isScalarType()) { - validateScalarType((ScalarType) catalogType); - } else if (catalogType.isComplexType()) { - // now we not support array / map / struct nesting complex type - if (catalogType.isArrayType()) { - Type itemType = ((org.apache.doris.catalog.ArrayType) catalogType).getItemType(); - if (itemType instanceof ScalarType) { - validateNestedType(catalogType, (ScalarType) itemType); - } - } - if (catalogType.isMapType()) { - org.apache.doris.catalog.MapType mt = - (org.apache.doris.catalog.MapType) catalogType; - if (mt.getKeyType() instanceof ScalarType) { - validateNestedType(catalogType, (ScalarType) mt.getKeyType()); - } - if (mt.getValueType() instanceof ScalarType) { - validateNestedType(catalogType, (ScalarType) mt.getValueType()); - } - } - if (catalogType.isStructType()) { - ArrayList fields = - ((org.apache.doris.catalog.StructType) catalogType).getFields(); - Set fieldNames = new HashSet<>(); - for (org.apache.doris.catalog.StructField field : fields) { - Type fieldType = field.getType(); - if (fieldType instanceof ScalarType) { - validateNestedType(catalogType, (ScalarType) fieldType); - if (!fieldNames.add(field.getName())) { - throw new AnalysisException("Duplicate field name " + field.getName() - + " in struct " + catalogType.toSql()); - } - } - } - } - } - } - - private void validateScalarType(ScalarType scalarType) { - PrimitiveType type = scalarType.getPrimitiveType(); - // When string type length is not assigned, it needs to be assigned to 1. - if (scalarType.getPrimitiveType().isStringType() && !scalarType.isLengthSet()) { - if (scalarType.getPrimitiveType() == PrimitiveType.VARCHAR) { - // always set varchar length MAX_VARCHAR_LENGTH - scalarType.setLength(ScalarType.MAX_VARCHAR_LENGTH); - } else if (scalarType.getPrimitiveType() == PrimitiveType.STRING) { - // always set text length MAX_STRING_LENGTH - scalarType.setLength(ScalarType.MAX_STRING_LENGTH); - } else { - scalarType.setLength(1); - } - } - switch (type) { - case CHAR: - case VARCHAR: { - String name; - int maxLen; - if (type == PrimitiveType.VARCHAR) { - name = "VARCHAR"; - maxLen = ScalarType.MAX_VARCHAR_LENGTH; - } else { - name = "CHAR"; - maxLen = ScalarType.MAX_CHAR_LENGTH; - } - int len = scalarType.getLength(); - // len is decided by child, when it is -1. - - if (len <= 0) { - throw new AnalysisException(name + " size must be > 0: " + len); - } - if (scalarType.getLength() > maxLen) { - throw new AnalysisException(name + " size must be <= " + maxLen + ": " + len); - } - break; - } - case DECIMALV2: { - int precision = scalarType.decimalPrecision(); - int scale = scalarType.decimalScale(); - // precision: [1, 27] - if (precision < 1 || precision > ScalarType.MAX_DECIMALV2_PRECISION) { - throw new AnalysisException("Precision of decimal must between 1 and 27." - + " Precision was set to: " + precision + "."); - } - // scale: [0, 9] - if (scale < 0 || scale > ScalarType.MAX_DECIMALV2_SCALE) { - throw new AnalysisException("Scale of decimal must between 0 and 9." - + " Scale was set to: " + scale + "."); - } - if (precision - scale > ScalarType.MAX_DECIMALV2_PRECISION - - ScalarType.MAX_DECIMALV2_SCALE) { - throw new AnalysisException("Invalid decimal type with precision = " + precision - + ", scale = " + scale); - } - // scale < precision - if (scale > precision) { - throw new AnalysisException("Scale of decimal must be smaller than precision." - + " Scale is " + scale + " and precision is " + precision); - } - break; - } - case DECIMAL32: { - int decimal32Precision = scalarType.decimalPrecision(); - int decimal32Scale = scalarType.decimalScale(); - if (decimal32Precision < 1 - || decimal32Precision > ScalarType.MAX_DECIMAL32_PRECISION) { - throw new AnalysisException("Precision of decimal must between 1 and 9." - + " Precision was set to: " + decimal32Precision + "."); - } - // scale >= 0 - if (decimal32Scale < 0) { - throw new AnalysisException("Scale of decimal must not be less than 0." - + " Scale was set to: " + decimal32Scale + "."); - } - // scale < precision - if (decimal32Scale > decimal32Precision) { - throw new AnalysisException( - "Scale of decimal must be smaller than precision." + " Scale is " - + decimal32Scale + " and precision is " + decimal32Precision); - } - break; - } - case DECIMAL64: { - int decimal64Precision = scalarType.decimalPrecision(); - int decimal64Scale = scalarType.decimalScale(); - if (decimal64Precision < 1 - || decimal64Precision > ScalarType.MAX_DECIMAL64_PRECISION) { - throw new AnalysisException("Precision of decimal64 must between 1 and 18." - + " Precision was set to: " + decimal64Precision + "."); - } - // scale >= 0 - if (decimal64Scale < 0) { - throw new AnalysisException("Scale of decimal must not be less than 0." - + " Scale was set to: " + decimal64Scale + "."); - } - // scale < precision - if (decimal64Scale > decimal64Precision) { - throw new AnalysisException( - "Scale of decimal must be smaller than precision." + " Scale is " - + decimal64Scale + " and precision is " + decimal64Precision); - } - break; - } - case DECIMAL128: { - int decimal128Precision = scalarType.decimalPrecision(); - int decimal128Scale = scalarType.decimalScale(); - if (decimal128Precision < 1 - || decimal128Precision > ScalarType.MAX_DECIMAL128_PRECISION) { - throw new AnalysisException("Precision of decimal128 must between 1 and 38." - + " Precision was set to: " + decimal128Precision + "."); - } - // scale >= 0 - if (decimal128Scale < 0) { - throw new AnalysisException("Scale of decimal must not be less than 0." - + " Scale was set to: " + decimal128Scale + "."); - } - // scale < precision - if (decimal128Scale > decimal128Precision) { - throw new AnalysisException( - "Scale of decimal must be smaller than precision." + " Scale is " - + decimal128Scale + " and precision is " + decimal128Precision); - } - break; - } - case DECIMAL256: { - if (SessionVariable.getEnableDecimal256()) { - int precision = scalarType.decimalPrecision(); - int scale = scalarType.decimalScale(); - if (precision < 1 || precision > ScalarType.MAX_DECIMAL256_PRECISION) { - throw new AnalysisException("Precision of decimal256 must between 1 and 76." - + " Precision was set to: " + precision + "."); - } - // scale >= 0 - if (scale < 0) { - throw new AnalysisException("Scale of decimal must not be less than 0." - + " Scale was set to: " + scale + "."); - } - // scale < precision - if (scale > precision) { - throw new AnalysisException( - "Scale of decimal must be smaller than precision." + " Scale is " - + scale + " and precision is " + precision); - } - break; - } else { - int precision = scalarType.decimalPrecision(); - throw new AnalysisException("Column of type Decimal256 with precision " - + precision + " in not supported."); - } - } - case TIMEV2: - case DATETIMEV2: { - int precision = scalarType.decimalPrecision(); - int scale = scalarType.decimalScale(); - // precision: [1, 27] - if (precision != ScalarType.DATETIME_PRECISION) { - throw new AnalysisException( - "Precision of Datetime/Time must be " + ScalarType.DATETIME_PRECISION - + "." + " Precision was set to: " + precision + "."); - } - // scale: [0, 9] - if (scale < 0 || scale > 6) { - throw new AnalysisException("Scale of Datetime/Time must between 0 and 6." - + " Scale was set to: " + scale + "."); - } - break; - } - case INVALID_TYPE: - throw new AnalysisException("Invalid type."); - default: - break; - } - } - - private void validateNestedType(Type parent, Type child) throws AnalysisException { - if (child.isNull()) { - throw new AnalysisException("Unsupported data type: " + child.toSql()); - } - // check whether the sub-type is supported - if (!parent.supportSubType(child)) { - throw new AnalysisException( - parent.getPrimitiveType() + " unsupported sub-type: " + child.toSql()); - } - validateDataType(child); - } - /** * translate to catalog create table stmt */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgTypesInfo.java similarity index 81% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgTypesInfo.java index e3da48bfb7cbed..5b82f5015dff3b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgsDefInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgTypesInfo.java @@ -19,22 +19,24 @@ import org.apache.doris.catalog.Type; import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.util.TypeCoercionUtils; +import org.apache.doris.nereids.util.Utils; import java.util.List; +import java.util.Objects; /** * represent function arguments */ -public class FunctionArgsDefInfo { +public class FunctionArgTypesInfo { private final List argTypeDefs; private final boolean isVariadic; // set after analyze private Type[] argTypes; - public FunctionArgsDefInfo(List argTypeDefs, boolean isVariadic) { - this.argTypeDefs = argTypeDefs; + public FunctionArgTypesInfo(List argTypeDefs, boolean isVariadic) { + this.argTypeDefs = Utils.fastToImmutableList(Objects.requireNonNull(argTypeDefs, + "argTypeDefs should not be null")); this.isVariadic = isVariadic; } @@ -57,7 +59,7 @@ public void validate() { argTypes = new Type[argTypeDefs.size()]; int i = 0; for (DataType dataType : argTypeDefs) { - TypeCoercionUtils.validateDataType(dataType); + dataType.validateDataType(); argTypes[i++] = dataType.toCatalogDataType(); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java index 9b77017f6de8cb..581edecb85ffa3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java @@ -33,14 +33,18 @@ import org.apache.doris.nereids.types.coercion.IntegralType; import org.apache.doris.nereids.types.coercion.NumericType; import org.apache.doris.nereids.types.coercion.PrimitiveType; +import org.apache.doris.qe.SessionVariable; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Set; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -763,4 +767,243 @@ public boolean isAssignableFrom(DataType targetDataType) { } return false; } + + public void validateDataType() { + validateCatalogDataType(toCatalogDataType()); + } + + private static void validateCatalogDataType(Type catalogType) { + if (catalogType.exceedsMaxNestingDepth()) { + throw new AnalysisException( + String.format("Type exceeds the maximum nesting depth of %s:\n%s", + Type.MAX_NESTING_DEPTH, catalogType.toSql())); + } + if (!catalogType.isSupported()) { + throw new AnalysisException("Unsupported data type: " + catalogType.toSql()); + } + + if (catalogType.isScalarType()) { + validateScalarType((ScalarType) catalogType); + } else if (catalogType.isComplexType()) { + // now we not support array / map / struct nesting complex type + if (catalogType.isArrayType()) { + Type itemType = ((org.apache.doris.catalog.ArrayType) catalogType).getItemType(); + if (itemType instanceof ScalarType) { + validateNestedType(catalogType, (ScalarType) itemType); + } + } + if (catalogType.isMapType()) { + org.apache.doris.catalog.MapType mt = + (org.apache.doris.catalog.MapType) catalogType; + if (mt.getKeyType() instanceof ScalarType) { + validateNestedType(catalogType, (ScalarType) mt.getKeyType()); + } + if (mt.getValueType() instanceof ScalarType) { + validateNestedType(catalogType, (ScalarType) mt.getValueType()); + } + } + if (catalogType.isStructType()) { + ArrayList fields = + ((org.apache.doris.catalog.StructType) catalogType).getFields(); + Set fieldNames = new HashSet<>(); + for (org.apache.doris.catalog.StructField field : fields) { + Type fieldType = field.getType(); + if (fieldType instanceof ScalarType) { + validateNestedType(catalogType, (ScalarType) fieldType); + if (!fieldNames.add(field.getName())) { + throw new AnalysisException("Duplicate field name " + field.getName() + + " in struct " + catalogType.toSql()); + } + } + } + } + } + } + + private static void validateNestedType(Type parent, Type child) throws AnalysisException { + if (child.isNull()) { + throw new AnalysisException("Unsupported data type: " + child.toSql()); + } + // check whether the sub-type is supported + if (!parent.supportSubType(child)) { + throw new AnalysisException( + parent.getPrimitiveType() + " unsupported sub-type: " + child.toSql()); + } + validateCatalogDataType(child); + } + + private static void validateScalarType(ScalarType scalarType) { + org.apache.doris.catalog.PrimitiveType type = scalarType.getPrimitiveType(); + // When string type length is not assigned, it needs to be assigned to 1. + if (scalarType.getPrimitiveType().isStringType() && !scalarType.isLengthSet()) { + if (scalarType.getPrimitiveType() == org.apache.doris.catalog.PrimitiveType.VARCHAR) { + // always set varchar length MAX_VARCHAR_LENGTH + scalarType.setLength(ScalarType.MAX_VARCHAR_LENGTH); + } else if (scalarType.getPrimitiveType() == org.apache.doris.catalog.PrimitiveType.STRING) { + // always set text length MAX_STRING_LENGTH + scalarType.setLength(ScalarType.MAX_STRING_LENGTH); + } else { + scalarType.setLength(1); + } + } + switch (type) { + case CHAR: + case VARCHAR: { + String name; + int maxLen; + if (type == org.apache.doris.catalog.PrimitiveType.VARCHAR) { + name = "VARCHAR"; + maxLen = ScalarType.MAX_VARCHAR_LENGTH; + } else { + name = "CHAR"; + maxLen = ScalarType.MAX_CHAR_LENGTH; + } + int len = scalarType.getLength(); + // len is decided by child, when it is -1. + + if (len <= 0) { + throw new AnalysisException(name + " size must be > 0: " + len); + } + if (scalarType.getLength() > maxLen) { + throw new AnalysisException(name + " size must be <= " + maxLen + ": " + len); + } + break; + } + case DECIMALV2: { + int precision = scalarType.decimalPrecision(); + int scale = scalarType.decimalScale(); + // precision: [1, 27] + if (precision < 1 || precision > ScalarType.MAX_DECIMALV2_PRECISION) { + throw new AnalysisException("Precision of decimal must between 1 and 27." + + " Precision was set to: " + precision + "."); + } + // scale: [0, 9] + if (scale < 0 || scale > ScalarType.MAX_DECIMALV2_SCALE) { + throw new AnalysisException("Scale of decimal must between 0 and 9." + + " Scale was set to: " + scale + "."); + } + if (precision - scale > ScalarType.MAX_DECIMALV2_PRECISION + - ScalarType.MAX_DECIMALV2_SCALE) { + throw new AnalysisException("Invalid decimal type with precision = " + precision + + ", scale = " + scale); + } + // scale < precision + if (scale > precision) { + throw new AnalysisException("Scale of decimal must be smaller than precision." + + " Scale is " + scale + " and precision is " + precision); + } + break; + } + case DECIMAL32: { + int decimal32Precision = scalarType.decimalPrecision(); + int decimal32Scale = scalarType.decimalScale(); + if (decimal32Precision < 1 + || decimal32Precision > ScalarType.MAX_DECIMAL32_PRECISION) { + throw new AnalysisException("Precision of decimal must between 1 and 9." + + " Precision was set to: " + decimal32Precision + "."); + } + // scale >= 0 + if (decimal32Scale < 0) { + throw new AnalysisException("Scale of decimal must not be less than 0." + + " Scale was set to: " + decimal32Scale + "."); + } + // scale < precision + if (decimal32Scale > decimal32Precision) { + throw new AnalysisException( + "Scale of decimal must be smaller than precision." + " Scale is " + + decimal32Scale + " and precision is " + decimal32Precision); + } + break; + } + case DECIMAL64: { + int decimal64Precision = scalarType.decimalPrecision(); + int decimal64Scale = scalarType.decimalScale(); + if (decimal64Precision < 1 + || decimal64Precision > ScalarType.MAX_DECIMAL64_PRECISION) { + throw new AnalysisException("Precision of decimal64 must between 1 and 18." + + " Precision was set to: " + decimal64Precision + "."); + } + // scale >= 0 + if (decimal64Scale < 0) { + throw new AnalysisException("Scale of decimal must not be less than 0." + + " Scale was set to: " + decimal64Scale + "."); + } + // scale < precision + if (decimal64Scale > decimal64Precision) { + throw new AnalysisException( + "Scale of decimal must be smaller than precision." + " Scale is " + + decimal64Scale + " and precision is " + decimal64Precision); + } + break; + } + case DECIMAL128: { + int decimal128Precision = scalarType.decimalPrecision(); + int decimal128Scale = scalarType.decimalScale(); + if (decimal128Precision < 1 + || decimal128Precision > ScalarType.MAX_DECIMAL128_PRECISION) { + throw new AnalysisException("Precision of decimal128 must between 1 and 38." + + " Precision was set to: " + decimal128Precision + "."); + } + // scale >= 0 + if (decimal128Scale < 0) { + throw new AnalysisException("Scale of decimal must not be less than 0." + + " Scale was set to: " + decimal128Scale + "."); + } + // scale < precision + if (decimal128Scale > decimal128Precision) { + throw new AnalysisException( + "Scale of decimal must be smaller than precision." + " Scale is " + + decimal128Scale + " and precision is " + decimal128Precision); + } + break; + } + case DECIMAL256: { + if (SessionVariable.getEnableDecimal256()) { + int precision = scalarType.decimalPrecision(); + int scale = scalarType.decimalScale(); + if (precision < 1 || precision > ScalarType.MAX_DECIMAL256_PRECISION) { + throw new AnalysisException("Precision of decimal256 must between 1 and 76." + + " Precision was set to: " + precision + "."); + } + // scale >= 0 + if (scale < 0) { + throw new AnalysisException("Scale of decimal must not be less than 0." + + " Scale was set to: " + scale + "."); + } + // scale < precision + if (scale > precision) { + throw new AnalysisException( + "Scale of decimal must be smaller than precision." + " Scale is " + + scale + " and precision is " + precision); + } + break; + } else { + int precision = scalarType.decimalPrecision(); + throw new AnalysisException("Column of type Decimal256 with precision " + + precision + " in not supported."); + } + } + case TIMEV2: + case DATETIMEV2: { + int precision = scalarType.decimalPrecision(); + int scale = scalarType.decimalScale(); + // precision: [1, 27] + if (precision != ScalarType.DATETIME_PRECISION) { + throw new AnalysisException( + "Precision of Datetime/Time must be " + ScalarType.DATETIME_PRECISION + + "." + " Precision was set to: " + precision + "."); + } + // scale: [0, 9] + if (scale < 0 || scale > 6) { + throw new AnalysisException("Scale of Datetime/Time must between 0 and 6." + + " Scale was set to: " + scale + "."); + } + break; + } + case INVALID_TYPE: + throw new AnalysisException("Invalid type."); + default: + break; + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index 15142a77e80e01..5f8be613552c78 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -112,7 +112,6 @@ import org.apache.doris.nereids.types.coercion.IntegralType; import org.apache.doris.nereids.types.coercion.NumericType; import org.apache.doris.nereids.types.coercion.PrimitiveType; -import org.apache.doris.qe.SessionVariable; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; @@ -126,12 +125,10 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.util.ArrayList; -import java.util.HashSet; import java.util.List; import java.util.ListIterator; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; @@ -1850,243 +1847,4 @@ private static boolean supportCompare(DataType dataType) { } return true; } - - public static void validateDataType(DataType dataType) { - validateCatalogDataType(dataType.toCatalogDataType()); - } - - private static void validateCatalogDataType(Type catalogType) { - if (catalogType.exceedsMaxNestingDepth()) { - throw new AnalysisException( - String.format("Type exceeds the maximum nesting depth of %s:\n%s", - Type.MAX_NESTING_DEPTH, catalogType.toSql())); - } - if (!catalogType.isSupported()) { - throw new AnalysisException("Unsupported data type: " + catalogType.toSql()); - } - - if (catalogType.isScalarType()) { - validateScalarType((ScalarType) catalogType); - } else if (catalogType.isComplexType()) { - // now we not support array / map / struct nesting complex type - if (catalogType.isArrayType()) { - Type itemType = ((org.apache.doris.catalog.ArrayType) catalogType).getItemType(); - if (itemType instanceof ScalarType) { - validateNestedType(catalogType, (ScalarType) itemType); - } - } - if (catalogType.isMapType()) { - org.apache.doris.catalog.MapType mt = - (org.apache.doris.catalog.MapType) catalogType; - if (mt.getKeyType() instanceof ScalarType) { - validateNestedType(catalogType, (ScalarType) mt.getKeyType()); - } - if (mt.getValueType() instanceof ScalarType) { - validateNestedType(catalogType, (ScalarType) mt.getValueType()); - } - } - if (catalogType.isStructType()) { - ArrayList fields = - ((org.apache.doris.catalog.StructType) catalogType).getFields(); - Set fieldNames = new HashSet<>(); - for (org.apache.doris.catalog.StructField field : fields) { - Type fieldType = field.getType(); - if (fieldType instanceof ScalarType) { - validateNestedType(catalogType, (ScalarType) fieldType); - if (!fieldNames.add(field.getName())) { - throw new AnalysisException("Duplicate field name " + field.getName() - + " in struct " + catalogType.toSql()); - } - } - } - } - } - } - - private static void validateNestedType(Type parent, Type child) throws AnalysisException { - if (child.isNull()) { - throw new AnalysisException("Unsupported data type: " + child.toSql()); - } - // check whether the sub-type is supported - if (!parent.supportSubType(child)) { - throw new AnalysisException( - parent.getPrimitiveType() + " unsupported sub-type: " + child.toSql()); - } - validateCatalogDataType(child); - } - - private static void validateScalarType(ScalarType scalarType) { - org.apache.doris.catalog.PrimitiveType type = scalarType.getPrimitiveType(); - // When string type length is not assigned, it needs to be assigned to 1. - if (scalarType.getPrimitiveType().isStringType() && !scalarType.isLengthSet()) { - if (scalarType.getPrimitiveType() == org.apache.doris.catalog.PrimitiveType.VARCHAR) { - // always set varchar length MAX_VARCHAR_LENGTH - scalarType.setLength(ScalarType.MAX_VARCHAR_LENGTH); - } else if (scalarType.getPrimitiveType() == org.apache.doris.catalog.PrimitiveType.STRING) { - // always set text length MAX_STRING_LENGTH - scalarType.setLength(ScalarType.MAX_STRING_LENGTH); - } else { - scalarType.setLength(1); - } - } - switch (type) { - case CHAR: - case VARCHAR: { - String name; - int maxLen; - if (type == org.apache.doris.catalog.PrimitiveType.VARCHAR) { - name = "VARCHAR"; - maxLen = ScalarType.MAX_VARCHAR_LENGTH; - } else { - name = "CHAR"; - maxLen = ScalarType.MAX_CHAR_LENGTH; - } - int len = scalarType.getLength(); - // len is decided by child, when it is -1. - - if (len <= 0) { - throw new AnalysisException(name + " size must be > 0: " + len); - } - if (scalarType.getLength() > maxLen) { - throw new AnalysisException(name + " size must be <= " + maxLen + ": " + len); - } - break; - } - case DECIMALV2: { - int precision = scalarType.decimalPrecision(); - int scale = scalarType.decimalScale(); - // precision: [1, 27] - if (precision < 1 || precision > ScalarType.MAX_DECIMALV2_PRECISION) { - throw new AnalysisException("Precision of decimal must between 1 and 27." - + " Precision was set to: " + precision + "."); - } - // scale: [0, 9] - if (scale < 0 || scale > ScalarType.MAX_DECIMALV2_SCALE) { - throw new AnalysisException("Scale of decimal must between 0 and 9." - + " Scale was set to: " + scale + "."); - } - if (precision - scale > ScalarType.MAX_DECIMALV2_PRECISION - - ScalarType.MAX_DECIMALV2_SCALE) { - throw new AnalysisException("Invalid decimal type with precision = " + precision - + ", scale = " + scale); - } - // scale < precision - if (scale > precision) { - throw new AnalysisException("Scale of decimal must be smaller than precision." - + " Scale is " + scale + " and precision is " + precision); - } - break; - } - case DECIMAL32: { - int decimal32Precision = scalarType.decimalPrecision(); - int decimal32Scale = scalarType.decimalScale(); - if (decimal32Precision < 1 - || decimal32Precision > ScalarType.MAX_DECIMAL32_PRECISION) { - throw new AnalysisException("Precision of decimal must between 1 and 9." - + " Precision was set to: " + decimal32Precision + "."); - } - // scale >= 0 - if (decimal32Scale < 0) { - throw new AnalysisException("Scale of decimal must not be less than 0." - + " Scale was set to: " + decimal32Scale + "."); - } - // scale < precision - if (decimal32Scale > decimal32Precision) { - throw new AnalysisException( - "Scale of decimal must be smaller than precision." + " Scale is " - + decimal32Scale + " and precision is " + decimal32Precision); - } - break; - } - case DECIMAL64: { - int decimal64Precision = scalarType.decimalPrecision(); - int decimal64Scale = scalarType.decimalScale(); - if (decimal64Precision < 1 - || decimal64Precision > ScalarType.MAX_DECIMAL64_PRECISION) { - throw new AnalysisException("Precision of decimal64 must between 1 and 18." - + " Precision was set to: " + decimal64Precision + "."); - } - // scale >= 0 - if (decimal64Scale < 0) { - throw new AnalysisException("Scale of decimal must not be less than 0." - + " Scale was set to: " + decimal64Scale + "."); - } - // scale < precision - if (decimal64Scale > decimal64Precision) { - throw new AnalysisException( - "Scale of decimal must be smaller than precision." + " Scale is " - + decimal64Scale + " and precision is " + decimal64Precision); - } - break; - } - case DECIMAL128: { - int decimal128Precision = scalarType.decimalPrecision(); - int decimal128Scale = scalarType.decimalScale(); - if (decimal128Precision < 1 - || decimal128Precision > ScalarType.MAX_DECIMAL128_PRECISION) { - throw new AnalysisException("Precision of decimal128 must between 1 and 38." - + " Precision was set to: " + decimal128Precision + "."); - } - // scale >= 0 - if (decimal128Scale < 0) { - throw new AnalysisException("Scale of decimal must not be less than 0." - + " Scale was set to: " + decimal128Scale + "."); - } - // scale < precision - if (decimal128Scale > decimal128Precision) { - throw new AnalysisException( - "Scale of decimal must be smaller than precision." + " Scale is " - + decimal128Scale + " and precision is " + decimal128Precision); - } - break; - } - case DECIMAL256: { - if (SessionVariable.getEnableDecimal256()) { - int precision = scalarType.decimalPrecision(); - int scale = scalarType.decimalScale(); - if (precision < 1 || precision > ScalarType.MAX_DECIMAL256_PRECISION) { - throw new AnalysisException("Precision of decimal256 must between 1 and 76." - + " Precision was set to: " + precision + "."); - } - // scale >= 0 - if (scale < 0) { - throw new AnalysisException("Scale of decimal must not be less than 0." - + " Scale was set to: " + scale + "."); - } - // scale < precision - if (scale > precision) { - throw new AnalysisException( - "Scale of decimal must be smaller than precision." + " Scale is " - + scale + " and precision is " + precision); - } - break; - } else { - int precision = scalarType.decimalPrecision(); - throw new AnalysisException("Column of type Decimal256 with precision " - + precision + " in not supported."); - } - } - case TIMEV2: - case DATETIMEV2: { - int precision = scalarType.decimalPrecision(); - int scale = scalarType.decimalScale(); - // precision: [1, 27] - if (precision != ScalarType.DATETIME_PRECISION) { - throw new AnalysisException( - "Precision of Datetime/Time must be " + ScalarType.DATETIME_PRECISION - + "." + " Precision was set to: " + precision + "."); - } - // scale: [0, 9] - if (scale < 0 || scale > 6) { - throw new AnalysisException("Scale of Datetime/Time must between 0 and 6." - + " Scale was set to: " + scale + "."); - } - break; - } - case INVALID_TYPE: - throw new AnalysisException("Invalid type."); - default: - break; - } - } } From d2a0fdfc1877415c658294e71ffef84cbe5805f9 Mon Sep 17 00:00:00 2001 From: lichi Date: Thu, 9 Jan 2025 18:06:05 +0800 Subject: [PATCH 4/4] update code --- .../org/apache/doris/nereids/DorisParser.g4 | 9 +-- .../nereids/parser/LogicalPlanBuilder.java | 48 +++++++------- .../plans/commands/CreateFunctionCommand.java | 62 +++++++++---------- .../plans/commands/DropFunctionCommand.java | 7 +-- .../commands/info/FunctionArgTypesInfo.java | 2 +- 5 files changed, 64 insertions(+), 64 deletions(-) diff --git a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 index 4a4855ecd7f311..455b75325effc9 100644 --- a/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 +++ b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisParser.g4 @@ -818,12 +818,13 @@ passwordOption ; functionArguments - : functionArgument (COMMA functionArgument)* + : DOTDOTDOT + | dataTypeList + | dataTypeList COMMA DOTDOTDOT ; -functionArgument - : DOTDOTDOT - | dataType +dataTypeList + : dataType (COMMA dataType)* ; supportedSetStatement diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index d99b37c99f218d..56b259b367a5a5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -126,6 +126,7 @@ import org.apache.doris.nereids.DorisParser.CreateViewContext; import org.apache.doris.nereids.DorisParser.CreateWorkloadGroupContext; import org.apache.doris.nereids.DorisParser.CteContext; +import org.apache.doris.nereids.DorisParser.DataTypeListContext; import org.apache.doris.nereids.DorisParser.DataTypeWithNullableContext; import org.apache.doris.nereids.DorisParser.DecimalLiteralContext; import org.apache.doris.nereids.DorisParser.DeleteContext; @@ -158,8 +159,8 @@ import org.apache.doris.nereids.DorisParser.ExportContext; import org.apache.doris.nereids.DorisParser.FixedPartitionDefContext; import org.apache.doris.nereids.DorisParser.FromClauseContext; -import org.apache.doris.nereids.DorisParser.FunctionArgumentContext; import org.apache.doris.nereids.DorisParser.FunctionArgumentsContext; +import org.apache.doris.nereids.DorisParser.FunctionIdentifierContext; import org.apache.doris.nereids.DorisParser.GroupingElementContext; import org.apache.doris.nereids.DorisParser.GroupingSetContext; import org.apache.doris.nereids.DorisParser.HavingClauseContext; @@ -4140,9 +4141,7 @@ public Command visitCreateUserDefineFunction(CreateUserDefineFunctionContext ctx boolean ifNotExists = ctx.EXISTS() != null; boolean isAggFunction = ctx.AGGREGATE() != null; boolean isTableFunction = ctx.TABLES() != null; - String functionName = ctx.functionIdentifier().functionNameIdentifier().getText(); - String dbName = ctx.functionIdentifier().dbName != null ? ctx.functionIdentifier().dbName.getText() : null; - FunctionName function = new FunctionName(dbName, functionName); + FunctionName function = visitFunctionIdentifier(ctx.functionIdentifier()); FunctionArgTypesInfo functionArgTypesInfo; if (ctx.functionArguments() != null) { functionArgTypesInfo = visitFunctionArguments(ctx.functionArguments()); @@ -4167,9 +4166,7 @@ public Command visitCreateUserDefineFunction(CreateUserDefineFunctionContext ctx public Command visitCreateAliasFunction(CreateAliasFunctionContext ctx) { SetType statementScope = visitStatementScope(ctx.statementScope()); boolean ifNotExists = ctx.EXISTS() != null; - String functionName = ctx.functionIdentifier().functionNameIdentifier().getText(); - String dbName = ctx.functionIdentifier().dbName != null ? ctx.functionIdentifier().dbName.getText() : null; - FunctionName function = new FunctionName(dbName, functionName); + FunctionName function = visitFunctionIdentifier(ctx.functionIdentifier()); FunctionArgTypesInfo functionArgTypesInfo; if (ctx.functionArguments() != null) { functionArgTypesInfo = visitFunctionArguments(ctx.functionArguments()); @@ -4187,9 +4184,7 @@ public Command visitCreateAliasFunction(CreateAliasFunctionContext ctx) { public Command visitDropFunction(DropFunctionContext ctx) { SetType statementScope = visitStatementScope(ctx.statementScope()); boolean ifExists = ctx.EXISTS() != null; - String functionName = ctx.functionIdentifier().functionNameIdentifier().getText(); - String dbName = ctx.functionIdentifier().dbName != null ? ctx.functionIdentifier().dbName.getText() : null; - FunctionName function = new FunctionName(dbName, functionName); + FunctionName function = visitFunctionIdentifier(ctx.functionIdentifier()); FunctionArgTypesInfo functionArgTypesInfo; if (ctx.functionArguments() != null) { functionArgTypesInfo = visitFunctionArguments(ctx.functionArguments()); @@ -4201,24 +4196,31 @@ public Command visitDropFunction(DropFunctionContext ctx) { @Override public FunctionArgTypesInfo visitFunctionArguments(FunctionArgumentsContext ctx) { - boolean isVariadic = false; - List argTypeDefs = new ArrayList<>(4); - for (Object child : ctx.children) { - if (child instanceof FunctionArgumentContext) { - DataType dataType = visitFunctionArgument((FunctionArgumentContext) child); - if (dataType != null) { - argTypeDefs.add(dataType.conversion()); - } else { - isVariadic = true; - } - } + boolean isVariadic = ctx.DOTDOTDOT() != null; + List argTypeDefs; + if (ctx.dataTypeList() != null) { + argTypeDefs = visitDataTypeList(ctx.dataTypeList()); + } else { + argTypeDefs = new ArrayList<>(); } return new FunctionArgTypesInfo(argTypeDefs, isVariadic); } @Override - public DataType visitFunctionArgument(FunctionArgumentContext ctx) { - return ctx.dataType() != null ? typedVisit(ctx.dataType()) : null; + public FunctionName visitFunctionIdentifier(FunctionIdentifierContext ctx) { + String functionName = ctx.functionNameIdentifier().getText(); + String dbName = ctx.dbName != null ? ctx.dbName.getText() : null; + return new FunctionName(dbName, functionName); + } + + @Override + public List visitDataTypeList(DataTypeListContext ctx) { + List dataTypeList = new ArrayList<>(ctx.getChildCount()); + for (DorisParser.DataTypeContext dataTypeContext : ctx.dataType()) { + DataType dataType = typedVisit(dataTypeContext); + dataTypeList.add(dataType.conversion()); + } + return dataTypeList; } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java index 2ba6301052965c..8be22c7330dbbb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java @@ -110,42 +110,42 @@ */ public class CreateFunctionCommand extends Command implements ForwardWithSync { @Deprecated - public static final String OBJECT_FILE_KEY = "object_file"; - public static final String FILE_KEY = "file"; - public static final String SYMBOL_KEY = "symbol"; - public static final String PREPARE_SYMBOL_KEY = "prepare_fn"; - public static final String CLOSE_SYMBOL_KEY = "close_fn"; - public static final String MD5_CHECKSUM = "md5"; - public static final String INIT_KEY = "init_fn"; - public static final String UPDATE_KEY = "update_fn"; - public static final String MERGE_KEY = "merge_fn"; - public static final String SERIALIZE_KEY = "serialize_fn"; - public static final String FINALIZE_KEY = "finalize_fn"; - public static final String GET_VALUE_KEY = "get_value_fn"; - public static final String REMOVE_KEY = "remove_fn"; - public static final String BINARY_TYPE = "type"; - public static final String EVAL_METHOD_KEY = "evaluate"; - public static final String CREATE_METHOD_NAME = "create"; - public static final String DESTROY_METHOD_NAME = "destroy"; - public static final String ADD_METHOD_NAME = "add"; - public static final String SERIALIZE_METHOD_NAME = "serialize"; - public static final String MERGE_METHOD_NAME = "merge"; - public static final String GETVALUE_METHOD_NAME = "getValue"; - public static final String STATE_CLASS_NAME = "State"; + private static final String OBJECT_FILE_KEY = "object_file"; + private static final String FILE_KEY = "file"; + private static final String SYMBOL_KEY = "symbol"; + private static final String PREPARE_SYMBOL_KEY = "prepare_fn"; + private static final String CLOSE_SYMBOL_KEY = "close_fn"; + private static final String MD5_CHECKSUM = "md5"; + private static final String INIT_KEY = "init_fn"; + private static final String UPDATE_KEY = "update_fn"; + private static final String MERGE_KEY = "merge_fn"; + private static final String SERIALIZE_KEY = "serialize_fn"; + private static final String FINALIZE_KEY = "finalize_fn"; + private static final String GET_VALUE_KEY = "get_value_fn"; + private static final String REMOVE_KEY = "remove_fn"; + private static final String BINARY_TYPE = "type"; + private static final String EVAL_METHOD_KEY = "evaluate"; + private static final String CREATE_METHOD_NAME = "create"; + private static final String DESTROY_METHOD_NAME = "destroy"; + private static final String ADD_METHOD_NAME = "add"; + private static final String SERIALIZE_METHOD_NAME = "serialize"; + private static final String MERGE_METHOD_NAME = "merge"; + private static final String GETVALUE_METHOD_NAME = "getValue"; + private static final String STATE_CLASS_NAME = "State"; // add for java udf check return type nullable mode, always_nullable or always_not_nullable - public static final String IS_RETURN_NULL = "always_nullable"; + private static final String IS_RETURN_NULL = "always_nullable"; // iff is static load, BE will be cache the udf class load, so only need load once - public static final String IS_STATIC_LOAD = "static_load"; - public static final String EXPIRATION_TIME = "expiration_time"; + private static final String IS_STATIC_LOAD = "static_load"; + private static final String EXPIRATION_TIME = "expiration_time"; // timeout for both connection and read. 10 seconds is long enough. private static final int HTTP_TIMEOUT_MS = 10000; - private SetType setType = SetType.DEFAULT; + private final SetType setType; private final boolean ifNotExists; private final FunctionName functionName; private final boolean isAggregate; private final boolean isAlias; - private boolean isTableFunction; + private final boolean isTableFunction; private final FunctionArgTypesInfo argsDef; private final DataType returnType; private DataType intermediateType; @@ -238,7 +238,7 @@ private void analyze(ConnectContext ctx) throws Exception { analyzeCommon(ctx); // check if (isAggregate) { - analyzeUda(); + analyzeUdaf(); } else if (isAlias) { analyzeAliasFunction(ctx); } else if (isTableFunction) { @@ -251,7 +251,7 @@ private void analyze(ConnectContext ctx) throws Exception { analyzeCommon(ctx); // check if (isAggregate) { - analyzeUda(); + analyzeUdaf(); } else if (isAlias) { analyzeAliasFunction(ctx); } else if (isTableFunction) { @@ -276,7 +276,7 @@ private void analyzeCommon(ConnectContext ctx) throws AnalysisException { ErrorReport.reportAnalysisException(ErrorCode.ERR_SPECIFIC_ACCESS_DENIED_ERROR, "ADMIN"); } // check argument - argsDef.validate(); + argsDef.analyze(); // alias function does not need analyze following params if (isAlias) { @@ -399,7 +399,7 @@ private void analyzeUdtf() throws AnalysisException { // normal and one is outer as those have different result when result is NULL. } - private void analyzeUda() throws AnalysisException { + private void analyzeUdaf() throws AnalysisException { AggregateFunction.AggregateFunctionBuilder builder = AggregateFunction.AggregateFunctionBuilder .createUdfBuilder(); URI location; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java index 3cf19921b467e6..8135e85a74cd86 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DropFunctionCommand.java @@ -50,8 +50,6 @@ public class DropFunctionCommand extends Command implements ForwardWithSync { private final boolean ifExists; private final FunctionName functionName; private final FunctionArgTypesInfo argsDef; - // set after analyzed - private FunctionSearchDesc function; /** * DropFunctionCommand @@ -71,8 +69,8 @@ public void run(ConnectContext ctx, StmtExecutor executor) throws Exception { if (!Env.getCurrentEnv().getAccessManager().checkGlobalPriv(ConnectContext.get(), PrivPredicate.ADMIN)) { ErrorReport.reportAnalysisException(ErrorCode.ERR_SPECIFIC_ACCESS_DENIED_ERROR, "ADMIN"); } - argsDef.validate(); - function = new FunctionSearchDesc(functionName, argsDef.getArgTypes(), argsDef.isVariadic()); + argsDef.analyze(); + FunctionSearchDesc function = new FunctionSearchDesc(functionName, argsDef.getArgTypes(), argsDef.isVariadic()); if (SetType.GLOBAL.equals(setType)) { Env.getCurrentEnv().getGlobalFunctionMgr().dropFunction(function, ifExists); } else { @@ -94,7 +92,6 @@ public void run(ConnectContext ctx, StmtExecutor executor) throws Exception { LOG.info("clean udf cache in be {}, beId {}", backend.getHost(), backend.getId()); } AgentTaskExecutor.submit(batchTask); - } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgTypesInfo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgTypesInfo.java index 5b82f5015dff3b..f02c3d1b2ec314 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgTypesInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/FunctionArgTypesInfo.java @@ -55,7 +55,7 @@ public boolean isVariadic() { /** * validate */ - public void validate() { + public void analyze() { argTypes = new Type[argTypeDefs.size()]; int i = 0; for (DataType dataType : argTypeDefs) {