diff --git a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java index b095e81c71e4..3b7cd2b62cea 100644 --- a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java +++ b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java @@ -197,7 +197,16 @@ public enum TSStatusCode { CQ_ALREADY_EXIST(1402), CQ_UPDATE_LAST_EXEC_TIME_ERROR(1403), - // code 1500-1599 are used by IoTDB-ML + // AI + CREATE_MODEL_ERROR(1500), + DROP_MODEL_ERROR(1501), + MODEL_EXIST_ERROR(1502), + GET_MODEL_INFO_ERROR(1503), + NO_REGISTERED_AI_NODE_ERROR(1504), + MODEL_NOT_FOUND_ERROR(1505), + REGISTER_AI_NODE_ERROR(1506), + AI_NODE_INTERNAL_ERROR(1510), + REMOVE_AI_NODE_ERROR(1511), // Pipe Plugin CREATE_PIPE_PLUGIN_ERROR(1600), diff --git a/iotdb-core/ainode/.gitignore b/iotdb-core/ainode/.gitignore new file mode 100644 index 000000000000..b7ad350dc979 --- /dev/null +++ b/iotdb-core/ainode/.gitignore @@ -0,0 +1,11 @@ +# generated by Thrift +/iotdb/thrift/ + +# generated by maven +/iotdb/conf/ + +# .whl of ainode, generated by Poetry +/dist/ + +# the config to build ainode, it will be generated automatically +pyproject.toml diff --git a/iotdb-core/ainode/README.md b/iotdb-core/ainode/README.md new file mode 100644 index 000000000000..150ad93e499b --- /dev/null +++ b/iotdb-core/ainode/README.md @@ -0,0 +1,22 @@ + + +# Apache IoTDB AINode \ No newline at end of file diff --git a/iotdb-core/ainode/README_ZH.md b/iotdb-core/ainode/README_ZH.md new file mode 100644 index 000000000000..150ad93e499b --- /dev/null +++ b/iotdb-core/ainode/README_ZH.md @@ -0,0 +1,22 @@ + + +# Apache IoTDB AINode \ No newline at end of file diff --git a/iotdb-core/ainode/ainode.xml b/iotdb-core/ainode/ainode.xml new file mode 100644 index 000000000000..480c3e7221e6 --- /dev/null +++ b/iotdb-core/ainode/ainode.xml @@ -0,0 +1,62 @@ + + + + ainode-assembly + + dir + zip + + + + README.md + + + README_ZH.md + + + ${maven.multiModuleProjectDirectory}/LICENSE-binary + LICENSE + + + ${maven.multiModuleProjectDirectory}/NOTICE-binary + NOTICE + + + + + resources/conf + conf + + + resources/sbin + sbin + 0755 + + + dist + lib + + *.whl + + + + diff --git a/iotdb-core/ainode/iotdb/__init__.py b/iotdb-core/ainode/iotdb/__init__.py new file mode 100644 index 000000000000..2a1e720805f2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/__init__.py b/iotdb-core/ainode/iotdb/ainode/__init__.py new file mode 100644 index 000000000000..2a1e720805f2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/attribute.py b/iotdb-core/ainode/iotdb/ainode/attribute.py new file mode 100644 index 000000000000..a91ae436cd10 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/attribute.py @@ -0,0 +1,669 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from abc import abstractmethod +from typing import List, Dict + +from iotdb.ainode.constant import AttributeName, BuiltInModelType +from iotdb.ainode.exception import WrongAttributeTypeError, NumericalRangeException, StringRangeException, \ + ListRangeException, BuiltInModelNotSupportError + + +class Attribute(object): + def __init__(self, name: str): + """ + Args: + name: the name of the attribute + """ + self.name = name + + @abstractmethod + def get_default_value(self): + raise NotImplementedError + + @abstractmethod + def validate_value(self, value): + raise NotImplementedError + + @abstractmethod + def parse(self, string_value: str): + raise NotImplementedError + + +class IntAttribute(Attribute): + def __init__(self, name: str, + default_value: int, + default_low: int, + default_high: int, + ): + super(IntAttribute, self).__init__(name) + self.__default_value = default_value + self.__default_low = default_low + self.__default_high = default_high + + def get_default_value(self): + return self.__default_value + + def validate_value(self, value): + if self.__default_low <= value <= self.__default_high: + return True + raise NumericalRangeException(self._name, value, self.__default_low, self.__default_high) + + def parse(self, string_value: str): + try: + int_value = int(string_value) + except: + raise WrongAttributeTypeError(self.name, "int") + return int_value + + +class FloatAttribute(Attribute): + def __init__(self, name: str, + default_value: float, + default_low: float, + default_high: float, + ): + super(FloatAttribute, self).__init__(name) + self.__default_value = default_value + self.__default_low = default_low + self.__default_high = default_high + + def get_default_value(self): + return self.__default_value + + def validate_value(self, value): + if self.__default_low <= value <= self.__default_high: + return True + raise NumericalRangeException(self._name, value, self.__default_low, self.__default_high) + + def parse(self, string_value: str): + try: + float_value = float(string_value) + except: + raise WrongAttributeTypeError(self.name, "float") + return float_value + + +class StringAttribute(Attribute): + def __init__(self, name: str, default_value: str, value_choices: List[str]): + super(StringAttribute, self).__init__(name) + self.__default_value = default_value + self.__value_choices = value_choices + + def get_default_value(self): + return self.__default_value + + def validate_value(self, value): + if value in self.__value_choices: + return True + raise StringRangeException(self._name, value, self.__value_choices) + + def parse(self, string_value: str): + return string_value + + +class BooleanAttribute(Attribute): + def __init__(self, name: str, default_value: bool): + super(BooleanAttribute, self).__init__(name) + self.__default_value = default_value + + def get_default_value(self): + return self.__default_value + + def validate_value(self, value): + if isinstance(value, bool): + return True + raise WrongAttributeTypeError(self._name, "bool") + + def parse(self, string_value: str): + if string_value.lower() == "true": + return True + elif string_value.lower() == "false": + return False + else: + raise WrongAttributeTypeError(self.name, "bool") + + +class ListAttribute(Attribute): + def __init__(self, name: str, default_value: List, value_type): + """ + value_type is the type of the elements in the list, e.g. int, float, str + """ + super(ListAttribute, self).__init__(name) + self.__default_value = default_value + self.__value_type = value_type + self.__type_to_str = {str: "str", int: "int", float: "float"} + + def get_default_value(self): + return self.__default_value + + def validate_value(self, value): + if not isinstance(value, list): + raise WrongAttributeTypeError(self._name, "list") + for value_item in value: + if not isinstance(value_item, self.__value_type): + raise WrongAttributeTypeError(self._name, self.__value_type) + return True + + def parse(self, string_value: str): + try: + list_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "list") + if not isinstance(list_value, list): + raise WrongAttributeTypeError(self.name, "list") + for i in range(len(list_value)): + try: + list_value[i] = self.__value_type(list_value[i]) + except: + raise ListRangeException(self.name, list_value, self.__type_to_str[self.__value_type]) + return list_value + + +class TupleAttribute(Attribute): + def __init__(self, name: str, default_value: tuple, value_type): + """ + value_type is the type of the elements in the list, e.g. int, float, str + """ + super(TupleAttribute, self).__init__(name) + self.__default_value = default_value + self.__value_type = value_type + self.__type_to_str = {str: "str", int: "int", float: "float"} + + def get_default_value(self): + return self.__default_value + + def validate_value(self, value): + if not isinstance(value, tuple): + raise WrongAttributeTypeError(self._name, "tuple") + for value_item in value: + if not isinstance(value_item, self.__value_type): + raise WrongAttributeTypeError(self._name, self.__value_type) + return True + + def parse(self, string_value: str): + try: + tuple_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "tuple") + if not isinstance(tuple_value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + list_value = list(tuple_value) + for i in range(len(list_value)): + try: + list_value[i] = self.__value_type(list_value[i]) + except: + raise ListRangeException(self.name, list_value, self.__type_to_str[self.__value_type]) + tuple_value = tuple(list_value) + return tuple_value + + +def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, Attribute]): + """ + Args: + input_attributes: a dict of attributes, where the key is the attribute name, the value is the string value of + the attribute + attribute_map: a dict of hyperparameters, where the key is the attribute name, the value is the Attribute + object + Returns: + a dict of attributes, where the key is the attribute name, the value is the parsed value of the attribute + """ + attributes = {} + for attribute_name in attribute_map: + # user specified the attribute + if attribute_name in input_attributes: + attribute = attribute_map[attribute_name] + value = attribute.parse(input_attributes[attribute_name]) + attribute.validate_value(value) + attributes[attribute_name] = value + # user did not specify the attribute, use the default value + else: + try: + attributes[attribute_name] = attribute_map[attribute_name].get_default_value() + except: + print("attribute_name: ", attribute_name) + return attributes + + +# built-in sktime model attributes +# NaiveForecaster +naive_forecaster_attribute_map = { + AttributeName.PREDICT_LENGTH.value: IntAttribute( + name=AttributeName.PREDICT_LENGTH.value, + default_value=1, + default_low=1, + default_high=5000 + ), + AttributeName.STRATEGY.value: StringAttribute( + name=AttributeName.STRATEGY.value, + default_value="last", + value_choices=["last", "mean"], + ), + AttributeName.SP.value: IntAttribute( + name=AttributeName.SP.value, + default_value=1, + default_low=1, + default_high=5000 + ), +} +# ExponentialSmoothing +exponential_smoothing_attribute_map = { + AttributeName.PREDICT_LENGTH.value: IntAttribute( + name=AttributeName.PREDICT_LENGTH.value, + default_value=1, + default_low=1, + default_high=5000 + ), + AttributeName.DAMPED_TREND.value: BooleanAttribute( + name=AttributeName.DAMPED_TREND.value, + default_value=False, + ), + AttributeName.INITIALIZATION_METHOD.value: StringAttribute( + name=AttributeName.INITIALIZATION_METHOD.value, + default_value="estimated", + value_choices=["estimated", "heuristic", "legacy-heuristic", "known"], + ), + AttributeName.OPTIMIZED.value: BooleanAttribute( + name=AttributeName.OPTIMIZED.value, + default_value=True, + ), + AttributeName.REMOVE_BIAS.value: BooleanAttribute( + name=AttributeName.REMOVE_BIAS.value, + default_value=False, + ), + AttributeName.USE_BRUTE.value: BooleanAttribute( + name=AttributeName.USE_BRUTE.value, + default_value=False, + ) +} +# Arima +arima_attribute_map = { + AttributeName.PREDICT_LENGTH.value: IntAttribute( + name=AttributeName.PREDICT_LENGTH.value, + default_value=1, + default_low=1, + default_high=5000 + ), + AttributeName.ORDER.value: TupleAttribute( + name=AttributeName.ORDER.value, + default_value=(1, 0, 0), + value_type=int + ), + AttributeName.SEASONAL_ORDER.value: TupleAttribute( + name=AttributeName.SEASONAL_ORDER.value, + default_value=(0, 0, 0, 0), + value_type=int + ), + AttributeName.METHOD.value: StringAttribute( + name=AttributeName.METHOD.value, + default_value="lbfgs", + value_choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"], + ), + AttributeName.MAXITER.value: IntAttribute( + name=AttributeName.MAXITER.value, + default_value=1, + default_low=1, + default_high=5000 + ), + AttributeName.SUPPRESS_WARNINGS.value: BooleanAttribute( + name=AttributeName.SUPPRESS_WARNINGS.value, + default_value=True, + ), + AttributeName.OUT_OF_SAMPLE_SIZE.value: IntAttribute( + name=AttributeName.OUT_OF_SAMPLE_SIZE.value, + default_value=0, + default_low=0, + default_high=5000 + ), + AttributeName.SCORING.value: StringAttribute( + name=AttributeName.SCORING.value, + default_value="mse", + value_choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"], + ), + AttributeName.WITH_INTERCEPT.value: BooleanAttribute( + name=AttributeName.WITH_INTERCEPT.value, + default_value=True, + ), + AttributeName.TIME_VARYING_REGRESSION.value: BooleanAttribute( + name=AttributeName.TIME_VARYING_REGRESSION.value, + default_value=False, + ), + AttributeName.ENFORCE_STATIONARITY.value: BooleanAttribute( + name=AttributeName.ENFORCE_STATIONARITY.value, + default_value=True, + ), + AttributeName.ENFORCE_INVERTIBILITY.value: BooleanAttribute( + name=AttributeName.ENFORCE_INVERTIBILITY.value, + default_value=True, + ), + AttributeName.SIMPLE_DIFFERENCING.value: BooleanAttribute( + name=AttributeName.SIMPLE_DIFFERENCING.value, + default_value=False, + ), + AttributeName.MEASUREMENT_ERROR.value: BooleanAttribute( + name=AttributeName.MEASUREMENT_ERROR.value, + default_value=False, + ), + AttributeName.MLE_REGRESSION.value: BooleanAttribute( + name=AttributeName.MLE_REGRESSION.value, + default_value=True, + ), + AttributeName.HAMILTON_REPRESENTATION.value: BooleanAttribute( + name=AttributeName.HAMILTON_REPRESENTATION.value, + default_value=False, + ), + AttributeName.CONCENTRATE_SCALE.value: BooleanAttribute( + name=AttributeName.CONCENTRATE_SCALE.value, + default_value=False, + ) +} +# STLForecaster +stl_forecaster_attribute_map = { + AttributeName.PREDICT_LENGTH.value: IntAttribute( + name=AttributeName.PREDICT_LENGTH.value, + default_value=1, + default_low=1, + default_high=5000 + ), + AttributeName.SP.value: IntAttribute( + name=AttributeName.SP.value, + default_value=2, + default_low=1, + default_high=5000 + ), + AttributeName.SEASONAL.value: IntAttribute( + name=AttributeName.SEASONAL.value, + default_value=7, + default_low=1, + default_high=5000 + ), + AttributeName.SEASONAL_DEG.value: IntAttribute( + name=AttributeName.SEASONAL_DEG.value, + default_value=1, + default_low=0, + default_high=5000 + ), + AttributeName.TREND_DEG.value: IntAttribute( + name=AttributeName.TREND_DEG.value, + default_value=1, + default_low=0, + default_high=5000 + ), + AttributeName.LOW_PASS_DEG.value: IntAttribute( + name=AttributeName.LOW_PASS_DEG.value, + default_value=1, + default_low=0, + default_high=5000 + ), + AttributeName.SEASONAL_JUMP.value: IntAttribute( + name=AttributeName.SEASONAL_JUMP.value, + default_value=1, + default_low=0, + default_high=5000 + ), + AttributeName.TREND_JUMP.value: IntAttribute( + name=AttributeName.TREND_JUMP.value, + default_value=1, + default_low=0, + default_high=5000 + ), + AttributeName.LOSS_PASS_JUMP.value: IntAttribute( + name=AttributeName.LOSS_PASS_JUMP.value, + default_value=1, + default_low=0, + default_high=5000 + ), +} + +# GAUSSIAN_HMM +gaussian_hmm_attribute_map = { + AttributeName.N_COMPONENTS.value: IntAttribute( + name=AttributeName.N_COMPONENTS.value, + default_value=1, + default_low=1, + default_high=5000 + ), + AttributeName.COVARIANCE_TYPE.value: StringAttribute( + name=AttributeName.COVARIANCE_TYPE.value, + default_value="diag", + value_choices=["spherical", "diag", "full", "tied"], + ), + AttributeName.MIN_COVAR.value: FloatAttribute( + name=AttributeName.MIN_COVAR.value, + default_value=1e-3, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.STARTPROB_PRIOR.value: FloatAttribute( + name=AttributeName.STARTPROB_PRIOR.value, + default_value=1.0, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.TRANSMAT_PRIOR.value: FloatAttribute( + name=AttributeName.TRANSMAT_PRIOR.value, + default_value=1.0, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.MEANS_PRIOR.value: FloatAttribute( + name=AttributeName.MEANS_PRIOR.value, + default_value=0.0, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.MEANS_WEIGHT.value: FloatAttribute( + name=AttributeName.MEANS_WEIGHT.value, + default_value=0.0, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.COVARS_PRIOR.value: FloatAttribute( + name=AttributeName.COVARS_PRIOR.value, + default_value=1e-2, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.COVARS_WEIGHT.value: FloatAttribute( + name=AttributeName.COVARS_WEIGHT.value, + default_value=1.0, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.ALGORITHM.value: StringAttribute( + name=AttributeName.ALGORITHM.value, + default_value="viterbi", + value_choices=["viterbi", "map"], + ), + AttributeName.N_ITER.value: IntAttribute( + name=AttributeName.N_ITER.value, + default_value=10, + default_low=1, + default_high=5000 + ), + AttributeName.TOL.value: FloatAttribute( + name=AttributeName.TOL.value, + default_value=1e-2, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.PARAMS.value: StringAttribute( + name=AttributeName.PARAMS.value, + default_value="stmc", + value_choices=["stmc", "stm"], + ), + AttributeName.INIT_PARAMS.value: StringAttribute( + name=AttributeName.INIT_PARAMS.value, + default_value="stmc", + value_choices=["stmc", "stm"], + ), + AttributeName.IMPLEMENTATION.value: StringAttribute( + name=AttributeName.IMPLEMENTATION.value, + default_value="log", + value_choices=["log", "scaling"], + ) +} +# GMMHMM +gmmhmm_attribute_map = { + AttributeName.N_COMPONENTS.value: IntAttribute( + name=AttributeName.N_COMPONENTS.value, + default_value=1, + default_low=1, + default_high=5000 + ), + AttributeName.N_MIX.value: IntAttribute( + name=AttributeName.N_MIX.value, + default_value=1, + default_low=1, + default_high=5000 + ), + AttributeName.MIN_COVAR.value: FloatAttribute( + name=AttributeName.MIN_COVAR.value, + default_value=1e-3, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.STARTPROB_PRIOR.value: FloatAttribute( + name=AttributeName.STARTPROB_PRIOR.value, + default_value=1.0, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.TRANSMAT_PRIOR.value: FloatAttribute( + name=AttributeName.TRANSMAT_PRIOR.value, + default_value=1.0, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.WEIGHTS_PRIOR.value: FloatAttribute( + name=AttributeName.WEIGHTS_PRIOR.value, + default_value=1.0, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.MEANS_PRIOR.value: FloatAttribute( + name=AttributeName.MEANS_PRIOR.value, + default_value=0.0, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.MEANS_WEIGHT.value: FloatAttribute( + name=AttributeName.MEANS_WEIGHT.value, + default_value=0.0, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.ALGORITHM.value: StringAttribute( + name=AttributeName.ALGORITHM.value, + default_value="viterbi", + value_choices=["viterbi", "map"], + ), + AttributeName.COVARIANCE_TYPE.value: StringAttribute( + name=AttributeName.COVARIANCE_TYPE.value, + default_value="diag", + value_choices=["sperical", "diag", "full", "tied"], + ), + AttributeName.N_ITER.value: IntAttribute( + name=AttributeName.N_ITER.value, + default_value=10, + default_low=1, + default_high=5000 + ), + AttributeName.TOL.value: FloatAttribute( + name=AttributeName.TOL.value, + default_value=1e-2, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.INIT_PARAMS.value: StringAttribute( + name=AttributeName.INIT_PARAMS.value, + default_value="stmcw", + value_choices=["s", "t", "m", "c", "w", "st", "sm", "sc", "sw", "tm", "tc", "tw", "mc", "mw", "cw", "stm", + "stc", "stw", "smc", "smw", "scw", "tmc", "tmw", "tcw", "mcw", "stmc", "stmw", "stcw", "smcw", + "tmcw", "stmcw"] + ), + AttributeName.PARAMS.value: StringAttribute( + name=AttributeName.PARAMS.value, + default_value="stmcw", + value_choices=["s", "t", "m", "c", "w", "st", "sm", "sc", "sw", "tm", "tc", "tw", "mc", "mw", "cw", "stm", + "stc", "stw", "smc", "smw", "scw", "tmc", "tmw", "tcw", "mcw", "stmc", "stmw", "stcw", "smcw", + "tmcw", "stmcw"] + ), + AttributeName.IMPLEMENTATION.value: StringAttribute( + name=AttributeName.IMPLEMENTATION.value, + default_value="log", + value_choices=["log", "scaling"], + ) +} + +# STRAY +stray_attribute_map = { + AttributeName.ALPHA.value: FloatAttribute( + name=AttributeName.ALPHA.value, + default_value=0.01, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.K.value: IntAttribute( + name=AttributeName.K.value, + default_value=10, + default_low=1, + default_high=5000 + ), + AttributeName.KNN_ALGORITHM.value: StringAttribute( + name=AttributeName.KNN_ALGORITHM.value, + default_value="brute", + value_choices=["brute", "kd_tree", "ball_tree", "auto"], + ), + AttributeName.P.value: FloatAttribute( + name=AttributeName.P.value, + default_value=0.5, + default_low=-1e10, + default_high=1e10, + ), + AttributeName.SIZE_THRESHOLD.value: IntAttribute( + name=AttributeName.SIZE_THRESHOLD.value, + default_value=50, + default_low=1, + default_high=5000 + ), + AttributeName.OUTLIER_TAIL.value: StringAttribute( + name=AttributeName.OUTLIER_TAIL.value, + default_value="max", + value_choices=["min", "max"], + ), +} + + +def get_model_attributes(model_id: str): + if model_id == BuiltInModelType.ARIMA.value: + attribute_map = arima_attribute_map + elif model_id == BuiltInModelType.NAIVE_FORECASTER.value: + attribute_map = naive_forecaster_attribute_map + elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value: + attribute_map = exponential_smoothing_attribute_map + elif model_id == BuiltInModelType.STL_FORECASTER.value: + attribute_map = stl_forecaster_attribute_map + elif model_id == BuiltInModelType.GMM_HMM.value: + attribute_map = gmmhmm_attribute_map + elif model_id == BuiltInModelType.GAUSSIAN_HMM.value: + attribute_map = gaussian_hmm_attribute_map + elif model_id == BuiltInModelType.STRAY.value: + attribute_map = stray_attribute_map + else: + raise BuiltInModelNotSupportError(model_id) + return attribute_map diff --git a/iotdb-core/ainode/iotdb/ainode/client.py b/iotdb-core/ainode/iotdb/ainode/client.py new file mode 100644 index 000000000000..ecc5bb40f89e --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/client.py @@ -0,0 +1,285 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import time + +import pandas as pd +from thrift.Thrift import TException +from thrift.protocol import TCompactProtocol, TBinaryProtocol +from thrift.transport import TSocket, TTransport + +from iotdb.ainode import serde +from iotdb.ainode.config import descriptor +from iotdb.ainode.constant import TSStatusCode +from iotdb.ainode.log import logger +from iotdb.ainode.util import verify_success +from iotdb.thrift.common.ttypes import TEndPoint, TSStatus, TAINodeLocation, TAINodeConfiguration +from iotdb.thrift.confignode import IConfigNodeRPCService +from iotdb.thrift.confignode.ttypes import (TAINodeRemoveReq, TNodeVersionInfo, + TAINodeRegisterReq, TAINodeRestartReq) +from iotdb.thrift.datanode import IAINodeInternalRPCService +from iotdb.thrift.datanode.ttypes import (TFetchMoreDataReq, + TFetchTimeseriesReq) + + +class ClientManager(object): + def __init__(self): + self.__data_node_endpoint = descriptor.get_config().get_mn_target_data_node() + self.__config_node_endpoint = descriptor.get_config().get_ain_target_config_node_list() + + def borrow_data_node_client(self): + return DataNodeClient(host=self.__data_node_endpoint.ip, + port=self.__data_node_endpoint.port) + + def borrow_config_node_client(self): + return ConfigNodeClient(config_leader=self.__config_node_endpoint) + + +class DataNodeClient(object): + DEFAULT_FETCH_SIZE = 10000 + DEFAULT_TIMEOUT = 60000 + + def __init__(self, host, port): + self.__host = host + self.__port = port + + transport = TTransport.TFramedTransport( + TSocket.TSocket(self.__host, self.__port) + ) + if not transport.isOpen(): + try: + transport.open() + except TTransport.TTransportException as e: + logger.error("TTransportException: {}".format(e)) + raise e + + if descriptor.get_config().get_ain_thrift_compression_enabled(): + protocol = TCompactProtocol.TCompactProtocol(transport) + else: + protocol = TBinaryProtocol.TBinaryProtocol(transport) + self.__client = IAINodeInternalRPCService.Client(protocol) + + def fetch_timeseries(self, + query_body: str, + fetch_size: int = DEFAULT_FETCH_SIZE, + timeout: int = DEFAULT_TIMEOUT) -> pd.DataFrame: + req = TFetchTimeseriesReq( + queryBody=query_body, + fetchSize=fetch_size, + timeout=timeout + ) + try: + resp = self.__client.fetchTimeseries(req) + verify_success(resp.status, "An error occurs when calling fetch_timeseries()") + + if len(resp.tsDataset) == 0: + raise RuntimeError(f'No data fetched with sql: {query_body}') + data = serde.convert_to_df(resp.columnNameList, + resp.columnTypeList, + resp.columnNameIndexMap, + resp.tsDataset) + if data.empty: + raise RuntimeError( + f'Fetched empty data with sql: {query_body}') + except Exception as e: + logger.warning( + f'Fail to fetch data with sql: {query_body}') + raise e + query_id = resp.queryId + column_name_list = resp.columnNameList + column_type_list = resp.columnTypeList + column_name_index_map = resp.columnNameIndexMap + has_more_data = resp.hasMoreData + while has_more_data: + req = TFetchMoreDataReq(queryId=query_id, fetchSize=fetch_size) + try: + resp = self.__client.fetchMoreData(req) + verify_success(resp.status, "An error occurs when calling fetch_more_data()") + data = data.append(serde.convert_to_df(column_name_list, + column_type_list, + column_name_index_map, + resp.tsDataset)) + has_more_data = resp.hasMoreData + except Exception as e: + logger.warning( + f'Fail to fetch more data with query id: {query_id}') + raise e + return data + + +class ConfigNodeClient(object): + def __init__(self, config_leader: TEndPoint): + self.__config_leader = config_leader + self.__config_nodes = [] + self.__cursor = 0 + self.__transport = None + self.__client = None + + self.__MSG_RECONNECTION_FAIL = "Fail to connect to any config node. Please check status of ConfigNodes" + self.__RETRY_NUM = 5 + self.__RETRY_INTERVAL_MS = 1 + + self.__try_to_connect() + + def __try_to_connect(self) -> None: + if self.__config_leader is not None: + try: + self.__connect(self.__config_leader) + return + except TException: + logger.warning("The current node {} may have been down, try next node", self.__config_leader) + self.__config_leader = None + + if self.__transport is not None: + self.__transport.close() + + try_host_num = 0 + while try_host_num < len(self.__config_nodes): + self.__cursor = (self.__cursor + 1) % len(self.__config_nodes) + + try_endpoint = self.__config_nodes[self.__cursor] + try: + self.__connect(try_endpoint) + return + except TException: + logger.warning("The current node {} may have been down, try next node", try_endpoint) + + try_host_num = try_host_num + 1 + + raise TException(self.__MSG_RECONNECTION_FAIL) + + def __connect(self, target_config_node: TEndPoint) -> None: + transport = TTransport.TFramedTransport( + TSocket.TSocket(target_config_node.ip, target_config_node.port) + ) + if not transport.isOpen(): + try: + transport.open() + except TTransport.TTransportException as e: + logger.error("TTransportException: {}".format(e)) + raise e + + if descriptor.get_config().get_ain_thrift_compression_enabled(): + protocol = TCompactProtocol.TCompactProtocol(transport) + else: + protocol = TBinaryProtocol.TBinaryProtocol(transport) + self.__client = IConfigNodeRPCService.Client(protocol) + + def __wait_and_reconnect(self) -> None: + # wait to start the next try + time.sleep(self.__RETRY_INTERVAL_MS) + + try: + self.__try_to_connect() + except TException: + # can not connect to each config node + self.__sync_latest_config_node_list() + self.__try_to_connect() + + def __sync_latest_config_node_list(self) -> None: + # TODO + pass + + def __update_config_node_leader(self, status: TSStatus) -> bool: + if status.code == TSStatusCode.REDIRECTION_RECOMMEND.get_status_code(): + if status.redirectNode is not None: + self.__config_leader = status.redirectNode + else: + self.__config_leader = None + return True + return False + + def node_register(self, cluster_name: str, configuration: TAINodeConfiguration, + version_info: TNodeVersionInfo) -> int: + req = TAINodeRegisterReq( + clusterName=cluster_name, + aiNodeConfiguration=configuration, + versionInfo=version_info + ) + + for _ in range(0, self.__RETRY_NUM): + try: + resp = self.__client.registerAINode(req) + if not self.__update_config_node_leader(resp.status): + verify_success(resp.status, "An error occurs when calling node_register()") + self.__config_nodes = resp.configNodeList + return resp.aiNodeId + except TTransport.TException: + logger.warning("Failed to connect to ConfigNode {} from AINode when executing node_register()", + self.__config_leader) + self.__config_leader = None + self.__wait_and_reconnect() + + raise TException(self.__MSG_RECONNECTION_FAIL) + + def node_restart(self, cluster_name: str, configuration: TAINodeConfiguration, + version_info: TNodeVersionInfo) -> None: + req = TAINodeRestartReq( + clusterName=cluster_name, + aiNodeConfiguration=configuration, + versionInfo=version_info + ) + + for _ in range(0, self.__RETRY_NUM): + try: + resp = self.__client.restartAINode(req) + if not self.__update_config_node_leader(resp.status): + verify_success(resp.status, "An error occurs when calling node_restart()") + self.__config_nodes = resp.configNodeList + return resp.status + except TTransport.TException: + logger.warning("Failed to connect to ConfigNode {} from AINode when executing node_restart()", + self.__config_leader) + self.__config_leader = None + self.__wait_and_reconnect() + + raise TException(self.__MSG_RECONNECTION_FAIL) + + def node_remove(self, location: TAINodeLocation): + req = TAINodeRemoveReq( + aiNodeLocation=location + ) + for _ in range(0, self.__RETRY_NUM): + try: + status = self.__client.removeAINode(req) + if not self.__update_config_node_leader(status): + verify_success(status, "An error occurs when calling node_restart()") + return status + except TTransport.TException: + logger.warning("Failed to connect to ConfigNode {} from AINode when executing node_remove()", + self.__config_leader) + self.__config_leader = None + self.__wait_and_reconnect() + raise TException(self.__MSG_RECONNECTION_FAIL) + + def get_ainode_configuration(self, node_id: int) -> map: + for _ in range(0, self.__RETRY_NUM): + try: + resp = self.__client.getAINodeConfiguration(node_id) + if not self.__update_config_node_leader(resp.status): + verify_success(resp.status, "An error occurs when calling get_ainode_configuration()") + return resp.aiNodeConfigurationMap + except TTransport.TException: + logger.warning("Failed to connect to ConfigNode {} from AINode when executing " + "get_ainode_configuration()", + self.__config_leader) + self.__config_leader = None + self.__wait_and_reconnect() + raise TException(self.__MSG_RECONNECTION_FAIL) + + +client_manager = ClientManager() diff --git a/iotdb-core/ainode/iotdb/ainode/config.py b/iotdb-core/ainode/iotdb/ainode/config.py new file mode 100644 index 000000000000..8caa5cc1f081 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/config.py @@ -0,0 +1,276 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import os + +from iotdb.ainode.constant import (AINODE_CONF_DIRECTORY_NAME, + AINODE_CONF_FILE_NAME, + AINODE_MODELS_DIR, AINODE_LOG_DIR, AINODE_SYSTEM_DIR, AINODE_INFERENCE_RPC_ADDRESS, + AINODE_INFERENCE_RPC_PORT, AINODE_THRIFT_COMPRESSION_ENABLED, + AINODE_SYSTEM_FILE_NAME, AINODE_CLUSTER_NAME, AINODE_VERSION_INFO, AINODE_BUILD_INFO, + AINODE_CONF_GIT_FILE_NAME, AINODE_CONF_POM_FILE_NAME, AINODE_ROOT_DIR, + AINODE_ROOT_CONF_DIRECTORY_NAME) +from iotdb.ainode.exception import BadNodeUrlError +from iotdb.ainode.log import logger, set_logger +from iotdb.ainode.util import parse_endpoint_url +from iotdb.thrift.common.ttypes import TEndPoint + + +class AINodeConfig(object): + def __init__(self): + # Used for connection of DataNode/ConfigNode clients + self.__ain_inference_rpc_address: str = AINODE_INFERENCE_RPC_ADDRESS + self.__ain_inference_rpc_port: int = AINODE_INFERENCE_RPC_PORT + + # log directory + self.__ain_logs_dir: str = AINODE_LOG_DIR + + # Directory to save models + self.__ain_models_dir = AINODE_MODELS_DIR + + self.__ain_system_dir = AINODE_SYSTEM_DIR + + # Whether to enable compression for thrift + self.__ain_thrift_compression_enabled = AINODE_THRIFT_COMPRESSION_ENABLED + + # Cache number of model storage to avoid repeated loading + self.__mn_model_storage_cache_size = 30 + + # Maximum number of training model tasks, otherwise the task is pending + self.__mn_task_pool_size = 1 + + # Maximum number of trials to be explored in a tuning task + self.__mn_tuning_trial_num = 20 + + # Concurrency of trials in a tuning task + self.__mn_tuning_trial_concurrency = 4 + + # Target ConfigNode to be connected by AINode + self.__ain_target_config_node_list: TEndPoint = TEndPoint("127.0.0.1", 10710) + + # Target DataNode to be connected by AINode + self.__mn_target_data_node: TEndPoint = TEndPoint("127.0.0.1", 10780) + + # use for node management + self.__ainode_id = 0 + self.__cluster_name = AINODE_CLUSTER_NAME + + self.__version_info = AINODE_VERSION_INFO + self.__build_info = AINODE_BUILD_INFO + + def get_cluster_name(self) -> str: + return self.__cluster_name + + def set_cluster_name(self, cluster_name: str) -> None: + self.__cluster_name = cluster_name + + def get_version_info(self) -> str: + return self.__version_info + + def get_ainode_id(self) -> int: + return self.__ainode_id + + def set_ainode_id(self, id: int) -> None: + self.__ainode_id = id + + def get_build_info(self) -> str: + return self.__build_info + + def set_build_info(self, build_info: str) -> None: + self.__build_info = build_info + + def set_version_info(self, version_info: str) -> None: + self.__version_info = version_info + + def get_ain_inference_rpc_address(self) -> str: + return self.__ain_inference_rpc_address + + def set_ain_inference_rpc_address(self, ain_inference_rpc_address: str) -> None: + self.__ain_inference_rpc_address = ain_inference_rpc_address + + def get_ain_inference_rpc_port(self) -> int: + return self.__ain_inference_rpc_port + + def set_ain_inference_rpc_port(self, ain_inference_rpc_port: int) -> None: + self.__ain_inference_rpc_port = ain_inference_rpc_port + + def get_ain_logs_dir(self) -> str: + return self.__ain_logs_dir + + def set_ain_logs_dir(self, ain_logs_dir: str) -> None: + self.__ain_logs_dir = ain_logs_dir + + def get_ain_models_dir(self) -> str: + return self.__ain_models_dir + + def set_ain_models_dir(self, ain_models_dir: str) -> None: + self.__ain_models_dir = ain_models_dir + + def get_ain_system_dir(self) -> str: + return self.__ain_system_dir + + def set_ain_system_dir(self, ain_system_dir: str) -> None: + self.__ain_system_dir = ain_system_dir + + def get_ain_thrift_compression_enabled(self) -> bool: + return self.__ain_thrift_compression_enabled + + def set_ain_thrift_compression_enabled(self, ain_thrift_compression_enabled: int) -> None: + self.__ain_thrift_compression_enabled = ain_thrift_compression_enabled + + def get_mn_model_storage_cache_size(self) -> int: + return self.__mn_model_storage_cache_size + + def set_mn_model_storage_cache_size(self, mn_model_storage_cache_size: int) -> None: + self.__mn_model_storage_cache_size = mn_model_storage_cache_size + + def get_mn_mn_task_pool_size(self) -> int: + return self.__mn_task_pool_size + + def set_mn_task_pool_size(self, mn_task_pool_size: int) -> None: + self.__mn_task_pool_size = mn_task_pool_size + + def get_mn_tuning_trial_num(self) -> int: + return self.__mn_tuning_trial_num + + def set_mn_tuning_trial_num(self, mn_tuning_trial_num: int) -> None: + self.__mn_tuning_trial_num = mn_tuning_trial_num + + def get_mn_tuning_trial_concurrency(self) -> int: + return self.__mn_tuning_trial_concurrency + + def set_mn_tuning_trial_concurrency(self, mn_tuning_trial_concurrency: int) -> None: + self.__mn_tuning_trial_concurrency = mn_tuning_trial_concurrency + + def get_ain_target_config_node_list(self) -> TEndPoint: + return self.__ain_target_config_node_list + + def set_ain_target_config_node_list(self, ain_target_config_node_list: str) -> None: + self.__ain_target_config_node_list = parse_endpoint_url(ain_target_config_node_list) + + def get_mn_target_data_node(self) -> TEndPoint: + return self.__mn_target_data_node + + def set_mn_target_data_node(self, mn_target_data_node: str) -> None: + self.__mn_target_data_node = parse_endpoint_url(mn_target_data_node) + + +class AINodeDescriptor(object): + _instance = None + _first_init = False + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if not self._first_init: + self.__config = AINodeConfig() + self.load_config_from_file() + self._first_init = True + + def load_properties(self, filepath, sep='=', comment_char='#'): + """ + Read the file passed as parameter as a properties file. + """ + props = {} + with open(filepath, "rt") as f: + for line in f: + l = line.strip() + if l and not l.startswith(comment_char): + key_value = l.split(sep) + key = key_value[0].strip() + value = sep.join(key_value[1:]).strip().strip('"') + props[key] = value + return props + + def load_config_from_file(self) -> None: + system_properties_file = os.path.join(self.__config.get_ain_system_dir(), AINODE_SYSTEM_FILE_NAME) + if os.path.exists(system_properties_file): + system_configs = self.load_properties(system_properties_file) + if 'ainode_id' in system_configs: + self.__config.set_ainode_id(int(system_configs['ainode_id'])) + + git_file = os.path.join(AINODE_ROOT_DIR, AINODE_ROOT_CONF_DIRECTORY_NAME, AINODE_CONF_GIT_FILE_NAME) + if os.path.exists(git_file): + git_configs = self.load_properties(git_file) + if 'git.commit.id.abbrev' in git_configs: + build_info = git_configs['git.commit.id.abbrev'] + if 'git.dirty' in git_configs: + if git_configs['git.dirty'] == "true": + build_info += "-dev" + self.__config.set_build_info(build_info) + + pom_file = os.path.join(AINODE_ROOT_DIR, AINODE_ROOT_CONF_DIRECTORY_NAME, AINODE_CONF_POM_FILE_NAME) + if os.path.exists(pom_file): + pom_configs = self.load_properties(pom_file) + if 'version' in pom_configs: + self.__config.set_version_info(pom_configs['version']) + + conf_file = os.path.join(AINODE_CONF_DIRECTORY_NAME, AINODE_CONF_FILE_NAME) + if not os.path.exists(conf_file): + logger.info("Cannot find AINode config file '{}', use default configuration.".format(conf_file)) + return + + logger.info("Start to read AINode config file '{}'".format(conf_file)) + + # noinspection PyBroadException + try: + file_configs = self.load_properties(conf_file) + + config_keys = file_configs.keys() + + if 'ain_inference_rpc_address' in config_keys: + self.__config.set_ain_inference_rpc_address(file_configs['ain_inference_rpc_address']) + + if 'ain_inference_rpc_port' in config_keys: + self.__config.set_ain_inference_rpc_port(int(file_configs['ain_inference_rpc_port'])) + + if 'ain_logs_dir' in config_keys: + self.__config.set_ain_logs_dir(file_configs['ain_logs_dir']) + + set_logger(self.__config.get_ain_logs_dir()) + + if 'ain_models_dir' in config_keys: + self.__config.set_ain_models_dir(file_configs['ain_models_dir']) + + if 'ain_system_dir' in config_keys: + self.__config.set_ain_system_dir(file_configs['ain_system_dir']) + + if 'ain_seed_config_node' in config_keys: + self.__config.set_ain_target_config_node_list(file_configs['ain_seed_config_node']) + + if 'cluster_name' in config_keys: + self.__config.set_cluster_name(file_configs['cluster_name']) + + # AINODE_THRIFT_COMPRESSION_ENABLED + if 'ain_thrift_compression_enabled' in config_keys: + self.__config.set_ain_thrift_compression_enabled(int(file_configs['ain_thrift_compression_enabled'])) + + except BadNodeUrlError: + logger.warning("Cannot load AINode conf file, use default configuration.") + + except Exception as e: + logger.warning("Cannot load AINode conf file caused by: {}, use default configuration. ".format(e)) + + def get_config(self) -> AINodeConfig: + return self.__config + + +# initialize a singleton +descriptor = AINodeDescriptor() diff --git a/iotdb-core/ainode/iotdb/ainode/constant.py b/iotdb-core/ainode/iotdb/ainode/constant.py new file mode 100644 index 000000000000..5112c98d67d7 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/constant.py @@ -0,0 +1,247 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import inspect +import logging +import os +from enum import Enum +from typing import List + +AINODE_CONF_DIRECTORY_NAME = "conf" +AINODE_ROOT_CONF_DIRECTORY_NAME = "conf" +AINODE_CONF_FILE_NAME = "iotdb-ainode.properties" +AINODE_CONF_GIT_FILE_NAME = "git.properties" +AINODE_CONF_POM_FILE_NAME = "pom.properties" +AINODE_SYSTEM_FILE_NAME = "system.properties" +# inference_rpc_address +AINODE_INFERENCE_RPC_ADDRESS = "127.0.0.1" +AINODE_INFERENCE_RPC_PORT = 10810 +AINODE_MODELS_DIR = "data/ainode/models" +AINODE_SYSTEM_DIR = "data/ainode/system" +AINODE_LOG_DIR = "logs/ainode" +AINODE_THRIFT_COMPRESSION_ENABLED = False +# use for node management +AINODE_CLUSTER_NAME = "defaultCluster" +AINODE_VERSION_INFO = "UNKNOWN" +AINODE_BUILD_INFO = "UNKNOWN" +AINODE_ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))) + +# AINode log +AINODE_LOG_FILE_NAMES = ['log_ainode_all.log', + 'log_ainode_info.log', + 'log_ainode_warning.log', + 'log_ainode_error.log'] +AINODE_LOG_FILE_LEVELS = [ + logging.DEBUG, + logging.INFO, + logging.WARNING, + logging.ERROR] + +TRIAL_ID_PREFIX = "__trial_" +DEFAULT_TRIAL_ID = TRIAL_ID_PREFIX + "0" +DEFAULT_MODEL_FILE_NAME = "model.pt" +DEFAULT_CONFIG_FILE_NAME = "config.yaml" +DEFAULT_CHUNK_SIZE = 8192 + +DEFAULT_RECONNECT_TIMEOUT = 20 +DEFAULT_RECONNECT_TIMES = 3 + +STD_LEVEL = logging.INFO + + +class TSStatusCode(Enum): + SUCCESS_STATUS = 200 + REDIRECTION_RECOMMEND = 400 + AINODE_INTERNAL_ERROR = 1510 + INVALID_URI_ERROR = 1511 + INVALID_INFERENCE_CONFIG = 1512 + INFERENCE_INTERNAL_ERROR = 1520 + + def get_status_code(self) -> int: + return self.value + + +class TaskType(Enum): + FORECAST = "forecast" + + +class OptionsKey(Enum): + # common + TASK_TYPE = "task_type" + MODEL_TYPE = "model_type" + AUTO_TUNING = "auto_tuning" + INPUT_VARS = "input_vars" + + # forecast + INPUT_LENGTH = "input_length" + PREDICT_LENGTH = "predict_length" + PREDICT_INDEX_LIST = "predict_index_list" + INPUT_TYPE_LIST = "input_type_list" + + def name(self) -> str: + return self.value + + +class HyperparameterName(Enum): + # Training hyperparameter + LEARNING_RATE = "learning_rate" + EPOCHS = "epochs" + BATCH_SIZE = "batch_size" + USE_GPU = "use_gpu" + NUM_WORKERS = "num_workers" + + # Structure hyperparameter + KERNEL_SIZE = "kernel_size" + INPUT_VARS = "input_vars" + BLOCK_TYPE = "block_type" + D_MODEL = "d_model" + INNER_LAYERS = "inner_layer" + OUTER_LAYERS = "outer_layer" + + def name(self): + return self.value + + +class ForecastModelType(Enum): + DLINEAR = "dlinear" + DLINEAR_INDIVIDUAL = "dlinear_individual" + NBEATS = "nbeats" + + @classmethod + def values(cls) -> List[str]: + values = [] + for item in list(cls): + values.append(item.value) + return values + + +class ModelInputName(Enum): + DATA_X = "data_x" + TIME_STAMP_X = "time_stamp_x" + TIME_STAMP_Y = "time_stamp_y" + DEC_INP = "dec_inp" + + +class BuiltInModelType(Enum): + # forecast models + ARIMA = "_arima" + EXPONENTIAL_SMOOTHING = "_exponentialsmoothing" + NAIVE_FORECASTER = "_naiveforecaster" + STL_FORECASTER = "_stlforecaster" + + # anomaly detection models + GAUSSIAN_HMM = "_gaussianhmm" + GMM_HMM = "_gmmhmm" + STRAY = "_stray" + + @classmethod + def values(cls) -> List[str]: + values = [] + for item in list(cls): + values.append(item.value) + return values + + +class AttributeName(Enum): + # forecast Attribute + PREDICT_LENGTH = "predict_length" + + # NaiveForecaster + STRATEGY = 'strategy' + SP = 'sp' + + # STLForecaster + # SP = 'sp' + SEASONAL = 'seasonal' + SEASONAL_DEG = 'seasonal_deg' + TREND_DEG = 'trend_deg' + LOW_PASS_DEG = 'low_pass_deg' + SEASONAL_JUMP = 'seasonal_jump' + TREND_JUMP = 'trend_jump' + LOSS_PASS_JUMP = 'low_pass_jump' + + # ExponentialSmoothing + DAMPED_TREND = 'damped_trend' + INITIALIZATION_METHOD = 'initialization_method' + OPTIMIZED = 'optimized' + REMOVE_BIAS = 'remove_bias' + USE_BRUTE = 'use_brute' + + # Arima + ORDER = "order" + SEASONAL_ORDER = "seasonal_order" + METHOD = "method" + MAXITER = "maxiter" + SUPPRESS_WARNINGS = "suppress_warnings" + OUT_OF_SAMPLE_SIZE = "out_of_sample_size" + SCORING = "scoring" + WITH_INTERCEPT = "with_intercept" + TIME_VARYING_REGRESSION = "time_varying_regression" + ENFORCE_STATIONARITY = "enforce_stationarity" + ENFORCE_INVERTIBILITY = "enforce_invertibility" + SIMPLE_DIFFERENCING = "simple_differencing" + MEASUREMENT_ERROR = "measurement_error" + MLE_REGRESSION = "mle_regression" + HAMILTON_REPRESENTATION = "hamilton_representation" + CONCENTRATE_SCALE = "concentrate_scale" + + # GAUSSIAN_HMM + N_COMPONENTS = "n_components" + COVARIANCE_TYPE = "covariance_type" + MIN_COVAR = "min_covar" + STARTPROB_PRIOR = "startprob_prior" + TRANSMAT_PRIOR = "transmat_prior" + MEANS_PRIOR = "means_prior" + MEANS_WEIGHT = "means_weight" + COVARS_PRIOR = "covars_prior" + COVARS_WEIGHT = "covars_weight" + ALGORITHM = "algorithm" + N_ITER = "n_iter" + TOL = "tol" + PARAMS = "params" + INIT_PARAMS = "init_params" + IMPLEMENTATION = "implementation" + + # GMMHMM + # N_COMPONENTS = "n_components" + N_MIX = "n_mix" + # MIN_COVAR = "min_covar" + # STARTPROB_PRIOR = "startprob_prior" + # TRANSMAT_PRIOR = "transmat_prior" + WEIGHTS_PRIOR = "weights_prior" + + # MEANS_PRIOR = "means_prior" + # MEANS_WEIGHT = "means_weight" + # ALGORITHM = "algorithm" + # COVARIANCE_TYPE = "covariance_type" + # N_ITER = "n_iter" + # TOL = "tol" + # INIT_PARAMS = "init_params" + # PARAMS = "params" + # IMPLEMENTATION = "implementation" + + # STRAY + ALPHA = "alpha" + K = "k" + KNN_ALGORITHM = "knn_algorithm" + P = "p" + SIZE_THRESHOLD = "size_threshold" + OUTLIER_TAIL = "outlier_tail" + + + def name(self) -> str: + return self.value diff --git a/iotdb-core/ainode/iotdb/ainode/encryption.py b/iotdb-core/ainode/iotdb/ainode/encryption.py new file mode 100644 index 000000000000..2a1e720805f2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/encryption.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/exception.py b/iotdb-core/ainode/iotdb/ainode/exception.py new file mode 100644 index 000000000000..133c9741a6cb --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/exception.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from iotdb.ainode.constant import DEFAULT_MODEL_FILE_NAME, DEFAULT_CONFIG_FILE_NAME + + +class _BaseError(Exception): + """Base class for exceptions in this module.""" + + def __init__(self): + self.message = None + + def __str__(self) -> str: + return self.message + + +class BadNodeUrlError(_BaseError): + def __init__(self, node_url: str): + self.message = "Bad node url: {}".format(node_url) + + +class ModelNotExistError(_BaseError): + def __init__(self, file_path: str): + self.message = "Model path is not exists: {} ".format(file_path) + + +class BadConfigValueError(_BaseError): + def __init__(self, config_name: str, config_value, hint: str = ''): + self.message = "Bad value [{0}] for config {1}. {2}".format(config_value, config_name, hint) + + +class MissingConfigError(_BaseError): + def __init__(self, config_name: str): + self.message = "Missing config: {}".format(config_name) + + +class MissingOptionError(_BaseError): + def __init__(self, config_name: str): + self.message = "Missing task option: {}".format(config_name) + + +class RedundantOptionError(_BaseError): + def __init__(self, option_name: str): + self.message = "Redundant task option: {}".format(option_name) + + +class WrongTypeConfigError(_BaseError): + def __init__(self, config_name: str, expected_type: str): + self.message = "Wrong type for config: {0}, expected: {1}".format(config_name, expected_type) + + +class UnsupportedError(_BaseError): + def __init__(self, msg: str): + self.message = "{0} is not supported in current version".format(msg) + + +class InvaildUriError(_BaseError): + def __init__(self, uri: str): + self.message = "Invalid uri: {}, there are no {} or {} under this uri.".format(uri, DEFAULT_MODEL_FILE_NAME, + DEFAULT_CONFIG_FILE_NAME) + + +class InvalidWindowArgumentError(_BaseError): + def __init__( + self, + window_interval: int, + window_step: int, + dataset_length: int): + self.message = "Invalid inference input: window_interval {0}, window_step {1}, dataset_length {2}".format( + window_interval, window_step, dataset_length) + + +class InferenceModelInternalError(_BaseError): + def __init__(self, msg: str): + self.message = "Inference model internal error: {0}".format(msg) + + +class BuiltInModelNotSupportError(_BaseError): + def __init__(self, msg: str): + self.message = "Built-in model not support: {0}".format(msg) + + +class WrongAttributeTypeError(_BaseError): + def __init__(self, attribute_name: str, expected_type: str): + self.message = "Wrong type for attribute: {0}, expected: {1}".format(attribute_name, expected_type) + + +class NumericalRangeException(_BaseError): + def __init__(self, attribute_name: str, value, min_value, max_value): + self.message = "Attribute {0} expect value between {1} and {2}, got {3} instead." \ + .format(attribute_name, min_value, max_value, value) + + +class StringRangeException(_BaseError): + def __init__(self, attribute_name: str, value: str, expect_value): + self.message = "Attribute {0} expect value in {1}, got {2} instead." \ + .format(attribute_name, expect_value, value) + + +class ListRangeException(_BaseError): + def __init__(self, attribute_name: str, value: list, expected_type: str): + self.message = "Attribute {0} expect value type list[{1}], got {2} instead." \ + .format(attribute_name, expected_type, value) + + +class AttributeNotSupportError(_BaseError): + def __init__(self, model_name: str, attribute_name: str): + self.message = "Attribute {0} is not supported in model {1}".format(attribute_name, model_name) diff --git a/iotdb-core/ainode/iotdb/ainode/factory.py b/iotdb-core/ainode/iotdb/ainode/factory.py new file mode 100644 index 000000000000..158eeb624675 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/factory.py @@ -0,0 +1,272 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from abc import abstractmethod + +import numpy as np +from sklearn.preprocessing import MinMaxScaler +from sktime.annotation.hmm_learn import GaussianHMM, GMMHMM +from sktime.annotation.stray import STRAY +from sktime.forecasting.arima import ARIMA +from sktime.forecasting.exp_smoothing import ExponentialSmoothing +from sktime.forecasting.naive import NaiveForecaster +from sktime.forecasting.trend import STLForecaster + +from iotdb.ainode.attribute import get_model_attributes, parse_attribute +from iotdb.ainode.constant import BuiltInModelType +from iotdb.ainode.exception import AttributeNotSupportError, BuiltInModelNotSupportError, InferenceModelInternalError + + +class BuiltInModel(object): + def __init__(self, attributes): + self._attributes = attributes + self._model = None + + @abstractmethod + def inference(self, data): + raise NotImplementedError + + +class ArimaModel(BuiltInModel): + def __init__(self, attributes): + super(ArimaModel, self).__init__(attributes) + self._model = ARIMA( + order=attributes['order'], + seasonal_order=attributes['seasonal_order'], + method=attributes['method'], + suppress_warnings=attributes['suppress_warnings'], + maxiter=attributes['maxiter'], + out_of_sample_size=attributes['out_of_sample_size'], + scoring=attributes['scoring'], + with_intercept=attributes['with_intercept'], + time_varying_regression=attributes['time_varying_regression'], + enforce_stationarity=attributes['enforce_stationarity'], + enforce_invertibility=attributes['enforce_invertibility'], + simple_differencing=attributes['simple_differencing'], + measurement_error=attributes['measurement_error'], + mle_regression=attributes['mle_regression'], + hamilton_representation=attributes['hamilton_representation'], + concentrate_scale=attributes['concentrate_scale'] + ) + + def inference(self, data): + try: + predict_length = self._attributes['predict_length'] + self._model.fit(data) + output = self._model.predict(fh=range(predict_length)) + output = np.array(output, dtype=np.float64) + return output + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class ExponentialSmoothingModel(BuiltInModel): + def __init__(self, attributes): + super(ExponentialSmoothingModel, self).__init__(attributes) + self._model = ExponentialSmoothing( + damped_trend=attributes['damped_trend'], + initialization_method=attributes['initialization_method'], + optimized=attributes['optimized'], + remove_bias=attributes['remove_bias'], + use_brute=attributes['use_brute'] + ) + + def inference(self, data): + try: + predict_length = self._attributes['predict_length'] + self._model.fit(data) + output = self._model.predict(fh=range(predict_length)) + output = np.array(output, dtype=np.float64) + return output + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class NaiveForecasterModel(BuiltInModel): + def __init__(self, attributes): + super(NaiveForecasterModel, self).__init__(attributes) + self._model = NaiveForecaster( + strategy=attributes['strategy'], + sp=attributes['sp'] + ) + + def inference(self, data): + try: + predict_length = self._attributes['predict_length'] + self._model.fit(data) + output = self._model.predict(fh=range(predict_length)) + output = np.array(output, dtype=np.float64) + return output + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class STLForecasterModel(BuiltInModel): + def __init__(self, attributes): + super(STLForecasterModel, self).__init__(attributes) + self._model = STLForecaster( + sp=attributes['sp'], + seasonal=attributes['seasonal'], + seasonal_deg=attributes['seasonal_deg'], + trend_deg=attributes['trend_deg'], + low_pass_deg=attributes['low_pass_deg'], + seasonal_jump=attributes['seasonal_jump'], + trend_jump=attributes['trend_jump'], + low_pass_jump=attributes['low_pass_jump'] + ) + + def inference(self, data): + try: + predict_length = self._attributes['predict_length'] + self._model.fit(data) + output = self._model.predict(fh=range(predict_length)) + output = np.array(output, dtype=np.float64) + return output + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class GMMHMMModel(BuiltInModel): + def __init__(self, attributes): + super(GMMHMMModel, self).__init__(attributes) + self._model = GMMHMM( + n_components=attributes['n_components'], + n_mix=attributes['n_mix'], + min_covar=attributes['min_covar'], + startprob_prior=attributes['startprob_prior'], + transmat_prior=attributes['transmat_prior'], + means_prior=attributes['means_prior'], + means_weight=attributes['means_weight'], + weights_prior=attributes['weights_prior'], + algorithm=attributes['algorithm'], + covariance_type=attributes['covariance_type'], + n_iter=attributes['n_iter'], + tol=attributes['tol'], + params=attributes['params'], + init_params=attributes['init_params'], + implementation=attributes['implementation'] + ) + + def inference(self, data): + try: + self._model.fit(data) + output = self._model.predict(data) + output = np.array(output, dtype=np.int32) + return output + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class GaussianHmmModel(BuiltInModel): + def __init__(self, attributes): + super(GaussianHmmModel, self).__init__(attributes) + self._model = GaussianHMM( + n_components=attributes['n_components'], + covariance_type=attributes['covariance_type'], + min_covar=attributes['min_covar'], + startprob_prior=attributes['startprob_prior'], + transmat_prior=attributes['transmat_prior'], + means_prior=attributes['means_prior'], + means_weight=attributes['means_weight'], + covars_prior=attributes['covars_prior'], + covars_weight=attributes['covars_weight'], + algorithm=attributes['algorithm'], + n_iter=attributes['n_iter'], + tol=attributes['tol'], + params=attributes['params'], + init_params=attributes['init_params'], + implementation=attributes['implementation'] + ) + + def inference(self, data): + try: + self._model.fit(data) + output = self._model.predict(data) + output = np.array(output, dtype=np.int32) + return output + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class STRAYModel(BuiltInModel): + def __init__(self, attributes): + super(STRAYModel, self).__init__(attributes) + self._model = STRAY( + alpha=attributes['alpha'], + k=attributes['k'], + knn_algorithm=attributes['knn_algorithm'], + p=attributes['p'], + size_threshold=attributes['size_threshold'], + outlier_tail=attributes['outlier_tail'] + ) + + def inference(self, data): + try: + data = MinMaxScaler().fit_transform(data) + output = self._model.fit_transform(data) + # change the output to int + output = np.array(output, dtype=np.int32) + return output + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +def create_built_in_model(model_id, inference_attributes): + """ + Args: + model_id: the unique id of the model + inference_attributes: a list of attributes to be inferred, in this function, the attributes will include some + parameters of the built-in model. Some parameters are optional, and if the parameters are not + specified, the default value will be used. + Returns: + model: the built-in model from sktime + attributes: a dict of attributes, where the key is the attribute name, the value is the parsed value of the + attribute + Description: + the create_built_in_model function will create the built-in model from sktime, which does not require user + registration. This module will parse the inference attributes and create the built-in model. + """ + attribute_map = get_model_attributes(model_id) + + # validate the inference attributes + for attribute_name in inference_attributes: + if attribute_name not in attribute_map: + raise AttributeNotSupportError(model_id, attribute_name) + + # parse the inference attributes, attributes is a Dict[str, Any] + attributes = parse_attribute(inference_attributes, attribute_map) + + # build the built-in model + model = None + if model_id == BuiltInModelType.ARIMA.value: + model = ArimaModel(attributes) + elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value: + model = ExponentialSmoothingModel(attributes) + elif model_id == BuiltInModelType.NAIVE_FORECASTER.value: + model = NaiveForecasterModel(attributes) + elif model_id == BuiltInModelType.STL_FORECASTER.value: + model = STLForecasterModel(attributes) + elif model_id == BuiltInModelType.GMM_HMM.value: + model = GMMHMMModel(attributes) + elif model_id == BuiltInModelType.GAUSSIAN_HMM.value: + model = GaussianHmmModel(attributes) + elif model_id == BuiltInModelType.STRAY.value: + model = STRAYModel(attributes) + else: + raise BuiltInModelNotSupportError(model_id) + + return model, attributes diff --git a/iotdb-core/ainode/iotdb/ainode/handler.py b/iotdb-core/ainode/iotdb/ainode/handler.py new file mode 100644 index 000000000000..ea5bb8eb65b9 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/handler.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import psutil +from yaml import YAMLError + +from iotdb.ainode.constant import TSStatusCode +from iotdb.ainode.exception import InvaildUriError, BadConfigValueError +from iotdb.ainode.inference import inference_with_registered_model, inference_with_built_in_model +from iotdb.ainode.log import logger +from iotdb.ainode.parser import (parse_inference_request) +from iotdb.ainode.serde import convert_to_binary +from iotdb.ainode.storage import model_storage +from iotdb.ainode.util import get_status +from iotdb.thrift.ainode import IAINodeRPCService +from iotdb.thrift.ainode.ttypes import (TDeleteModelReq, TRegisterModelReq, + TRegisterModelResp, TAIHeartbeatReq, TAIHeartbeatResp, + TInferenceReq, TInferenceResp) +from iotdb.thrift.common.ttypes import TLoadSample + + +class AINodeRPCServiceHandler(IAINodeRPCService.Iface): + def __init__(self): + # for training, it's not open now. + self.__task_manager = None + + def registerModel(self, req: TRegisterModelReq): + logger.debug(f"register model {req.modelId} from {req.uri}") + try: + configs, attributes = model_storage.register_model(req.modelId, req.uri) + return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS), configs, attributes) + except InvaildUriError as e: + logger.warning(e) + model_storage.delete_model(req.modelId) + return TRegisterModelResp(get_status(TSStatusCode.INVALID_URI_ERROR, e.message)) + except BadConfigValueError as e: + logger.warning(e) + model_storage.delete_model(req.modelId) + return TRegisterModelResp(get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message)) + except YAMLError as e: + logger.warning(e) + model_storage.delete_model(req.modelId) + if hasattr(e, 'problem_mark'): + mark = e.problem_mark + return TRegisterModelResp(get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, + f"An error occurred while parsing the yaml file, " + f"at line {mark.line + 1} column {mark.column + 1}.")) + return TRegisterModelResp( + get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, f"An error occurred while parsing the yaml file")) + except Exception as e: + logger.warning(e) + model_storage.delete_model(req.modelId) + return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) + + def deleteModel(self, req: TDeleteModelReq): + logger.debug(f"delete model {req.modelId}") + try: + model_storage.delete_model(req.modelId) + return get_status(TSStatusCode.SUCCESS_STATUS) + except Exception as e: + logger.warning(e) + return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) + + def inference(self, req: TInferenceReq): + logger.info(f"infer {req.modelId}") + model_id, full_data, window_interval, window_step, inference_attributes = parse_inference_request( + req) + try: + if model_id.startswith('_'): + # built-in models + inference_results = inference_with_built_in_model( + model_id, full_data, inference_attributes) + else: + # user-registered models + inference_results = inference_with_registered_model( + model_id, full_data, window_interval, window_step, inference_attributes) + for i in range(len(inference_results)): + inference_results[i] = convert_to_binary(inference_results[i]) + return TInferenceResp( + get_status( + TSStatusCode.SUCCESS_STATUS), + inference_results) + except Exception as e: + logger.warning(e) + inference_results = [] + return TInferenceResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)), inference_results) + + def getAIHeartbeat(self, req: TAIHeartbeatReq): + if req.needSamplingLoad: + cpu_percent = psutil.cpu_percent(interval=1) + memory_percent = psutil.virtual_memory().percent + disk_usage = psutil.disk_usage('/') + disk_free = disk_usage.free + load_sample = TLoadSample(cpuUsageRate=cpu_percent, + memoryUsageRate=memory_percent, + diskUsageRate=disk_usage.percent, + freeDiskSpace=disk_free / 1024 / 1024 / 1024) + return TAIHeartbeatResp(heartbeatTimestamp=req.heartbeatTimestamp, + status="Running", + loadSample=load_sample) + else: + return TAIHeartbeatResp(heartbeatTimestamp=req.heartbeatTimestamp, + status="Running") diff --git a/iotdb-core/ainode/iotdb/ainode/inference.py b/iotdb-core/ainode/iotdb/ainode/inference.py new file mode 100644 index 000000000000..284b4c783a0f --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/inference.py @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import pandas as pd +from torch import tensor + +from iotdb.ainode.constant import BuiltInModelType +from iotdb.ainode.exception import InvalidWindowArgumentError, InferenceModelInternalError, \ + BuiltInModelNotSupportError +from iotdb.ainode.factory import create_built_in_model +from iotdb.ainode.log import logger +from iotdb.ainode.parser import runtime_error_extractor +from iotdb.ainode.storage import model_storage + + +def inference_with_registered_model(model_id, full_data, window_interval, window_step, inference_attributes): + """ + Args: + model_id: the unique id of the model + full_data: a tuple of (data, time_stamp, type_list, column_name_list), where the data is a DataFrame with shape + (L, C), time_stamp is a DataFrame with shape(L, 1), type_list is a list of data types with length C, + column_name_list is a list of column names with length C, where L is the number of data points, C is the + number of variables, the data and time_stamp are aligned by index + window_interval: the length of each sliding window + window_step: the step between two adjacent sliding windows + inference_attributes: a list of attributes to be inferred. In this function, the attributes will include the + acceleration, which indicates whether the model is accelerated by the torch. Compile + Returns: + outputs: a list of output DataFrames, where each DataFrame has shape (H', C'), where H' is the output window + interval, C' is the number of variables in the output DataFrame + Description: + the inference_with_registered_model function will inference with deep learning model, which is registered in + user register process. This module will split the input data into several sliding windows which has the same + shape (1, H, C), where H is the window interval, and then feed each sliding window into the model to get the + output, the output is a DataFrame with shape (H', C'), where H' is the output window interval, C' is the number + of variables in the output DataFrame. Then the inference module will concatenate all the output DataFrames into + a list. + """ + logger.info(f"start inference registered model {model_id}") + + # parse the inference attributes + acceleration = False + if inference_attributes is None or 'acceleration' not in inference_attributes: + # if the acceleration is not specified, then the acceleration will be set to default value False + acceleration = False + else: + # if the acceleration is specified, then the acceleration will be set to the specified value + acceleration = (inference_attributes['acceleration'].lower() == 'true') + + model = model_storage.load_model_from_id(model_id, acceleration) + dataset, dataset_length = process_data(full_data) + + # check the validity of window_interval and window_step, the two arguments must be positive integers, and the + # window_interval should not be larger than the dataset length + if window_interval is None or window_step is None \ + or window_interval > dataset_length \ + or window_interval <= 0 or \ + window_step <= 0: + raise InvalidWindowArgumentError(window_interval, window_step, dataset_length) + + sliding_times = int((dataset_length - window_interval) // window_step + 1) + outputs = [] + try: + # split the input data into several sliding windows + for sliding_time in range(sliding_times): + if window_step == float('inf'): + start_index = 0 + else: + start_index = sliding_time * window_step + end_index = start_index + window_interval + # input_data: tensor, shape: (1, H, C), where H is input window interval + input_data = dataset[:, start_index:end_index, :] + # output: tensor, shape: (1, H', C'), where H' is the output window interval + output = model(input_data) + # output: DataFrame, shape: (H', C') + output = pd.DataFrame(output.squeeze(0).detach().numpy()) + outputs.append(output) + except Exception as e: + error_msg = runtime_error_extractor(str(e)) + if error_msg != "": + raise InferenceModelInternalError(error_msg) + raise InferenceModelInternalError(str(e)) + + return outputs + + +def process_data(full_data): + """ + Args: + full_data: a tuple of (data, time_stamp, type_list, column_name_list), where the data is a DataFrame with shape + (L, C), time_stamp is a DataFrame with shape(L, 1), type_list is a list of data types with length C, + column_name_list is a list of column names with length C, where L is the number of data points, C is the + number of variables, the data and time_stamp are aligned by index + Returns: + data: a tensor with shape (1, L, C) + data_length: the number of data points + Description: + the process_data module will convert the input data into a tensor with shape (1, L, C), where L is the number of + data points, C is the number of variables, the data and time_stamp are aligned by index. The module will also + convert the data type of each column to the corresponding type. + """ + data, time_stamp, type_list, _ = full_data + data_length = time_stamp.shape[0] + data = data.fillna(0) + for i in range(len(type_list)): + if type_list[i] == "TEXT": + data[data.columns[i]] = 0 + elif type_list[i] == "BOOLEAN": + data[data.columns[i]] = data[data.columns[i]].astype("int") + data = tensor(data.values).unsqueeze(0) + return data, data_length + + +def inference_with_built_in_model(model_id, full_data, inference_attributes): + """ + Args: + model_id: the unique id of the model + full_data: a tuple of (data, time_stamp, type_list, column_name_list), where the data is a DataFrame with shape + (L, C), time_stamp is a DataFrame with shape(L, 1), type_list is a list of data types with length C, + column_name_list is a list of column names with length C, where L is the number of data points, C is the + number of variables, the data and time_stamp are aligned by index + inference_attributes: a list of attributes to be inferred, in this function, the attributes will include some + parameters of the built-in model. Some parameters are optional, and if the parameters are not + specified, the default value will be used. + Returns: + outputs: a list of output DataFrames, where each DataFrame has shape (H', C'), where H' is the output window + interval, C' is the number of variables in the output DataFrame + Description: + the inference_with_built_in_model function will inference with built-in model from sktime, which does not + require user registration. This module will parse the inference attributes and create the built-in model, then + feed the input data into the model to get the output, the output is a DataFrame with shape (H', C'), where H' + is the output window interval, C' is the number of variables in the output DataFrame. Then the inference module + will concatenate all the output DataFrames into a list. + """ + model_id = model_id.lower() + if model_id not in BuiltInModelType.values(): + raise BuiltInModelNotSupportError(model_id) + + logger.info(f"start inference built-in model {model_id}") + + # parse the inference attributes and create the built-in model + model, attributes = create_built_in_model(model_id, inference_attributes) + + data, _, _, _ = full_data + + output = model.inference(data) + + # output: DataFrame, shape: (H', C') + output = pd.DataFrame(output) + outputs = [output] + return outputs diff --git a/iotdb-core/ainode/iotdb/ainode/log.py b/iotdb-core/ainode/iotdb/ainode/log.py new file mode 100644 index 000000000000..5745b6fd2abb --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/log.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import inspect +import logging +import multiprocessing +import os +import random +import sys + +from iotdb.ainode.constant import STD_LEVEL, AINODE_LOG_FILE_NAMES, AINODE_LOG_FILE_LEVELS + + +class LoggerFilter(logging.Filter): + def filter(self, record): + record.msg = f"{self.custom_log_info()}: {record.msg}" + return True + + @staticmethod + def custom_log_info(): + frame = inspect.currentframe() + stack_trace = inspect.getouterframes(frame) + + pid = os.getpid() + process_name = multiprocessing.current_process().name + + stack_info = "" + frame_info = stack_trace[7] + file_name = frame_info.filename + # if file_name is not in current working directory, find the first "iotdb" in the path + for l in range(len(file_name)): + i = len(file_name) - l - 1 + if file_name[i:].startswith("iotdb/") or file_name[i:].startswith("iotdb\\"): + file_name = file_name[i:] + break + + stack_info += f"{file_name}:{frame_info.lineno}-{frame_info.function}" + + return f"[{pid}:{process_name}] {stack_info}" + + +class Logger: + """ + Args: + log_dir: log directory + + logger_format: log format of global logger + logger: global logger with custom format and level + file_handlers: file handlers for different levels + console_handler: console handler for stdout + __lock: lock for logger + """ + + def __init__(self, log_dir=None): + + self.logger_format = logging.Formatter(fmt='%(asctime)s %(levelname)s %(' + 'message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + self.logger = logging.getLogger(str(random.random())) + self.logger.handlers.clear() + self.logger.setLevel(logging.DEBUG) + self.console_handler = logging.StreamHandler(sys.stdout) + self.console_handler.setLevel(STD_LEVEL) + self.console_handler.setFormatter(self.logger_format) + + self.logger.addHandler(self.console_handler) + + if log_dir is not None: + file_names = AINODE_LOG_FILE_NAMES + file_levels = AINODE_LOG_FILE_LEVELS + if not os.path.exists(log_dir): + os.makedirs(log_dir) + os.chmod(log_dir, 0o777) + for file_name in file_names: + log_path = log_dir + "/" + file_name + if not os.path.exists(log_path): + f = open(log_path, mode='w', encoding='utf-8') + f.close() + os.chmod(log_path, 0o777) + self.file_handlers = [] + for l in range(len(file_names)): + self.file_handlers.append(logging.FileHandler(log_dir + "/" + file_names[l], mode='a')) + self.file_handlers[l].setLevel(file_levels[l]) + self.file_handlers[l].setFormatter(self.logger_format) + + for filehandler in self.file_handlers: + self.logger.addHandler(filehandler) + + + self.logger.addFilter(LoggerFilter()) + self.__lock = multiprocessing.Lock() + + def debug(self, *args) -> None: + self.__lock.acquire() + self.logger.debug(' '.join(map(str, args))) + self.__lock.release() + + def info(self, *args) -> None: + self.__lock.acquire() + self.logger.info(' '.join(map(str, args))) + self.__lock.release() + + def warning(self, *args) -> None: + self.__lock.acquire() + self.logger.warning(' '.join(map(str, args))) + self.__lock.release() + + def error(self, *args) -> None: + self.__lock.acquire() + self.logger.error(' '.join(map(str, args))) + self.__lock.release() + + +logger = Logger() + + +def set_logger(ain_logs_dir): + global logger + logger = Logger(ain_logs_dir) diff --git a/iotdb-core/ainode/iotdb/ainode/parser.py b/iotdb-core/ainode/iotdb/ainode/parser.py new file mode 100644 index 000000000000..f3f8002333ef --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/parser.py @@ -0,0 +1,191 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import argparse +import re +from abc import abstractmethod +from typing import Dict + +from iotdb.ainode.constant import OptionsKey, TaskType, ForecastModelType +from iotdb.ainode.exception import (MissingOptionError, RedundantOptionError, + UnsupportedError, BadConfigValueError) +from iotdb.ainode.serde import convert_to_df, get_data_type_byte_from_str +from iotdb.thrift.ainode.ttypes import TInferenceReq, TConfigs + + +class TaskOptions(object): + def __init__(self, options: Dict): + self._raw_options = options + + if OptionsKey.MODEL_TYPE.name() not in self._raw_options: + raise MissingOptionError(OptionsKey.MODEL_TYPE.name()) + model_name = self._raw_options.pop(OptionsKey.MODEL_TYPE.name()) + self.model_type = getattr(ForecastModelType, model_name.upper(), None) + if not self.model_type: + raise UnsupportedError(f"model_type {model_name}") + + # training with auto-tuning as default + self.auto_tuning = str2bool(self._raw_options.pop(OptionsKey.AUTO_TUNING.name(), "false")) + + @abstractmethod + def get_task_type(self) -> TaskType: + raise NotImplementedError("Subclasses must implement the validate() method.") + + def _check_redundant_options(self) -> None: + if len(self._raw_options): + raise RedundantOptionError(str(self._raw_options)) + + +class ForecastTaskOptions(TaskOptions): + def __init__(self, options: Dict): + super().__init__(options) + self.input_length = self._raw_options.pop(OptionsKey.INPUT_LENGTH.name(), 96) + self.predict_length = self._raw_options.pop(OptionsKey.PREDICT_LENGTH.name(), 96) + self.predict_index_list = self._raw_options.pop(OptionsKey.PREDICT_INDEX_LIST.name(), None) + self.input_type_list = self._raw_options.pop(OptionsKey.INPUT_TYPE_LIST.name(), None) + super()._check_redundant_options() + + def get_task_type(self) -> TaskType: + return TaskType.FORECAST + + +def parse_task_type(options: Dict) -> TaskType: + if OptionsKey.TASK_TYPE.name() not in options: + raise MissingOptionError(OptionsKey.TASK_TYPE.name()) + task_name = options.pop(OptionsKey.TASK_TYPE.name()) + task_type = getattr(TaskType, task_name.upper(), None) + if not task_type: + raise UnsupportedError(f"task_type {task_name}") + return task_type + + +def parse_task_options(options) -> TaskOptions: + task_type = parse_task_type(options) + if task_type == TaskType.FORECAST: + return ForecastTaskOptions(options) + else: + raise UnsupportedError(f"task type {task_type}") + + +def parse_inference_config(config_dict): + """ + Args: + config_dict: dict + - configs: dict + - input_shape (list): input shape of the model and needs to be two-dimensional array like [96, 2] + - output_shape (list): output shape of the model and needs to be two-dimensional array like [96, 2] + - input_type (list): input type of the model and each element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'], default float64 + - output_type (list): output type of the model and each element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'], default float64 + - attributes: dict + Returns: + configs: TConfigs + attributes: str + """ + configs = config_dict['configs'] + + # check if input_shape and output_shape are two-dimensional array + if not (isinstance(configs['input_shape'], list) and len(configs['input_shape']) == 2): + raise BadConfigValueError('input_shape', configs['input_shape'], + 'input_shape should be a two-dimensional array.') + if not (isinstance(configs['output_shape'], list) and len(configs['output_shape']) == 2): + raise BadConfigValueError('output_shape', configs['output_shape'], + 'output_shape should be a two-dimensional array.') + + # check if input_shape and output_shape are positive integer + input_shape_is_positive_number = isinstance(configs['input_shape'][0], int) and isinstance( + configs['input_shape'][1], int) and configs['input_shape'][0] > 0 and configs['input_shape'][1] > 0 + if not input_shape_is_positive_number: + raise BadConfigValueError('input_shape', configs['input_shape'], + 'element in input_shape should be positive integer.') + + output_shape_is_positive_number = isinstance(configs['output_shape'][0], int) and isinstance( + configs['output_shape'][1], int) and configs['output_shape'][0] > 0 and configs['output_shape'][1] > 0 + if not output_shape_is_positive_number: + raise BadConfigValueError('output_shape', configs['output_shape'], + 'element in output_shape should be positive integer.') + + # check if input_type and output_type are one-dimensional array with right length + if 'input_type' in configs and not ( + isinstance(configs['input_type'], list) and len(configs['input_type']) == configs['input_shape'][1]): + raise BadConfigValueError('input_type', configs['input_type'], + 'input_type should be a one-dimensional array and length of it should be equal to input_shape[1].') + + if 'output_type' in configs and not ( + isinstance(configs['output_type'], list) and len(configs['output_type']) == configs['output_shape'][1]): + raise BadConfigValueError('output_type', configs['output_type'], + 'output_type should be a one-dimensional array and length of it should be equal to output_shape[1].') + + # parse input_type and output_type to byte + if 'input_type' in configs: + input_type = [get_data_type_byte_from_str(x) for x in configs['input_type']] + else: + input_type = [get_data_type_byte_from_str('float32')] * configs['input_shape'][1] + + if 'output_type' in configs: + output_type = [get_data_type_byte_from_str(x) for x in configs['output_type']] + else: + output_type = [get_data_type_byte_from_str('float32')] * configs['output_shape'][1] + + # parse attributes + attributes = "" + if 'attributes' in config_dict: + attributes = str(config_dict['attributes']) + + return TConfigs(configs['input_shape'], configs['output_shape'], input_type, output_type), attributes + + +def parse_inference_request(req: TInferenceReq): + binary_dataset = req.dataset + type_list = req.typeList + column_name_list = req.columnNameList + column_name_index = req.columnNameIndexMap + data = convert_to_df(column_name_list, type_list, column_name_index, [binary_dataset]) + time_stamp, data = data[data.columns[0:1]], data[data.columns[1:]] + full_data = (data, time_stamp, type_list, column_name_list) + inference_attributes = req.inferenceAttributes + if inference_attributes is None: + inference_attributes = {} + + window_params = req.windowParams + if window_params is None: + # set default window_step to infinity and window_interval to dataset length + window_step = float('inf') + window_interval = data.shape[0] + else: + window_step = window_params.windowStep + window_interval = window_params.windowInterval + return req.modelId, full_data, window_interval, window_step, inference_attributes + + +def str2bool(value): + if value.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif value.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +# This is used to extract the key message in RuntimeError instead of the traceback message +def runtime_error_extractor(error_message): + pattern = re.compile(r"RuntimeError: (.+)") + match = pattern.search(error_message) + + if match: + return match.group(1) + else: + return "" diff --git a/iotdb-core/ainode/iotdb/ainode/script.py b/iotdb-core/ainode/iotdb/ainode/script.py new file mode 100644 index 000000000000..2db29bb62bc4 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/script.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import os +import shutil +import sys + +from iotdb.ainode.client import client_manager +from iotdb.ainode.config import descriptor +from iotdb.ainode.constant import TSStatusCode +from iotdb.ainode.exception import MissingConfigError +from iotdb.ainode.log import logger +from iotdb.ainode.service import AINode +from iotdb.thrift.common.ttypes import TAINodeLocation, TEndPoint + +server: AINode = None +POINT_COLON = ":" + +def main(): + global server + arguments = sys.argv + if len(arguments) == 1: + logger.info("Command line argument must be specified.") + return + command = arguments[1] + if command == 'start': + try: + server = AINode() + server.start() + except Exception as e: + logger.error("Start AINode failed, because of: {}".format(e)) + sys.exit(1) + elif command == 'remove': + try: + logger.info("Removing AINode...") + if len(arguments) >= 3: + target_ainode = arguments[2] + # parameter pattern: or : + ainode_info = target_ainode.split(POINT_COLON) + target_ainode_id = -1 + + # ainode id + if len(ainode_info) == 1: + target_ainode_id = int(ainode_info[0]) + + ainode_configuration_map = client_manager.borrow_config_node_client().get_ainode_configuration( + target_ainode_id) + + end_point = ainode_configuration_map[target_ainode_id].location.internalEndPoint + target_rpc_address = end_point.ip + target_rpc_port = end_point.port + elif len(ainode_info) == 2: + target_rpc_address = ainode_info[0] + target_rpc_port = int(ainode_info[1]) + + ainode_configuration_map = client_manager.borrow_config_node_client().get_ainode_configuration(-1) + + for cur_ainode_id, cur_configuration in ainode_configuration_map.items(): + cur_end_point = cur_configuration.location.internalEndPoint + if cur_end_point.ip == target_rpc_address and cur_end_point.port == target_rpc_port: + target_ainode_id = cur_ainode_id + break + if target_ainode_id == -1: + raise MissingConfigError( + "Can't find ainode through {}:{}".format(target_rpc_port, target_rpc_address)) + else: + raise MissingConfigError("NodeId or IP:Port should be provided to remove AINode") + + logger.info('Got target AINode id: {}, address: {}, port: {}' + .format(target_ainode_id, target_rpc_address, target_rpc_port)) + else: + target_ainode_id = descriptor.get_config().get_ainode_id() + target_rpc_address = descriptor.get_config().get_ain_inference_rpc_address() + target_rpc_port = descriptor.get_config().get_ain_inference_rpc_port() + + location = TAINodeLocation(target_ainode_id, TEndPoint(target_rpc_address, target_rpc_port)) + status = client_manager.borrow_config_node_client().node_remove(location) + + if status.code == TSStatusCode.SUCCESS_STATUS.get_status_code(): + logger.info('IoTDB-AINode has successfully removed.') + if os.path.exists(descriptor.get_config().get_ain_models_dir()): + shutil.rmtree(descriptor.get_config().get_ain_models_dir()) + + except Exception as e: + logger.error("Remove AINode failed, because of: {}".format(e)) + sys.exit(1) + else: + logger.warning("Unknown argument: {}.".format(command)) diff --git a/iotdb-core/ainode/iotdb/ainode/serde.py b/iotdb-core/ainode/iotdb/ainode/serde.py new file mode 100644 index 000000000000..4338dcdfefc8 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/serde.py @@ -0,0 +1,564 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import struct +from enum import Enum + +import numpy as np +import pandas as pd + +from iotdb.ainode.exception import BadConfigValueError + + +class TSDataType(Enum): + BOOLEAN = 0 + INT32 = 1 + INT64 = 2 + FLOAT = 3 + DOUBLE = 4 + TEXT = 5 + + # this method is implemented to avoid the issue reported by: + # https://bugs.python.org/issue30545 + def __eq__(self, other) -> bool: + return self.value == other.value + + def __hash__(self): + return self.value + + def np_dtype(self): + return { + TSDataType.BOOLEAN: np.dtype(">?"), + TSDataType.FLOAT: np.dtype(">f4"), + TSDataType.DOUBLE: np.dtype(">f8"), + TSDataType.INT32: np.dtype(">i4"), + TSDataType.INT64: np.dtype(">i8"), + TSDataType.TEXT: np.dtype("str"), + }[self] + + +TIMESTAMP_STR = "Time" +START_INDEX = 2 + + +# convert dataFrame to tsBlock in binary +# input shouldn't contain time column +def convert_to_binary(data_frame: pd.DataFrame): + data_shape = data_frame.shape + value_column_size = data_shape[1] + position_count = data_shape[0] + keys = data_frame.keys() + + binary = value_column_size.to_bytes(4, byteorder="big") + + for data_type in data_frame.dtypes: + binary += _get_type_in_byte(data_type) + + # position count + binary += position_count.to_bytes(4, byteorder="big") + + # column encoding + binary += b'\x02' + for data_type in data_frame.dtypes: + binary += _get_encoder(data_type) + + # write columns, the column in index 0 must be timeColumn + binary += bool.to_bytes(False, 1, byteorder="big") + for i in range(position_count): + value = 0 + v = struct.pack(">i", value) + binary += v + binary += v + + for i in range(value_column_size): + # the value can't be null + binary += bool.to_bytes(False, 1, byteorder="big") + col = data_frame[keys[i]] + for j in range(position_count): + value = col[j] + if value.dtype.byteorder != '>': + value = value.byteswap() + binary += value.tobytes() + + return binary + + +# convert tsBlock in binary to dataFrame +def convert_to_df(name_list, type_list, name_index, binary_list): + column_name_list = [TIMESTAMP_STR] + column_type_list = [TSDataType.INT64] + column_ordinal_dict = {TIMESTAMP_STR: 1} + + if name_index is not None: + column_type_deduplicated_list = [ + None for _ in range(len(name_index)) + ] + for i in range(len(name_list)): + name = name_list[i] + column_name_list.append(name) + column_type_list.append(TSDataType[type_list[i]]) + if name not in column_ordinal_dict: + index = name_index[name] + column_ordinal_dict[name] = index + START_INDEX + column_type_deduplicated_list[index] = TSDataType[type_list[i]] + else: + index = START_INDEX + column_type_deduplicated_list = [] + for i in range(len(name_list)): + name = name_list[i] + column_name_list.append(name) + column_type_list.append(TSDataType[type_list[i]]) + if name not in column_ordinal_dict: + column_ordinal_dict[name] = index + index += 1 + column_type_deduplicated_list.append( + TSDataType[type_list[i]] + ) + + binary_size = len(binary_list) + binary_index = 0 + result = {} + for column_name in column_name_list: + result[column_name] = None + + while binary_index < binary_size: + buffer = binary_list[binary_index] + binary_index += 1 + time_column_values, column_values, null_indicators, _ = deserialize(buffer) + time_array = np.frombuffer( + time_column_values, np.dtype(np.longlong).newbyteorder(">") + ) + if time_array.dtype.byteorder == ">": + time_array = time_array.byteswap().newbyteorder("<") + + if result[TIMESTAMP_STR] is None: + result[TIMESTAMP_STR] = time_array + else: + result[TIMESTAMP_STR] = np.concatenate( + (result[TIMESTAMP_STR], time_array), axis=0 + ) + total_length = len(time_array) + + for i in range(len(column_values)): + column_name = column_name_list[i + 1] + + location = column_ordinal_dict[column_name] - START_INDEX + if location < 0: + continue + + data_type = column_type_deduplicated_list[location] + value_buffer = column_values[location] + value_buffer_len = len(value_buffer) + + if data_type == TSDataType.DOUBLE: + data_array = np.frombuffer( + value_buffer, np.dtype(np.double).newbyteorder(">") + ) + elif data_type == TSDataType.FLOAT: + data_array = np.frombuffer( + value_buffer, np.dtype(np.float32).newbyteorder(">") + ) + elif data_type == TSDataType.BOOLEAN: + data_array = [] + for index in range(len(value_buffer)): + data_array.append(value_buffer[index]) + data_array = np.array(data_array).astype("bool") + elif data_type == TSDataType.INT32: + data_array = np.frombuffer( + value_buffer, np.dtype(np.int32).newbyteorder(">") + ) + elif data_type == TSDataType.INT64: + data_array = np.frombuffer( + value_buffer, np.dtype(np.int64).newbyteorder(">") + ) + elif data_type == TSDataType.TEXT: + index = 0 + data_array = [] + while index < value_buffer_len: + value_bytes = value_buffer[index] + value = value_bytes.decode("utf-8") + data_array.append(value) + index += 1 + data_array = np.array(data_array, dtype=object) + else: + raise RuntimeError("unsupported data type {}.".format(data_type)) + + if data_array.dtype.byteorder == ">": + data_array = data_array.byteswap().newbyteorder("<") + + null_indicator = null_indicators[location] + if len(data_array) < total_length or (data_type == TSDataType.BOOLEAN and null_indicator is not None): + if data_type == TSDataType.INT32 or data_type == TSDataType.INT64: + tmp_array = np.full(total_length, np.nan, np.float32) + elif data_type == TSDataType.FLOAT or data_type == TSDataType.DOUBLE: + tmp_array = np.full(total_length, np.nan, data_array.dtype) + elif data_type == TSDataType.BOOLEAN: + tmp_array = np.full(total_length, np.nan, np.float32) + elif data_type == TSDataType.TEXT: + tmp_array = np.full(total_length, np.nan, dtype=data_array.dtype) + else: + raise Exception("Unsupported dataType in deserialization") + + if null_indicator is not None: + indexes = [not v for v in null_indicator] + if data_type == TSDataType.BOOLEAN: + tmp_array[indexes] = data_array[indexes] + else: + tmp_array[indexes] = data_array + + if data_type == TSDataType.INT32: + tmp_array = pd.Series(tmp_array).astype("Int32") + elif data_type == TSDataType.INT64: + tmp_array = pd.Series(tmp_array).astype("Int64") + elif data_type == TSDataType.BOOLEAN: + tmp_array = pd.Series(tmp_array).astype("boolean") + + data_array = tmp_array + + if result[column_name] is None: + result[column_name] = data_array + else: + if isinstance(result[column_name], pd.Series): + if not isinstance(data_array, pd.Series): + if data_type == TSDataType.INT32: + data_array = pd.Series(data_array).astype("Int32") + elif data_type == TSDataType.INT64: + data_array = pd.Series(data_array).astype("Int64") + elif data_type == TSDataType.BOOLEAN: + data_array = pd.Series(data_array).astype("boolean") + else: + raise RuntimeError("Series Error") + result[column_name] = result[column_name].append(data_array) + else: + result[column_name] = np.concatenate( + (result[column_name], data_array), axis=0 + ) + for k, v in result.items(): + if v is None: + result[k] = [] + df = pd.DataFrame(result) + df = df.reset_index(drop=True) + return df + + +def _get_encoder(data_type: pd.Series): + if data_type == "bool": + return b'\x00' + elif data_type == "int32" or data_type == "float32": + return b'\x01' + elif data_type == "int64" or data_type == "float64": + return b'\x02' + elif data_type == "texr": + return b'\x03' + + +def _get_type_in_byte(data_type: pd.Series): + if data_type == 'bool': + return b'\x00' + elif data_type == 'int32': + return b'\x01' + elif data_type == 'int64': + return b'\x02' + elif data_type == 'float32': + return b'\x03' + elif data_type == 'float64': + return b'\x04' + elif data_type == 'text': + return b'\x05' + else: + raise BadConfigValueError('data_type', data_type, + "data_type should be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text']") + + +# Serialized tsBlock: +# +-------------+---------------+---------+------------+-----------+----------+ +# | val col cnt | val col types | pos cnt | encodings | time col | val col | +# +-------------+---------------+---------+------------+-----------+----------+ +# | int32 | list[byte] | int32 | list[byte] | bytes | byte | +# +-------------+---------------+---------+------------+-----------+----------+ + +def deserialize(buffer): + value_column_count, buffer = read_int_from_buffer(buffer) + data_types, buffer = read_column_types(buffer, value_column_count) + + position_count, buffer = read_int_from_buffer(buffer) + column_encodings, buffer = read_column_encoding(buffer, value_column_count + 1) + + time_column_values, buffer = read_time_column(buffer, position_count) + column_values = [None] * value_column_count + null_indicators = [None] * value_column_count + for i in range(value_column_count): + column_value, null_indicator, buffer = read_column(column_encodings[i + 1], buffer, data_types[i], + position_count) + column_values[i] = column_value + null_indicators[i] = null_indicator + + return time_column_values, column_values, null_indicators, position_count + + +# General Methods + +def read_int_from_buffer(buffer): + res, buffer = read_from_buffer(buffer, 4) + return int.from_bytes(res, "big"), buffer + + +def read_byte_from_buffer(buffer): + return read_from_buffer(buffer, 1) + + +def read_from_buffer(buffer, size): + res = buffer[:size] + buffer = buffer[size:] + return res, buffer + + +# Read ColumnType + +def read_column_types(buffer, value_column_count): + data_types = [] + for _ in range(value_column_count): + res, buffer = read_byte_from_buffer(buffer) + data_types.append(get_data_type(res)) + return data_types, buffer + + +def get_data_type(value): + if value == b'\x00': + return TSDataType.BOOLEAN + elif value == b'\x01': + return TSDataType.INT32 + elif value == b'\x02': + return TSDataType.INT64 + elif value == b'\x03': + return TSDataType.FLOAT + elif value == b'\x04': + return TSDataType.DOUBLE + elif value == b'\x05': + return TSDataType.TEXT + + +def get_data_type_byte_from_str(value): + ''' + Args: + value (str): data type in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'] + Returns: + byte: corresponding data type in [b'\x00', b'\x01', b'\x02', b'\x03', b'\x04', b'\x05'] + ''' + if value not in ['bool', 'int32', 'int64', 'float32', 'float64', 'text']: + raise BadConfigValueError('data_type', value, + "data_type should be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text']") + if value == "bool": + return TSDataType.BOOLEAN.value + elif value == "int32": + return TSDataType.INT32.value + elif value == "int64": + return TSDataType.INT64.value + elif value == "float32": + return TSDataType.FLOAT.value + elif value == "float64": + return TSDataType.DOUBLE.value + elif value == "text": + return TSDataType.TEXT.value + + +# Read ColumnEncodings + +def read_column_encoding(buffer, size): + encodings = [] + for _ in range(size): + res, buffer = read_byte_from_buffer(buffer) + encodings.append(res) + return encodings, buffer + + +# Read Column + +def deserialize_null_indicators(buffer, size): + may_have_null, buffer = read_byte_from_buffer(buffer) + if may_have_null != b'\x00': + return deserialize_from_boolean_array(buffer, size) + return None, buffer + + +# Serialized data layout: +# +---------------+-----------------+-------------+ +# | may have null | null indicators | values | +# +---------------+-----------------+-------------+ +# | byte | list[byte] | list[int64] | +# +---------------+-----------------+-------------+ + +def read_time_column(buffer, size): + null_indicators, buffer = deserialize_null_indicators(buffer, size) + if null_indicators is None: + values, buffer = read_from_buffer( + buffer, size * 8 + ) + else: + raise Exception("TimeColumn should not contains null value") + return values, buffer + + +def read_int64_column(buffer, data_type, position_count): + null_indicators, buffer = deserialize_null_indicators(buffer, position_count) + if null_indicators is None: + size = position_count + else: + size = null_indicators.count(False) + + if TSDataType.INT64 == data_type or TSDataType.DOUBLE == data_type: + values, buffer = read_from_buffer(buffer, size * 8) + return values, null_indicators, buffer + else: + raise Exception("Invalid data type: " + data_type) + + +# Serialized data layout: +# +---------------+-----------------+-------------+ +# | may have null | null indicators | values | +# +---------------+-----------------+-------------+ +# | byte | list[byte] | list[int32] | +# +---------------+-----------------+-------------+ + +def read_int32_column(buffer, data_type, position_count): + null_indicators, buffer = deserialize_null_indicators(buffer, position_count) + if null_indicators is None: + size = position_count + else: + size = null_indicators.count(False) + + if TSDataType.INT32 == data_type or TSDataType.FLOAT == data_type: + values, buffer = read_from_buffer(buffer, size * 4) + return values, null_indicators, buffer + else: + raise Exception("Invalid data type: " + data_type) + + +# Serialized data layout: +# +---------------+-----------------+-------------+ +# | may have null | null indicators | values | +# +---------------+-----------------+-------------+ +# | byte | list[byte] | list[byte] | +# +---------------+-----------------+-------------+ + +def read_byte_column(buffer, data_type, position_count): + if data_type != TSDataType.BOOLEAN: + raise Exception("Invalid data type: " + data_type) + null_indicators, buffer = deserialize_null_indicators(buffer, position_count) + res, buffer = deserialize_from_boolean_array(buffer, position_count) + return res, null_indicators, buffer + + +def deserialize_from_boolean_array(buffer, size): + packed_boolean_array, buffer = read_from_buffer(buffer, (size + 7) // 8) + current_byte = 0 + output = [None] * size + position = 0 + # read null bits 8 at a time + while position < (size & ~0b111): + value = packed_boolean_array[current_byte] + output[position] = ((value & 0b1000_0000) != 0) + output[position + 1] = ((value & 0b0100_0000) != 0) + output[position + 2] = ((value & 0b0010_0000) != 0) + output[position + 3] = ((value & 0b0001_0000) != 0) + output[position + 4] = ((value & 0b0000_1000) != 0) + output[position + 5] = ((value & 0b0000_0100) != 0) + output[position + 6] = ((value & 0b0000_0010) != 0) + output[position + 7] = ((value & 0b0000_0001) != 0) + + position += 8 + current_byte += 1 + # read last null bits + if (size & 0b111) > 0: + value = packed_boolean_array[-1] + mask = 0b1000_0000 + position = size & ~0b111 + while position < size: + output[position] = ((value & mask) != 0) + mask >>= 1 + position += 1 + return output, buffer + + +# Serialized data layout: +# +---------------+-----------------+-------------+ +# | may have null | null indicators | values | +# +---------------+-----------------+-------------+ +# | byte | list[byte] | list[entry] | +# +---------------+-----------------+-------------+ +# +# Each entry is represented as: +# +---------------+-------+ +# | value length | value | +# +---------------+-------+ +# | int32 | bytes | +# +---------------+-------+ + +def read_binary_column(buffer, data_type, position_count): + if data_type != TSDataType.TEXT: + raise Exception("Invalid data type: " + data_type) + null_indicators, buffer = deserialize_null_indicators(buffer, position_count) + + if null_indicators is None: + size = position_count + else: + size = null_indicators.count(False) + values = [None] * size + for i in range(size): + length, buffer = read_int_from_buffer(buffer) + res, buffer = read_from_buffer(buffer, length) + values[i] = res + return values, null_indicators, buffer + + +def read_column(encoding, buffer, data_type, position_count): + if encoding == b'\x00': + return read_byte_column(buffer, data_type, position_count) + elif encoding == b'\x01': + return read_int32_column(buffer, data_type, position_count) + elif encoding == b'\x02': + return read_int64_column(buffer, data_type, position_count) + elif encoding == b'\x03': + return read_binary_column(buffer, data_type, position_count) + elif encoding == b'\x04': + return read_run_length_column(buffer, data_type, position_count) + else: + raise Exception("Unsupported encoding: " + encoding) + + +# Serialized data layout: +# +-----------+-------------------------+ +# | encoding | serialized inner column | +# +-----------+-------------------------+ +# | byte | list[byte] | +# +-----------+-------------------------+ + +def read_run_length_column(buffer, data_type, position_count): + encoding, buffer = read_byte_from_buffer(buffer) + column, null_indicators, buffer = read_column(encoding, buffer, data_type, 1) + + return repeat(column, data_type, position_count), null_indicators * position_count, buffer + + +def repeat(buffer, data_type, position_count): + if data_type == TSDataType.BOOLEAN or data_type == TSDataType.TEXT: + return buffer * position_count + else: + res = bytes(0) + for _ in range(position_count): + res.join(buffer) + return res diff --git a/iotdb-core/ainode/iotdb/ainode/service.py b/iotdb-core/ainode/iotdb/ainode/service.py new file mode 100644 index 000000000000..b6bd73bde9ba --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/service.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import os +import sys +import threading +from datetime import datetime + +import psutil +from thrift.protocol import TCompactProtocol, TBinaryProtocol +from thrift.server import TServer +from thrift.transport import TSocket, TTransport + +from iotdb.ainode.client import client_manager +from iotdb.ainode.config import descriptor +from iotdb.ainode.constant import AINODE_SYSTEM_FILE_NAME +from iotdb.ainode.handler import AINodeRPCServiceHandler +from iotdb.ainode.log import logger +from iotdb.thrift.ainode import IAINodeRPCService +from iotdb.thrift.common.ttypes import TAINodeConfiguration, TAINodeLocation, TEndPoint, TNodeResource +from iotdb.thrift.confignode.ttypes import TNodeVersionInfo + + +class RPCService(threading.Thread): + def __init__(self): + self.exit_code = 0 + super().__init__() + processor = IAINodeRPCService.Processor(handler=AINodeRPCServiceHandler()) + transport = TSocket.TServerSocket(host=descriptor.get_config().get_ain_inference_rpc_address(), + port=descriptor.get_config().get_ain_inference_rpc_port()) + transport_factory = TTransport.TFramedTransportFactory() + if descriptor.get_config().get_ain_thrift_compression_enabled(): + protocol_factory = TCompactProtocol.TCompactProtocolFactory() + else: + protocol_factory = TBinaryProtocol.TBinaryProtocolFactory() + + self.__pool_server = TServer.TThreadPoolServer(processor, transport, transport_factory, protocol_factory) + + def run(self) -> None: + logger.info("The RPC service thread begin to run...") + try: + self.__pool_server.serve() + except Exception as e: + self.exit_code = 1 + logger.error(e) + + +class AINode(object): + def __init__(self): + self.__rpc_service = RPCService() + + def start(self) -> None: + logger.info('IoTDB-AINode is starting...') + system_path = descriptor.get_config().get_ain_system_dir() + system_properties_file = os.path.join(descriptor.get_config().get_ain_system_dir(), AINODE_SYSTEM_FILE_NAME) + if not os.path.exists(system_path): + try: + os.makedirs(system_path) + os.chmod(system_path, 0o777) + except PermissionError as e: + logger.error(e) + raise e + + if not os.path.exists(system_properties_file): + # If the system.properties file does not exist, the AINode will register to ConfigNode. + try: + logger.info('IoTDB-AINode is registering to ConfigNode...') + ainode_id = client_manager.borrow_config_node_client().node_register( + descriptor.get_config().get_cluster_name(), + self._generate_configuration(), + self._generate_version_info()) + descriptor.get_config().set_ainode_id(ainode_id) + system_properties = { + 'ainode_id': ainode_id, + 'cluster_name': descriptor.get_config().get_cluster_name(), + 'iotdb_version': descriptor.get_config().get_version_info(), + 'commit_id': descriptor.get_config().get_build_info(), + 'ain_rpc_address': descriptor.get_config().get_ain_inference_rpc_address(), + 'ain_rpc_port': descriptor.get_config().get_ain_inference_rpc_port(), + 'config_node_list': descriptor.get_config().get_ain_target_config_node_list(), + } + with open(system_properties_file, 'w') as f: + f.write('#' + str(datetime.now()) + '\n') + for key, value in system_properties.items(): + f.write(key + '=' + str(value) + '\n') + + except Exception as e: + logger.error('IoTDB-AINode failed to register to ConfigNode: {}'.format(e)) + sys.exit(1) + else: + # If the system.properties file does exist, the AINode will just restart. + try: + logger.info('IoTDB-AINode is restarting...') + client_manager.borrow_config_node_client().node_restart( + descriptor.get_config().get_cluster_name(), + self._generate_configuration(), + self._generate_version_info()) + + except Exception as e: + logger.error('IoTDB-AINode failed to restart: {}'.format(e)) + sys.exit(1) + + self.__rpc_service.start() + self.__rpc_service.join(1) + if self.__rpc_service.exit_code != 0: + return + + logger.info('IoTDB-AINode has successfully started.') + + @staticmethod + def _generate_configuration() -> TAINodeConfiguration: + location = TAINodeLocation(descriptor.get_config().get_ainode_id(), + TEndPoint(descriptor.get_config().get_ain_inference_rpc_address(), + descriptor.get_config().get_ain_inference_rpc_port())) + resource = TNodeResource( + int(psutil.cpu_count()), + int(psutil.virtual_memory()[0]) + ) + + return TAINodeConfiguration(location, resource) + + @staticmethod + def _generate_version_info() -> TNodeVersionInfo: + return TNodeVersionInfo(descriptor.get_config().get_version_info(), + descriptor.get_config().get_build_info()) diff --git a/iotdb-core/ainode/iotdb/ainode/storage.py b/iotdb-core/ainode/iotdb/ainode/storage.py new file mode 100644 index 000000000000..612693d614f3 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/storage.py @@ -0,0 +1,318 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import json +import os +import shutil +import threading +from typing import Dict, Tuple +from urllib.parse import urljoin, urlparse + +import requests +import torch +import torch._dynamo +import torch.nn as nn +import yaml +from pylru import lrucache +from requests.adapters import HTTPAdapter + +from iotdb.ainode.config import descriptor +from iotdb.ainode.constant import (OptionsKey, DEFAULT_MODEL_FILE_NAME, + DEFAULT_CONFIG_FILE_NAME, DEFAULT_RECONNECT_TIMEOUT, + DEFAULT_RECONNECT_TIMES, DEFAULT_CHUNK_SIZE) +from iotdb.ainode.exception import ModelNotExistError, InvaildUriError +from iotdb.ainode.log import logger +from iotdb.ainode.parser import parse_inference_config +from iotdb.ainode.util import pack_input_dict + + +class ModelStorage(object): + _instance = None + _first_init = False + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if not self._first_init: + self.__model_dir = os.path.join(os.getcwd(), descriptor.get_config().get_ain_models_dir()) + if not os.path.exists(self.__model_dir): + try: + os.makedirs(self.__model_dir) + except PermissionError as e: + logger.error(e) + raise e + self.lock = threading.RLock() + self.__model_cache = lrucache(descriptor.get_config().get_mn_model_storage_cache_size()) + self._first_init = True + + def _parse_uri(self, uri): + ''' + Args: + uri (str): uri to parse + Returns: + is_network_path (bool): True if the url is a network path, False otherwise + parsed_uri (str): parsed uri to get related file + ''' + # remove quotation mark in uri + uri = uri[1:-1] + + parse_result = urlparse(uri) + is_network_path = parse_result.scheme in ('http', 'https') + if is_network_path: + return True, uri + + # handle file:// in uri + if parse_result.scheme == 'file': + uri = uri[7:] + + # handle ~ in uri + uri = os.path.expanduser(uri) + return False, uri + + def _download_file(self, url: str, storage_path: str) -> None: + ''' + Args: + url: url of file to download + storage_path: path to save the file + Returns: + None + ''' + logger.debug(f"download file from {url} to {storage_path}") + + session = requests.Session() + adapter = HTTPAdapter(max_retries=DEFAULT_RECONNECT_TIMES) + session.mount('http://', adapter) + session.mount('https://', adapter) + + response = session.get(url, timeout=DEFAULT_RECONNECT_TIMEOUT, stream=True) + response.raise_for_status() + + with self.lock: + with open(storage_path, 'wb') as file: + for chunk in response.iter_content(chunk_size=DEFAULT_CHUNK_SIZE): + if chunk: + file.write(chunk) + + logger.debug(f"download file from {url} to {storage_path} success") + + def _register_model_from_network(self, uri: str, storage_path: str, model_storage_path: str, + config_storage_path: str): + ''' + Args: + uri: network dir path of model to register, where model.pt and config.yaml are required, + e.g. https://huggingface.co/user/modelname/resolve/main/ + storage_path: dir path to save the model + model_storage_path: path to save model.pt + config_storage_path: path to save config.yaml + Returns: + configs: TConfigs + attributes: str + ''' + # concat uri to get commplete url + uri = uri if uri.endswith("/") else uri + "/" + target_model_path = urljoin(uri, DEFAULT_MODEL_FILE_NAME) + target_config_path = urljoin(uri, DEFAULT_CONFIG_FILE_NAME) + + # create storage dir if not exist + with self.lock: + if not os.path.exists(storage_path): + os.makedirs(storage_path) + + # download config file + self._download_file(target_config_path, config_storage_path) + + # read and parse config dict from config.yaml + with open(config_storage_path, 'r', encoding='utf-8') as file: + config_dict = yaml.safe_load(file) + configs, attributes = parse_inference_config(config_dict) + + # if config.yaml is correct, download model file + self._download_file(target_model_path, model_storage_path) + return configs, attributes + + def _register_model_from_local(self, uri: str, storage_path: str, model_storage_path: str, + config_storage_path: str): + ''' + Args: + uri: local dir path of model to register, where model.pt and config.yaml are required, + e.g. /Users/admin/Desktop/dlinear + storage_path: dir path to save the model + model_storage_path: path to save model.pt + config_storage_path: path to save config.yaml + Returns: + configs: TConfigs + attributes: str + ''' + # concat uri to get commplete path + target_model_path = os.path.join(uri, DEFAULT_MODEL_FILE_NAME) + target_config_path = os.path.join(uri, DEFAULT_CONFIG_FILE_NAME) + + # check if file exist + exist_model_file = os.path.exists(target_model_path) + exist_config_file = os.path.exists(target_config_path) + + if exist_model_file and exist_config_file: + # copy config.yaml + with self.lock: + if not os.path.exists(storage_path): + os.makedirs(storage_path) + + logger.debug(f"copy file from {target_config_path} to {storage_path}") + shutil.copy(target_config_path, config_storage_path) + logger.debug(f"copy file from {target_config_path} to {storage_path} success") + + # read and parse config dict from config.yaml + with open(config_storage_path, 'r', encoding='utf-8') as file: + config_dict = yaml.safe_load(file) + configs, attributes = parse_inference_config(config_dict) + + # if config.yaml is correct, copy model file + with self.lock: + logger.debug(f"copy file from {target_model_path} to {storage_path}") + shutil.copy(target_model_path, model_storage_path) + logger.debug(f"copy file from {target_model_path} to {storage_path} success") + elif not exist_model_file or not exist_config_file: + raise InvaildUriError(uri) + + return configs, attributes + + def register_model(self, model_id: str, uri: str): + ''' + Args: + model_id: id of model to register + uri: network dir path or local dir path of model to register, where model.pt and config.yaml are required, + e.g. https://huggingface.co/user/modelname/resolve/main/ or /Users/admin/Desktop/dlinear + Returns: + configs: TConfigs + attributes: str + ''' + storage_path = os.path.join(self.__model_dir, f'{model_id}') + model_storage_path = os.path.join(storage_path, DEFAULT_MODEL_FILE_NAME) + config_storage_path = os.path.join(storage_path, DEFAULT_CONFIG_FILE_NAME) + + is_network_path, uri = self._parse_uri(uri) + + if is_network_path: + return self._register_model_from_network(uri, storage_path, model_storage_path, config_storage_path) + else: + return self._register_model_from_local(uri, storage_path, model_storage_path, config_storage_path) + + def save_model(self, + model: nn.Module, + model_config: Dict, + model_id: str, + trial_id: str) -> str: + model_dir_path = os.path.join(self.__model_dir, f'{model_id}') + logger.debug(f"save model {model_config} to {model_dir_path}") + with self.lock: + if not os.path.exists(model_dir_path): + os.makedirs(model_dir_path) + model_file_path = os.path.join(model_dir_path, f'{trial_id}.pt') + + # Note: model config for time series should contain 'input_len' and 'input_vars' + sample_input = ( + pack_input_dict( + torch.randn(1, model_config[OptionsKey.INPUT_LENGTH.name()], model_config[OptionsKey.INPUT_VARS.name()]) + ) + ) + with self.lock: + torch.jit.save(torch.jit.trace(model, sample_input), + model_file_path, + _extra_files={'model_config': json.dumps(model_config)}) + return os.path.abspath(model_file_path) + + def load_model( + self, file_path: str) -> Tuple[torch.jit.ScriptModule, Dict]: + """ + Returns: + jit_model: a ScriptModule contains model architecture and parameters, which can be deployed cross-platform + model_config: a dict contains model attributes + """ + logger.debug(f"load model from {file_path}") + file_path = os.path.join(self.__model_dir, file_path) + if file_path in self.__model_cache: + return self.__model_cache[file_path] + else: + if not os.path.exists(file_path): + raise ModelNotExistError(file_path) + else: + tmp_dict = {'model_config': ''} + jit_model = torch.jit.load(file_path, _extra_files=tmp_dict) + model_config = json.loads(tmp_dict['model_config']) + self.__model_cache[file_path] = jit_model, model_config + return jit_model, model_config + + def load_model_from_id(self, model_id: str, acceleration=False): + """ + Returns: + model: a ScriptModule contains model architecture and parameters, which can be deployed cross-platform + """ + ain_models_dir = os.path.join(self.__model_dir, f'{model_id}') + model_path = os.path.join(ain_models_dir, DEFAULT_MODEL_FILE_NAME) + logger.debug(f"load model from {model_path}") + if model_path in self.__model_cache: + model = self.__model_cache[model_path] + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule) or not acceleration: + return model + else: + model = torch.compile(model) + self.__model_cache[model_path] = model + return model + else: + if not os.path.exists(model_path): + raise ModelNotExistError(model_path) + else: + model = torch.jit.load(model_path) + if acceleration: + try: + model = torch.compile(model) + except: + logger.warning("acceleration failed, fallback to normal mode") + self.__model_cache[model_path] = model + return model + + def delete_model(self, model_id: str) -> None: + ''' + Args: + model_id: id of model to delete + Returns: + None + ''' + storage_path = os.path.join(self.__model_dir, f'{model_id}') + + if os.path.exists(storage_path): + for file_name in os.listdir(storage_path): + self._remove_from_cache(os.path.join(storage_path, file_name)) + shutil.rmtree(storage_path) + + def delete_trial(self, model_id: str, trial_id: str) -> None: + logger.debug(f"delete trial {trial_id} of model {model_id}") + model_file_path = os.path.join(self.__model_dir, f'{model_id}', f'{trial_id}.pt') + self._remove_from_cache(model_file_path) + if os.path.exists(model_file_path): + os.remove(model_file_path) + + def _remove_from_cache(self, file_path: str) -> None: + if file_path in self.__model_cache: + del self.__model_cache[file_path] + + +model_storage = ModelStorage() diff --git a/iotdb-core/ainode/iotdb/ainode/util.py b/iotdb-core/ainode/iotdb/ainode/util.py new file mode 100644 index 000000000000..af605bb30c20 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/util.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import torch + +from iotdb.ainode.constant import TSStatusCode, ModelInputName +from iotdb.ainode.exception import BadNodeUrlError +from iotdb.ainode.log import logger +from iotdb.thrift.common.ttypes import TEndPoint, TSStatus + + +def parse_endpoint_url(endpoint_url: str) -> TEndPoint: + """ Parse TEndPoint from a given endpoint url. + Args: + endpoint_url: an endpoint url, format: ip:port + Returns: + TEndPoint + Raises: + BadNodeUrlError + """ + split = endpoint_url.split(":") + if len(split) != 2: + logger.warning("Illegal endpoint url format: {}".format(endpoint_url)) + raise BadNodeUrlError(endpoint_url) + + ip = split[0] + try: + port = int(split[1]) + result = TEndPoint(ip, port) + return result + except ValueError as e: + logger.warning("Illegal endpoint url format: {} ({})".format(endpoint_url, e)) + raise BadNodeUrlError(endpoint_url) + + +def get_status(status_code: TSStatusCode, message: str = None) -> TSStatus: + status = TSStatus(status_code.get_status_code()) + status.message = message + return status + + +def verify_success(status: TSStatus, err_msg: str) -> None: + if status.code != TSStatusCode.SUCCESS_STATUS.get_status_code(): + logger.warning(err_msg + ", error status is ", status) + raise RuntimeError(str(status.code) + ": " + status.message) + + +def pack_input_dict(batch_x: torch.Tensor, + batch_x_mark: torch.Tensor = None, + dec_inp: torch.Tensor = None, + batch_y_mark: torch.Tensor = None): + """ + pack up inputs as a dict to adapt for different models + """ + input_dict = {} + if batch_x is not None: + input_dict[ModelInputName.DATA_X.value] = batch_x + if batch_x_mark is not None: + input_dict[ModelInputName.TIME_STAMP_X] = batch_x_mark + if dec_inp is not None: + input_dict[ModelInputName.DEC_INP] = dec_inp + if batch_y_mark is not None: + input_dict[ModelInputName.TIME_STAMP_Y.value] = batch_y_mark + return input_dict diff --git a/iotdb-core/ainode/pom.xml b/iotdb-core/ainode/pom.xml new file mode 100644 index 000000000000..a6a18fbd1374 --- /dev/null +++ b/iotdb-core/ainode/pom.xml @@ -0,0 +1,347 @@ + + + + 4.0.0 + + org.apache.iotdb + iotdb-core + 1.3.3-SNAPSHOT + + iotdb-ainode + IoTDB: Core: AINode + + + + org.apache.iotdb + iotdb-thrift-commons + 1.3.3-SNAPSHOT + provided + + + org.apache.iotdb + iotdb-thrift + 1.3.3-SNAPSHOT + provided + + + org.apache.iotdb + iotdb-thrift-confignode + 1.3.3-SNAPSHOT + provided + + + org.apache.iotdb + iotdb-thrift-ainode + 1.3.3-SNAPSHOT + provided + + + + + + true + ${basedir}/resources + + pyproject.toml + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + + true + + + + + org.apache.maven.plugins + maven-clean-plugin + + + + iotdb + + thrift/ + + false + + + ./ + + LICENSE + + + + target + + **/*ainode* + + + + dist + + + + + + + pl.project13.maven + git-commit-id-plugin + + true + ${project.basedir}/iotdb/conf/git.properties + + ^git.commit.id.abbrev$ + ^git.dirty$ + + full + false + true + + -dev + + + + + + revision + + + + + + + org.apache.maven.plugins + maven-resources-plugin + + ${project.build.sourceEncoding} + + + + copy-thrift-python-resources + + copy-resources + + + generate-sources + + utf-8 + ${basedir}/iotdb/thrift/ + + + ${basedir}/../../iotdb-protocol/thrift-commons/target/generated-sources-python/iotdb/thrift/ + + + ${basedir}/../../iotdb-protocol/thrift-confignode/target/generated-sources-python/iotdb/thrift/ + + + ${basedir}/../../iotdb-protocol/thrift-ainode/target/generated-sources-python/iotdb/thrift/ + + + + + + copy-thrift-python-resources-datanode + + copy-resources + + + generate-sources + + utf-8 + ${basedir}/iotdb/thrift/datanode + + + ${basedir}/../../iotdb-protocol/thrift-datanode/target/generated-sources-python/iotdb/thrift/datanode/ + + + + + + + copy-pypi-file-resources + + copy-resources + + + generate-sources + + utf-8 + ${basedir}/ + + + ${basedir}/.. + + LICENSE + + + + + + + copy-pom-properties + + copy-resources + + generate-sources + + utf-8 + ${basedir}/iotdb/conf/ + + + ${basedir}/target/maven-archiver/ + + + + + + copy-pyproject-toml + + copy-resources + + generate-sources + + utf-8 + ${basedir}/ + + + ${basedir}/target/classes/ + + pyproject.toml + + + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + write-ai-node-version + + regex-property + + generate-resources + + ainode_version + -SNAPSHOT + ${project.version} + \.dev + false + + + + + + org.apache.maven.plugins + maven-antrun-plugin + 3.0.0 + + + copy-whl-file + package + + run + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + clean-temp-file + pre-integration-test + + run + + + + + + + + + + + + + + + + + + org.apache.maven.plugins + maven-assembly-plugin + 3.3.0 + + apache-iotdb-ainode-${project.version} + false + + ainode.xml + + + + + create-ainode-zip + package + + single + + + + + + + diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml new file mode 100644 index 000000000000..724017072595 --- /dev/null +++ b/iotdb-core/ainode/pyproject.toml @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "apache-iotdb-ainode" +version = "1.4.0.dev" +description = "Apache IoTDB AINode" +readme = "README.md" +authors = ["Apache Software Foundation "] +license = "Apache License, Version 2.0" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", +] +include = [ + "iotdb/thrift/*", + "iotdb/thrift/common/*", + "iotdb/thrift/confignode/*", + "iotdb/thrift/datanode/*", + "iotdb/thrift/ainode/*", + "iotdb/conf/*", +] +packages = [ + { include = "iotdb" } +] + +[tool.poetry.dependencies] +python = "^3.8" + +numpy = "^1.21.4" +pandas = "^1.3.5" +torch = "2.1.0" +pylru = "^1.2.1" + +thrift = "^0.13.0" +dynaconf = "^3.1.11" +requests = "^2.31.0" +optuna = "^3.2.0" +psutil = "^5.9.5" +sktime = "^0.24.1" +pmdarima = "^2.0.4" +hmmlearn = "^0.3.0" + +[tool.poetry.scripts] +ainode = "iotdb.ainode.script:main" \ No newline at end of file diff --git a/iotdb-core/ainode/resources/conf/ainode-env.bat b/iotdb-core/ainode/resources/conf/ainode-env.bat new file mode 100644 index 000000000000..6676b8a842b8 --- /dev/null +++ b/iotdb-core/ainode/resources/conf/ainode-env.bat @@ -0,0 +1,128 @@ +@REM +@REM Licensed to the Apache Software Foundation (ASF) under one +@REM or more contributor license agreements. See the NOTICE file +@REM distributed with this work for additional information +@REM regarding copyright ownership. The ASF licenses this file +@REM to you under the Apache License, Version 2.0 (the +@REM "License"); you may not use this file except in compliance +@REM with the License. You may obtain a copy of the License at +@REM +@REM http://www.apache.org/licenses/LICENSE-2.0 +@REM +@REM Unless required by applicable law or agreed to in writing, +@REM software distributed under the License is distributed on an +@REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +@REM KIND, either express or implied. See the License for the +@REM specific language governing permissions and limitations +@REM under the License. +@REM + +@echo off + +@REM The defaulte venv environment is used if ain_interpreter_dir is not set. Please use absolute path without quotation mark +@REM set ain_interpreter_dir= + +@REM Set ain_force_reinstall to 1 to force reinstall ainode +set ain_force_reinstall=0 + +@REM don't install dependencies online +set ain_install_offline=0 + +set ENV_SCRIPT_DIR=%~dp0 + +:initial +if "%1"=="" goto done +set aux=%1 +if "%aux:~0,2%"=="-r" ( + set ain_force_reinstall=1 + shift + goto initial +) +if "%aux:~0,2%"=="-n" ( + set ain_no_dependencies=--no-dependencies + shift + goto initial +) +if "%aux:~0,1%"=="-" ( + set nome=%aux:~1,250% +) else ( + set "%nome%=%1" + set nome= +) +shift +goto initial + +:done +@REM check if the parameters are set +if "%i%"=="" ( + echo No interpreter_dir is set, use default value. +) else ( + set ain_interpreter_dir=%i% +) + +echo Script got inputs: ain_interpreter_dir: %ain_interpreter_dir% , ain_force_reinstall: %ain_force_reinstall% +if "%ain_interpreter_dir%"=="" ( + %ENV_SCRIPT_DIR%//..//venv//Scripts//python.exe -c "import sys; print(sys.executable)" && ( + echo Activate default venv environment + ) || ( + echo Creating default venv environment + python -m venv "%ENV_SCRIPT_DIR%//..//venv" + ) + set ain_interpreter_dir="%ENV_SCRIPT_DIR%//..//venv//Scripts//python.exe" +) + +@REM Switch the working directory to the directory one level above the script +cd %ENV_SCRIPT_DIR%/../ + +echo Confirming ainode +%ain_interpreter_dir% -m pip config set global.disable-pip-version-check true +%ain_interpreter_dir% -m pip list | findstr /C:"apache-iotdb-ainode" >nul +if %errorlevel% == 0 ( + if %ain_force_reinstall% == 0 ( + echo ainode is already installed + exit /b 0 + ) +) + +set ain_only_ainode=1 +@REM if $ain_install_offline is 1 then do not install dependencies +if %ain_install_offline% == 1 ( + @REM if offline and not -n, then install dependencies + if "%ain_no_dependencies%"=="" ( + set ain_only_ainode=0 + ) else ( + set ain_only_ainode=1 + ) + set ain_no_dependencies=--no-dependencies + echo Installing ainode offline----without dependencies... +) + +if %ain_force_reinstall% == 1 ( + set ain_force_reinstall=--force-reinstall +) else ( + set ain_force_reinstall= +) + +echo Installing ainode... +@REM Print current work dir +cd lib +for %%i in (*.whl *.tar.gz) do ( + echo %%i | findstr "ainode" >nul && ( + echo Installing ainode body: %%i + %ain_interpreter_dir% -m pip install %%i %ain_force_reinstall% -i https://pypi.tuna.tsinghua.edu.cn/simple --no-warn-script-location %ain_no_dependencies% --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + ) || ( + @REM if ain_only_ainode is 0 then install dependencies + if %ain_only_ainode% == 0 ( + echo Installing dependencies: %%i + set ain_force_reinstall=--force-reinstall + %ain_interpreter_dir% -m pip install %%i %ain_force_reinstall% -i https://pypi.tuna.tsinghua.edu.cn/simple --no-warn-script-location %ain_no_dependencies% --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + ) + ) + if %errorlevel% == 1 ( + echo Failed to install ainode + exit /b 1 + ) +) +echo ainode is installed successfully +cd .. +exit /b 0 diff --git a/iotdb-core/ainode/resources/conf/ainode-env.sh b/iotdb-core/ainode/resources/conf/ainode-env.sh new file mode 100644 index 000000000000..6e9eec3d7a11 --- /dev/null +++ b/iotdb-core/ainode/resources/conf/ainode-env.sh @@ -0,0 +1,138 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# The defaulte venv environment is used if ain_interpreter_dir is not set. Please use absolute path without quotation mark +# ain_interpreter_dir= + +# Set ain_force_reinstall to 1 to force reinstall AINode +ain_force_reinstall=0 + +# don't install dependencies online +ain_install_offline=0 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" + +# fetch parameters with names +while getopts "i:t:rnm:" opt; do + case $opt in + i) + p_ain_interpreter_dir="$OPTARG" + ;; + r) + p_ain_force_reinstall=1 + ;; + t) ;; + n) + p_ain_no_dependencies="--no-dependencies" + ;; + m) + p_pypi_mirror="$OPTARG" + ;; + \?) + echo "Invalid option -$OPTARG" >&2 + exit 1 + ;; + esac +done + +if [ -z "$p_ain_interpreter_dir" ]; then + echo "No interpreter_dir is set, use default value." +else + ain_interpreter_dir="$p_ain_interpreter_dir" +fi + +if [ -z "$p_ain_force_reinstall" ]; then + echo "No check_version is set, use default value." +else + ain_force_reinstall="$p_ain_force_reinstall" +fi +echo Script got inputs: "ain_interpreter_dir: $ain_interpreter_dir", "ain_force_reinstall: $ain_force_reinstall" + +if [ -z $ain_interpreter_dir ]; then + $(dirname "$0")/../venv/bin/python3 -c "import sys; print(sys.executable)" && + echo "Activate default venv environment" || ( + echo "Creating default venv environment" && python3 -m venv "$(dirname "$0")/../venv" + ) + ain_interpreter_dir="$SCRIPT_DIR/../venv/bin/python3" +fi +echo "Calling venv to check: $ain_interpreter_dir" + +# Change the working directory to the parent directory +cd "$SCRIPT_DIR/.." + +echo "Confirming AINode..." +$ain_interpreter_dir -m pip config set global.disable-pip-version-check true +$ain_interpreter_dir -m pip list | grep "apache-iotdb-ainode" >/dev/null +if [ $? -eq 0 ]; then + if [ $ain_force_reinstall -eq 0 ]; then + echo "AINode is already installed" + exit 0 + fi +fi + +ain_only_ainode=1 + +# if $ain_install_offline is 1 then do not install dependencies +if [ $ain_install_offline -eq 1 ]; then + # if offline and not -n, then install dependencies + if [ -z "$p_ain_no_dependencies" ]; then + ain_only_ainode=0 + else + ain_only_ainode=1 + fi + p_ain_no_dependencies="--no-dependencies" + echo "Installing AINode offline----without dependencies..." +fi + +if [ $ain_force_reinstall -eq 1 ]; then + p_ain_force_reinstall="--force-reinstall" +else + p_ain_force_reinstall="" +fi + +echo "Installing AINode..." +cd "$SCRIPT_DIR/../lib/" +shopt -s nullglob +for i in *.whl *.tar.gz; do + if [[ $i =~ "ainode" ]]; then + echo Installing AINode body: $i + if [ -z "$p_pypi_mirror" ]; then + $ain_interpreter_dir -m pip install "$i" $p_ain_force_reinstall --no-warn-script-location $p_ain_no_dependencies + else + $ain_interpreter_dir -m pip install "$i" $p_ain_force_reinstall -i $p_pypi_mirror --no-warn-script-location $p_ain_no_dependencies + fi + else + # if ain_only_ainode is 0 then install dependencies + if [ $ain_only_ainode -eq 0 ]; then + echo Installing dependencies $i + if [ -z "$p_pypi_mirror" ]; then + $ain_interpreter_dir -m pip install "$i" $p_ain_force_reinstall --no-warn-script-location $p_ain_no_dependencies + else + $ain_interpreter_dir -m pip install "$i" $p_ain_force_reinstall -i $p_pypi_mirror --no-warn-script-location $p_ain_no_dependencies + fi + fi + fi + if [ $? -eq 1 ]; then + echo "Failed to install AINode" + exit 1 + fi +done +echo "AINode is installed successfully" +exit 0 diff --git a/iotdb-core/ainode/resources/conf/iotdb-ainode.properties b/iotdb-core/ainode/resources/conf/iotdb-ainode.properties new file mode 100644 index 000000000000..52b3dbb1bb0f --- /dev/null +++ b/iotdb-core/ainode/resources/conf/iotdb-ainode.properties @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# Used for indicate cluster name and distinguish different cluster. +# Datatype: string +cluster_name=defaultCluster + +# ConfigNode address registered at AINode startup +# Allow modifications only before starting the service for the first time +# Datatype: String +ain_seed_config_node=127.0.0.1:10710 + +# Used for connection of DataNode/ConfigNode clients +# Could set 127.0.0.1(for local test) or ipv4 address +# Datatype: String +ain_inference_rpc_address=127.0.0.1 + +# Used for connection of DataNode/ConfigNode clients +# Bind with MN_RPC_ADDRESS +# Datatype: String +ain_inference_rpc_port=10810 + +# The AINode metadata storage path. +# The starting directory of the relative path is related to the operating system. +# It is recommended to use an absolute path. +# Datatype: String +# ain_system_dir=data/ainode/system + +# The path where AINode stores model files +# The starting directory of the relative path is related to the operating system. +# It is recommended to use an absolute path. +# Datatype: String +# ain_models_dir=data/ainode/models + +# The path where AINode stores logs +# The starting directory of the relative path is related to the operating system. +# It is recommended to use an absolute path. +# Datatype: String +# ain_logs_dir=logs/ainode + +# Whether to use compression in Thrift +# Please use 0 or 1 +# Datatype: Boolean +# ain_thrift_compression_enabled=0 \ No newline at end of file diff --git a/iotdb-core/ainode/resources/sbin/remove-ainode.bat b/iotdb-core/ainode/resources/sbin/remove-ainode.bat new file mode 100644 index 000000000000..b97c688ea051 --- /dev/null +++ b/iotdb-core/ainode/resources/sbin/remove-ainode.bat @@ -0,0 +1,107 @@ +@REM +@REM Licensed to the Apache Software Foundation (ASF) under one +@REM or more contributor license agreements. See the NOTICE file +@REM distributed with this work for additional information +@REM regarding copyright ownership. The ASF licenses this file +@REM to you under the Apache License, Version 2.0 (the +@REM "License"); you may not use this file except in compliance +@REM with the License. You may obtain a copy of the License at +@REM +@REM http://www.apache.org/licenses/LICENSE-2.0 +@REM +@REM Unless required by applicable law or agreed to in writing, +@REM software distributed under the License is distributed on an +@REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +@REM KIND, either express or implied. See the License for the +@REM specific language governing permissions and limitations +@REM under the License. +@REM + +@echo off + +IF "%~1"=="--help" ( + echo The script will remove an AINode. + echo When it is necessary to move an already connected AINode out of the cluster, the corresponding removal script can be executed. + echo Usage: + echo Remove the AINode with ainode_id + echo ./sbin/remove-ainode.bat -t [ainode_id] + echo Remove the AINode with address:port + echo ./sbin/remove-ainode.bat -t [ain_inference_rpc_address:ain_inference_rpc_port] + echo. + echo Options: + echo ^ ^ -t = ainode_id or [ain_inference_rpc_address:ain_inference_rpc_port] + echo ^ ^ -i = When specifying the Python interpreter please enter the address of the executable file of the Python interpreter in the virtual environment. Currently AINode supports virtual environments such as venv, conda, etc. Inputting the system Python interpreter as the installation location is not supported. In order to ensure that scripts are recognized properly, please use absolute paths whenever possible! + EXIT /B 0 +) + +echo ``````````````````````````` +echo Removing IoTDB AINode +echo ``````````````````````````` + +set REMOVE_SCRIPT_DIR=%~dp0 +call %REMOVE_SCRIPT_DIR%\\..\\conf\\\ainode-env.bat %* +if %errorlevel% neq 0 ( + echo Environment check failed. Exiting... + exit /b 1 +) + +:initial +if "%1"=="" goto interpreter +set aux=%1 +if "%aux:~0,1%"=="-" ( + set nome=%aux:~1,250% +) else ( + set "%nome%=%1" + set nome= +) +shift +goto initial + +for /f "tokens=2 delims==" %%a in ('findstr /i /c:"^ain_interpreter_dir" "%REMOVE_SCRIPT_DIR%\\..\\conf\\\ainode-env.bat"') do ( + set _ain_interpreter_dir=%%a + goto :interpreter +) + +:interpreter +if "%i%"=="" ( + if "%_ain_interpreter_dir%"=="" ( + set _ain_interpreter_dir=%REMOVE_SCRIPT_DIR%\\..\\venv\\Scripts\\python.exe + ) +) else ( + set _ain_interpreter_dir=%i% +) + + +for /f "tokens=2 delims==" %%a in ('findstr /i /c:"^ain_system_dir" "%REMOVE_SCRIPT_DIR%\\..\\conf\\iotdb-\ainode.properties"') do ( + set _ain_system_dir=%%a + goto :system +) + +:system +if "%_ain_system_dir%"=="" ( + set _ain_system_dir=%REMOVE_SCRIPT_DIR%\\..\\data\\\ainode\\system +) + +echo Script got parameters: ain_interpreter_dir: %_ain_interpreter_dir%, ain_system_dir: %_ain_system_dir% + +cd %REMOVE_SCRIPT_DIR%\\.. +for %%i in ("%_ain_interpreter_dir%") do set "parent=%%~dpi" +set ain_\ainode_dir=%parent%\\\ainode.exe + +if "%t%"=="" ( + echo No target AINode set, use system.properties + %ain_\ainode_dir% remove +) else ( + %ain_\ainode_dir% remove %t% +) + +if %errorlevel% neq 0 ( + echo Remove AINode failed. Exiting... + exit /b 1 +) + +call %REMOVE_SCRIPT_DIR%\\stop-\ainode.bat %* + +rd /s /q %_ain_system_dir% + +pause \ No newline at end of file diff --git a/iotdb-core/ainode/resources/sbin/remove-ainode.sh b/iotdb-core/ainode/resources/sbin/remove-ainode.sh new file mode 100755 index 000000000000..dc6444b4c14b --- /dev/null +++ b/iotdb-core/ainode/resources/sbin/remove-ainode.sh @@ -0,0 +1,112 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +if [ "$#" -eq 1 ] && [ "$1" == "--help" ]; then + echo "The script will remove an AINode." + echo "When it is necessary to move an already connected AINode out of the cluster, the corresponding removal script can be executed." + echo "Usage:" + echo "Remove the AINode with ainode_id" + echo "./sbin/remove-ainode.sh -t [ainode_id]" + echo "Remove the AINode with address:port" + echo "./sbin/remove-ainode.sh -t [ain_inference_rpc_address:ain_inference_rpc_port]" + echo "" + echo "Options:" + echo " -t = ainode_id or [ain_inference_rpc_address:ain_inference_rpc_port]" + echo " -i = When specifying the Python interpreter please enter the address of the executable file of the Python interpreter in the virtual environment. Currently AINode supports virtual environments such as venv, conda, etc. Inputting the system Python interpreter as the installation location is not supported. In order to ensure that scripts are recognized properly, please use absolute paths whenever possible!" + exit 0 +fi + +echo --------------------------- +echo Removing IoTDB AINode +echo --------------------------- + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +echo "SCRIPT_DIR: $SCRIPT_DIR" +chmod u+x $(dirname "$0")/../conf/ainode-env.sh +ain_interpreter_dir=$(sed -n 's/^ain_interpreter_dir=\(.*\)$/\1/p' $(dirname "$0")/../conf/ainode-env.sh) +ain_system_dir=$(sed -n 's/^ain_system_dir=\(.*\)$/\1/p' $(dirname "$0")/../conf/iotdb-ainode.properties) +bash $(dirname "$0")/../conf/ainode-env.sh $* +if [ $? -eq 1 ]; then + echo "Environment check failed. Exiting..." + exit 1 +fi + +# fetch parameters with names +while getopts "i:t:rn" opt; do + case $opt in + i) p_ain_interpreter_dir="$OPTARG" + ;; + r) p_ain_force_reinstall="$OPTARG" + ;; + t) p_ain_remove_target="$OPTARG" + ;; + n) + ;; + \?) echo "Invalid option -$OPTARG" >&2 + exit 1 + ;; + esac +done + +# If ain_interpreter_dir in parameters is empty: +if [ -z "$p_ain_interpreter_dir" ]; then + # If ain_interpreter_dir in ../conf/ainode-env.sh is empty, set default value to ../venv/bin/python3 + if [ -z "$ain_interpreter_dir" ]; then + ain_interpreter_dir="$SCRIPT_DIR/../venv/bin/python3" + fi +else + # If ain_interpreter_dir in parameters is not empty, set ain_interpreter_dir to the value in parameters + ain_interpreter_dir="$p_ain_interpreter_dir" +fi + +# If ain_system_dir is empty, set default value to ../data/ainode/system +if [ -z "$ain_system_dir" ] +then + ain_system_dir="$SCRIPT_DIR/../data/ainode/system" +fi + +echo "Script got parameters: ain_interpreter_dir: $ain_interpreter_dir, ain_system_dir: $ain_system_dir" + +# check if ain_interpreter_dir is an absolute path +if [[ "$ain_interpreter_dir" != /* ]]; then + ain_interpreter_dir="$SCRIPT_DIR/$ain_interpreter_dir" +fi + +# Change the working directory to the parent directory +cd "$SCRIPT_DIR/.." +ain_ainode_dir=$(dirname "$ain_interpreter_dir")/ainode + + +if [ -z "$p_ain_remove_target" ]; then + echo No target AINode set, use system.properties + $ain_ainode_dir remove +else + $ain_ainode_dir remove $p_ain_remove_target +fi + +if [ $? -eq 1 ]; then + echo "Remove AINode failed. Exiting..." + exit 1 +fi + +bash $SCRIPT_DIR/stop-ainode.sh $* + +# Remove system directory +rm -rf $ain_system_dir \ No newline at end of file diff --git a/iotdb-core/ainode/resources/sbin/start-ainode.bat b/iotdb-core/ainode/resources/sbin/start-ainode.bat new file mode 100644 index 000000000000..e29109bbc4e4 --- /dev/null +++ b/iotdb-core/ainode/resources/sbin/start-ainode.bat @@ -0,0 +1,77 @@ +@REM +@REM Licensed to the Apache Software Foundation (ASF) under one +@REM or more contributor license agreements. See the NOTICE file +@REM distributed with this work for additional information +@REM regarding copyright ownership. The ASF licenses this file +@REM to you under the Apache License, Version 2.0 (the +@REM "License"); you may not use this file except in compliance +@REM with the License. You may obtain a copy of the License at +@REM +@REM http://www.apache.org/licenses/LICENSE-2.0 +@REM +@REM Unless required by applicable law or agreed to in writing, +@REM software distributed under the License is distributed on an +@REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +@REM KIND, either express or implied. See the License for the +@REM specific language governing permissions and limitations +@REM under the License. +@REM + +@echo off + +echo ``````````````````````````` +echo Starting IoTDB AINode +echo ``````````````````````````` + +set START_SCRIPT_DIR=%~dp0 +call %START_SCRIPT_DIR%\\..\\conf\\ainode-env.bat %* +if %errorlevel% neq 0 ( + echo Environment check failed. Exiting... + exit /b 1 +) + +for /f "tokens=2 delims==" %%a in ('findstr /i /c:"^ain_interpreter_dir" "%START_SCRIPT_DIR%\\..\\conf\\ainode-env.bat"') do ( + set _ain_interpreter_dir=%%a + goto :done +) + +:initial +if "%1"=="" goto done +set aux=%1 +if "%aux:~0,1%"=="-" ( + set nome=%aux:~1,250% +) else ( + set "%nome%=%1" + set nome= +) +shift +goto initial + +:done +if "%i%"=="" ( + if "%_ain_interpreter_dir%"=="" ( + set _ain_interpreter_dir=%START_SCRIPT_DIR%\\..\\venv\\Scripts\\python.exe + ) +) else ( + set _ain_interpreter_dir=%i% +) + +echo Script got parameter: ain_interpreter_dir: %_ain_interpreter_dir% + +cd %START_SCRIPT_DIR%\\.. + +for %%i in ("%_ain_interpreter_dir%") do set "parent=%%~dpi" + +set ain_ainode_dir=%parent%\ainode.exe + +set ain_ainode_dir_new=%parent%\Scripts\\ainode.exe + +echo Starting AINode... + +%ain_ainode_dir% start +if %errorlevel% neq 0 ( + echo ain_ainode_dir_new is %ain_ainode_dir_new% + %ain_ainode_dir_new% start +) + +pause \ No newline at end of file diff --git a/iotdb-core/ainode/resources/sbin/start-ainode.sh b/iotdb-core/ainode/resources/sbin/start-ainode.sh new file mode 100644 index 000000000000..dd1afbd8bda4 --- /dev/null +++ b/iotdb-core/ainode/resources/sbin/start-ainode.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +echo --------------------------- +echo Starting IoTDB AINode +echo --------------------------- + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)" +echo "SCRIPT_DIR: $SCRIPT_DIR" +chmod u+x $(dirname "$0")/../conf/ainode-env.sh +ain_interpreter_dir=$(sed -n 's/^ain_interpreter_dir=\(.*\)$/\1/p' $(dirname "$0")/../conf/ainode-env.sh) +bash $(dirname "$0")/../conf/ainode-env.sh $* +if [ $? -eq 1 ]; then + echo "Environment check failed. Exiting..." + exit 1 +fi + + +# fetch parameters with names +while getopts "i:rn" opt; do + case $opt in + i) p_ain_interpreter_dir="$OPTARG" + ;; + r) p_ain_force_reinstall="$OPTARG" + ;; + n) + ;; + \?) echo "Invalid option -$OPTARG" >&2 + exit 1 + ;; + esac +done + +# If ain_interpreter_dir in parameters is empty: +if [ -z "$p_ain_interpreter_dir" ]; then + # If ain_interpreter_dir in ../conf/ainode-env.sh is empty, set default value to ../venv/bin/python3 + if [ -z "$ain_interpreter_dir" ]; then + ain_interpreter_dir="$SCRIPT_DIR/../venv/bin/python3" + fi +else + # If ain_interpreter_dir in parameters is not empty, set ain_interpreter_dir to the value in parameters + ain_interpreter_dir="$p_ain_interpreter_dir" +fi + +# check if ain_interpreter_dir is an absolute path +if [[ "$ain_interpreter_dir" != /* ]]; then + ain_interpreter_dir="$SCRIPT_DIR/$ain_interpreter_dir" +fi + +echo Script got parameter: ain_interpreter_dir: $ain_interpreter_dir + +# Change the working directory to the parent directory +cd "$SCRIPT_DIR/.." + +ain_ainode_dir=$(dirname "$ain_interpreter_dir")/ainode + +echo Script got ainode dir: ain_ainode_dir: $ain_ainode_dir + +echo Starting AINode... + +$ain_ainode_dir start diff --git a/iotdb-core/ainode/resources/sbin/stop-ainode.bat b/iotdb-core/ainode/resources/sbin/stop-ainode.bat new file mode 100644 index 000000000000..a4f302b3f936 --- /dev/null +++ b/iotdb-core/ainode/resources/sbin/stop-ainode.bat @@ -0,0 +1,61 @@ +@REM +@REM Licensed to the Apache Software Foundation (ASF) under one +@REM or more contributor license agreements. See the NOTICE file +@REM distributed with this work for additional information +@REM regarding copyright ownership. The ASF licenses this file +@REM to you under the Apache License, Version 2.0 (the +@REM "License"); you may not use this file except in compliance +@REM with the License. You may obtain a copy of the License at +@REM +@REM http://www.apache.org/licenses/LICENSE-2.0 +@REM +@REM Unless required by applicable law or agreed to in writing, +@REM software distributed under the License is distributed on an +@REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +@REM KIND, either express or implied. See the License for the +@REM specific language governing permissions and limitations +@REM under the License. +@REM + +@echo off + +set current_dir=%~dp0 +set superior_dir=%current_dir%\..\ + +:initial +if "%1"=="" goto done +set aux=%1 +if "%aux:~0,1%"=="-" ( + set nome=%aux:~1,250% +) else ( + set "%nome%=%1" + set nome= +) +shift +goto initial + +:done +for /f "eol=# tokens=2 delims==" %%i in ('findstr /i "^ain_inference_rpc_port" +%superior_dir%\conf\iotdb-ainode.properties') do ( + set ain_inference_rpc_port=%%i +) + +echo Check whether the rpc_port is used..., port is %ain_inference_rpc_port% + +for /f "eol=# tokens=2 delims==" %%i in ('findstr /i "ain_inference_rpc_address" +%superior_dir%\conf\iotdb-ainode.properties') do ( + set ain_inference_rpc_address=%%i +) + +if defined t ( + for /f "tokens=2 delims=/" %%a in ("%t%") do set "ain_inference_rpc=%%a" +) else ( + set ain_inference_rpc=%ain_inference_rpc_address%:%ain_inference_rpc_port% +) + +echo Target AINode to be stopped: %ain_inference_rpc% + +for /f "tokens=5" %%a in ('netstat /ano ^| findstr /r /c:"^ *TCP *%ain_inference_rpc%.*$"') do ( + taskkill /f /pid %%a + echo Close AINode, PID: %%a +) diff --git a/iotdb-core/ainode/resources/sbin/stop-ainode.sh b/iotdb-core/ainode/resources/sbin/stop-ainode.sh new file mode 100644 index 000000000000..4580ce6fd3e2 --- /dev/null +++ b/iotdb-core/ainode/resources/sbin/stop-ainode.sh @@ -0,0 +1,73 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +AINODE_CONF="`dirname "$0"`/../conf" +ain_inference_rpc_port=`sed '/^ain_inference_rpc_port=/!d;s/.*=//' ${AINODE_CONF}/iotdb-ainode.properties` + +# fetch parameters with names +while getopts "i:t:r" opt; do + case $opt in + i) + ;; + r) + ;; + t) p_ain_remove_target="$OPTARG" + ;; + \?) echo "Invalid option -$OPTARG" >&2 + exit 1 + ;; + esac +done + +# If p_ain_remove_target exists, take the value after the colon of p_ain_remove_target as ain_inference_rpc_port +if [ -n "$p_ain_remove_target" ]; then + ain_inference_rpc_port=${p_ain_remove_target#*:} +fi + +echo "Check whether the rpc_port is used..., port is" $ain_inference_rpc_port + +if type lsof > /dev/null 2>&1 ; then + echo $(lsof -t -i:"${ain_inference_rpc_port}" -sTCP:LISTEN) + PID=$(lsof -t -i:"${ain_inference_rpc_port}" -sTCP:LISTEN) +elif type netstat > /dev/null 2>&1 ; then + PID=$(netstat -anp 2>/dev/null | grep ":${ain_inference_rpc_port} " | grep ' LISTEN ' | awk '{print $NF}' | sed "s|/.*||g" ) +else + echo "" + echo " Error: No necessary tool." + echo " Please install 'lsof' or 'netstat'." + exit 1 +fi + +PID_VERIFY=$(ps ax | grep -i 'ainode' | grep -v grep | awk '{print $1}') + +if [ -z "$PID" ]; then + echo "No AINode to stop" + if [ "$(id -u)" -ne 0 ]; then + echo "Maybe you can try to run in sudo mode to detect the process." + fi + exit 1 +elif [[ "${PID_VERIFY}" =~ ${PID} ]]; then + kill -s TERM "$PID" + echo "Stop AINode, PID:" "$PID" +else + echo "No AINode to stop" + exit 1 +fi + diff --git a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IdentifierParser.g4 b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IdentifierParser.g4 index c9cfc400dcee..fac1ad4b46e5 100644 --- a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IdentifierParser.g4 +++ b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IdentifierParser.g4 @@ -53,6 +53,7 @@ keyWords | BOUNDARY | BY | CACHE + | CALL | CASE | CAST | CHILD @@ -112,9 +113,11 @@ keyWords | GRANT | GROUP | HAVING + | HEAD | HYPERPARAMETERS | IN | INDEX + | INFERENCE | INFO | INSERT | INTO @@ -135,6 +138,9 @@ keyWords | MERGE | METADATA | MIGRATE + | AINODES + | MODEL + | MODELS | MODIFY | NAN | NODEID @@ -217,6 +223,7 @@ keyWords | SUBSTRING | SYSTEM | TAGS + | TAIL | TASK | TEMPLATE | TEMPLATES diff --git a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4 b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4 index 8986fce18e14..1d45ccb3d365 100644 --- a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4 +++ b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4 @@ -65,6 +65,8 @@ ddlStatement // Cluster | showVariables | showCluster | showRegions | showDataNodes | showConfigNodes | showClusterId | getRegionId | getTimeSlotList | countTimeSlotList | getSeriesSlotList | migrateRegion | verifyConnection + // AINode + | showAINodes | createModel | dropModel | showModels | callInference // Quota | setSpaceQuota | showSpaceQuota | setThrottleQuota | showThrottleQuota // View @@ -489,6 +491,11 @@ showConfigNodes : SHOW CONFIGNODES ; +// ---- Show AI Nodes +showAINodes + : SHOW AINODES + ; + // ---- Show Cluster Id showClusterId : SHOW CLUSTERID @@ -657,6 +664,42 @@ showSubscriptions : SHOW SUBSCRIPTIONS (ON topicName=identifier)? ; +// AI Model ========================================================================================= +// ---- Create Model +createModel + : CREATE MODEL modelName=identifier USING URI modelUri=STRING_LITERAL + ; + +windowFunction + : TAIL LR_BRACKET windowSize=INTEGER_LITERAL RR_BRACKET + | HEAD LR_BRACKET windowSize=INTEGER_LITERAL RR_BRACKET + | COUNT LR_BRACKET interval=INTEGER_LITERAL COMMA step=INTEGER_LITERAL RR_BRACKET + ; + +callInference + : CALL INFERENCE LR_BRACKET modelId=identifier COMMA inputSql=STRING_LITERAL (COMMA hparamPair)* RR_BRACKET + ; + +hparamPair + : hparamKey=attributeKey operator_eq hparamValue + ; + +hparamValue + : attributeValue + | windowFunction + ; + +// ---- Drop Model +dropModel + : DROP MODEL modelId=identifier + ; + +// ---- Show Models +showModels + : SHOW MODELS + | SHOW MODELS modelId=identifier + ; + // Create Logical View createLogicalView : CREATE VIEW viewTargetPaths AS viewSourcePaths diff --git a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4 b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4 index dc9c12c7c8c5..58a35d494f22 100644 --- a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4 +++ b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4 @@ -129,6 +129,10 @@ CACHE : C A C H E ; +CALL + : C A L L + ; + CAST : C A S T ; @@ -358,6 +362,10 @@ HAVING : H A V I N G ; +HEAD + : H E A D + ; + HYPERPARAMETERS : H Y P E R P A R A M E T E R S ; @@ -370,6 +378,10 @@ INDEX : I N D E X ; +INFERENCE + : I N F E R E N C E + ; + INFO : I N F O ; @@ -450,6 +462,18 @@ MIGRATE : M I G R A T E ; +AINODES + : A I N O D E S + ; + +MODEL + : M O D E L + ; + +MODELS + : M O D E L S + ; + MODIFY : M O D I F Y ; @@ -766,6 +790,10 @@ TAGS : T A G S ; +TAIL + : T A I L + ; + TASK : T A S K ; diff --git a/iotdb-core/confignode/pom.xml b/iotdb-core/confignode/pom.xml index 168815c89dff..a08b7944e8e9 100644 --- a/iotdb-core/confignode/pom.xml +++ b/iotdb-core/confignode/pom.xml @@ -84,6 +84,11 @@ iotdb-thrift-confignode 1.3.3-SNAPSHOT + + org.apache.iotdb + iotdb-thrift-ainode + 1.3.3-SNAPSHOT + org.apache.iotdb node-commons diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java new file mode 100644 index 000000000000..e09ccc79becb --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.client.async; + +import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatReq; +import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.ClientPoolFactory; +import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.commons.client.ainode.AsyncAINodeServiceClient; +import org.apache.iotdb.confignode.client.async.handlers.heartbeat.AINodeHeartbeatHandler; + +public class AsyncAINodeHeartbeatClientPool { + + private final IClientManager clientManager; + + private AsyncAINodeHeartbeatClientPool() { + clientManager = + new IClientManager.Factory() + .createClientManager( + new ClientPoolFactory.AsyncAINodeHeartbeatServiceClientPoolFactory()); + } + + public void getAINodeHeartBeat( + TEndPoint endPoint, TAIHeartbeatReq req, AINodeHeartbeatHandler handler) { + try { + clientManager.borrowClient(endPoint).getAIHeartbeat(req, handler); + } catch (Exception ignore) { + // Just ignore + } + } + + private static class AsyncAINodeHeartbeatClientPoolHolder { + + private static final AsyncAINodeHeartbeatClientPool INSTANCE = + new AsyncAINodeHeartbeatClientPool(); + + private AsyncAINodeHeartbeatClientPoolHolder() { + // Empty constructor + } + } + + public static AsyncAINodeHeartbeatClientPool getInstance() { + return AsyncAINodeHeartbeatClientPool.AsyncAINodeHeartbeatClientPoolHolder.INSTANCE; + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/handlers/heartbeat/AINodeHeartbeatHandler.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/handlers/heartbeat/AINodeHeartbeatHandler.java new file mode 100644 index 000000000000..9d8e0b6e8474 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/handlers/heartbeat/AINodeHeartbeatHandler.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.client.async.handlers.heartbeat; + +import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatResp; +import org.apache.iotdb.commons.client.ThriftClient; +import org.apache.iotdb.commons.cluster.NodeStatus; +import org.apache.iotdb.commons.cluster.NodeType; +import org.apache.iotdb.confignode.manager.load.LoadManager; +import org.apache.iotdb.confignode.manager.load.cache.node.NodeHeartbeatSample; + +import org.apache.thrift.async.AsyncMethodCallback; + +public class AINodeHeartbeatHandler implements AsyncMethodCallback { + + private final int nodeId; + + private final LoadManager loadManager; + + public AINodeHeartbeatHandler(int nodeId, LoadManager loadManager) { + this.nodeId = nodeId; + this.loadManager = loadManager; + } + + @Override + public void onComplete(TAIHeartbeatResp aiHeartbeatResp) { + loadManager + .getLoadCache() + .cacheAINodeHeartbeatSample(nodeId, new NodeHeartbeatSample(aiHeartbeatResp)); + } + + @Override + public void onError(Exception e) { + if (ThriftClient.isConnectionBroken(e)) { + loadManager.forceUpdateNodeCache( + NodeType.DataNode, nodeId, new NodeHeartbeatSample(NodeStatus.Unknown)); + } + loadManager.getLoadCache().resetHeartbeatProcessing(nodeId); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/conf/ConfigNodeConstant.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/conf/ConfigNodeConstant.java index 5724eb1862fe..17d23193b9d2 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/conf/ConfigNodeConstant.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/conf/ConfigNodeConstant.java @@ -34,6 +34,7 @@ public class ConfigNodeConstant { "Executed failed, check usage: /:"; public static final String REMOVE_DATANODE_PROCESS = "[REMOVE_DATANODE_PROCESS]"; + public static final String REMOVE_AINODE_PROCESS = "[REMOVE_AINODE_PROCESS]"; public static final String REGION_MIGRATE_PROCESS = "[REGION_MIGRATE_PROCESS]"; private ConfigNodeConstant() { diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java index 84728b3ca9b7..9cf54ed9f9ef 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlanType.java @@ -36,6 +36,12 @@ public enum ConfigPhysicalPlanType { RemoveDataNode((short) 102), UpdateDataNodeConfiguration((short) 103), + /** AINode. */ + RegisterAINode((short) 104), + UpdateAINodeConfiguration((short) 105), + RemoveAINode((short) 106), + GetAINodeConfiguration((short) 107), + /** Database. */ CreateDatabase((short) 200), SetTTL((short) 201), @@ -184,7 +190,14 @@ public enum ConfigPhysicalPlanType { UPDATE_CQ_LAST_EXEC_TIME((short) 1103), SHOW_CQ((short) 1104), - // 1200-1299 planId is used by IoTDB-ML. + /** AI model. */ + CreateModel((short) 1200), + UpdateModelInfo((short) 1201), + UpdateModelState((short) 1202), + DropModel((short) 1203), + ShowModel((short) 1204), + GetModelInfo((short) 1206), + DropModelInNode((short) 1207), /** Pipe Plugin. */ CreatePipePlugin((short) 1300), diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/ainode/GetAINodeConfigurationPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/ainode/GetAINodeConfigurationPlan.java new file mode 100644 index 000000000000..f13cbd276cfc --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/ainode/GetAINodeConfigurationPlan.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.request.read.ainode; + +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; + +public class GetAINodeConfigurationPlan extends ConfigPhysicalPlan { + + // if aiNodeId is set to -1, return all AINode configurations. + private int aiNodeId; + + public GetAINodeConfigurationPlan() { + super(ConfigPhysicalPlanType.GetAINodeConfiguration); + } + + public GetAINodeConfigurationPlan(int aiNodeId) { + this(); + this.aiNodeId = aiNodeId; + } + + public int getAiNodeId() { + return aiNodeId; + } + + @Override + protected void serializeImpl(DataOutputStream stream) throws IOException { + stream.writeShort(getType().getPlanType()); + stream.writeInt(aiNodeId); + } + + @Override + protected void deserializeImpl(ByteBuffer buffer) throws IOException { + this.aiNodeId = buffer.getInt(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof GetAINodeConfigurationPlan)) { + return false; + } + GetAINodeConfigurationPlan that = (GetAINodeConfigurationPlan) o; + return aiNodeId == that.aiNodeId; + } + + @Override + public int hashCode() { + return Integer.hashCode(aiNodeId); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java new file mode 100644 index 000000000000..9be7c59b1e7e --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.request.read.model; + +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; + +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +public class GetModelInfoPlan extends ConfigPhysicalPlan { + + private String modelId; + + public GetModelInfoPlan() { + super(ConfigPhysicalPlanType.GetModelInfo); + } + + public GetModelInfoPlan(TGetModelInfoReq getModelInfoReq) { + super(ConfigPhysicalPlanType.GetModelInfo); + this.modelId = getModelInfoReq.getModelId(); + } + + public String getModelId() { + return modelId; + } + + @Override + protected void serializeImpl(DataOutputStream stream) throws IOException { + stream.writeShort(getType().getPlanType()); + ReadWriteIOUtils.write(modelId, stream); + } + + @Override + protected void deserializeImpl(ByteBuffer buffer) throws IOException { + this.modelId = ReadWriteIOUtils.readString(buffer); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + GetModelInfoPlan that = (GetModelInfoPlan) o; + return Objects.equals(modelId, that.modelId); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), modelId); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java new file mode 100644 index 000000000000..b9dddf0cf491 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.request.read.model; + +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; +import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq; + +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +public class ShowModelPlan extends ConfigPhysicalPlan { + + private String modelName; + + public ShowModelPlan() { + super(ConfigPhysicalPlanType.ShowModel); + } + + public ShowModelPlan(TShowModelReq showModelReq) { + super(ConfigPhysicalPlanType.ShowModel); + if (showModelReq.isSetModelId()) { + this.modelName = showModelReq.getModelId(); + } + } + + public boolean isSetModelName() { + return modelName != null; + } + + public String getModelName() { + return modelName; + } + + @Override + protected void serializeImpl(DataOutputStream stream) throws IOException { + stream.writeShort(getType().getPlanType()); + ReadWriteIOUtils.write(modelName != null, stream); + ReadWriteIOUtils.write(modelName, stream); + } + + @Override + protected void deserializeImpl(ByteBuffer buffer) throws IOException { + boolean isSetModelId = ReadWriteIOUtils.readBool(buffer); + if (isSetModelId) { + this.modelName = ReadWriteIOUtils.readString(buffer); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + ShowModelPlan that = (ShowModelPlan) o; + return Objects.equals(modelName, that.modelName); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), modelName); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/ainode/RegisterAINodePlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/ainode/RegisterAINodePlan.java new file mode 100644 index 000000000000..5f5cb9ae1ecf --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/ainode/RegisterAINodePlan.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.request.write.ainode; + +import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; +import org.apache.iotdb.commons.utils.ThriftCommonsSerDeUtils; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +public class RegisterAINodePlan extends ConfigPhysicalPlan { + + private TAINodeConfiguration aiNodeConfiguration; + + public RegisterAINodePlan() { + super(ConfigPhysicalPlanType.RegisterAINode); + } + + public RegisterAINodePlan(TAINodeConfiguration aiNodeConfiguration) { + this(); + this.aiNodeConfiguration = aiNodeConfiguration; + } + + public TAINodeConfiguration getAINodeConfiguration() { + return aiNodeConfiguration; + } + + @Override + protected void serializeImpl(DataOutputStream stream) throws IOException { + stream.writeShort(getType().getPlanType()); + ThriftCommonsSerDeUtils.serializeTAINodeInfo(aiNodeConfiguration, stream); + } + + @Override + protected void deserializeImpl(ByteBuffer buffer) throws IOException { + aiNodeConfiguration = ThriftCommonsSerDeUtils.deserializeTAINodeInfo(buffer); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RegisterAINodePlan that = (RegisterAINodePlan) o; + return aiNodeConfiguration.equals(that.aiNodeConfiguration); + } + + @Override + public int hashCode() { + return Objects.hash(aiNodeConfiguration); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/ainode/RemoveAINodePlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/ainode/RemoveAINodePlan.java new file mode 100644 index 000000000000..92bfb8b7017f --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/ainode/RemoveAINodePlan.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.request.write.ainode; + +import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; +import org.apache.iotdb.commons.utils.ThriftCommonsSerDeUtils; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +public class RemoveAINodePlan extends ConfigPhysicalPlan { + + private TAINodeLocation aiNodeLocation; + + public RemoveAINodePlan() { + super(ConfigPhysicalPlanType.RemoveAINode); + } + + public RemoveAINodePlan(TAINodeLocation taiNodeLocation) { + this(); + this.aiNodeLocation = taiNodeLocation; + } + + @Override + protected void serializeImpl(DataOutputStream stream) throws IOException { + stream.writeShort(getType().getPlanType()); + ThriftCommonsSerDeUtils.serializeTAINodeLocation(aiNodeLocation, stream); + } + + @Override + protected void deserializeImpl(ByteBuffer buffer) throws IOException { + this.aiNodeLocation = ThriftCommonsSerDeUtils.deserializeTAINodeLocation(buffer); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + RemoveAINodePlan that = (RemoveAINodePlan) o; + return aiNodeLocation.equals(that.aiNodeLocation); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), aiNodeLocation); + } + + public TAINodeLocation getAINodeLocation() { + return aiNodeLocation; + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/ainode/UpdateAINodePlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/ainode/UpdateAINodePlan.java new file mode 100644 index 000000000000..5ef885551d7a --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/ainode/UpdateAINodePlan.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.request.write.ainode; + +import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; +import org.apache.iotdb.commons.utils.ThriftCommonsSerDeUtils; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +public class UpdateAINodePlan extends ConfigPhysicalPlan { + + private TAINodeConfiguration aiNodeConfiguration; + + public UpdateAINodePlan() { + super(ConfigPhysicalPlanType.UpdateAINodeConfiguration); + } + + public UpdateAINodePlan(TAINodeConfiguration aiNodeConfiguration) { + this(); + this.aiNodeConfiguration = aiNodeConfiguration; + } + + public TAINodeConfiguration getAINodeConfiguration() { + return aiNodeConfiguration; + } + + @Override + protected void serializeImpl(DataOutputStream stream) throws IOException { + stream.writeShort(getType().getPlanType()); + ThriftCommonsSerDeUtils.serializeTAINodeConfiguration(aiNodeConfiguration, stream); + } + + @Override + protected void deserializeImpl(ByteBuffer buffer) throws IOException { + aiNodeConfiguration = ThriftCommonsSerDeUtils.deserializeTAINodeConfiguration(buffer); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!getType().equals(((UpdateAINodePlan) o).getType())) { + return false; + } + UpdateAINodePlan that = (UpdateAINodePlan) o; + return aiNodeConfiguration.equals(that.aiNodeConfiguration); + } + + @Override + public int hashCode() { + return Objects.hash(getType(), aiNodeConfiguration); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java new file mode 100644 index 000000000000..61e37cdd2187 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.request.write.model; + +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; + +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +public class CreateModelPlan extends ConfigPhysicalPlan { + + private String modelName; + + public CreateModelPlan() { + super(ConfigPhysicalPlanType.CreateModel); + } + + public CreateModelPlan(String modelName) { + super(ConfigPhysicalPlanType.CreateModel); + this.modelName = modelName; + } + + public String getModelName() { + return modelName; + } + + @Override + protected void serializeImpl(DataOutputStream stream) throws IOException { + stream.writeShort(getType().getPlanType()); + ReadWriteIOUtils.write(modelName, stream); + } + + @Override + protected void deserializeImpl(ByteBuffer buffer) throws IOException { + modelName = ReadWriteIOUtils.readString(buffer); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + CreateModelPlan that = (CreateModelPlan) o; + return Objects.equals(modelName, that.modelName); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), modelName); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java new file mode 100644 index 000000000000..885543f84e15 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.request.write.model; + +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +public class DropModelInNodePlan extends ConfigPhysicalPlan { + + private int nodeId; + + public DropModelInNodePlan() { + super(ConfigPhysicalPlanType.DropModelInNode); + } + + public DropModelInNodePlan(int nodeId) { + super(ConfigPhysicalPlanType.DropModelInNode); + this.nodeId = nodeId; + } + + public int getNodeId() { + return nodeId; + } + + @Override + protected void serializeImpl(DataOutputStream stream) throws IOException { + stream.writeShort(getType().getPlanType()); + stream.writeInt(nodeId); + } + + @Override + protected void deserializeImpl(ByteBuffer buffer) throws IOException { + nodeId = buffer.getInt(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof DropModelInNodePlan)) return false; + DropModelInNodePlan that = (DropModelInNodePlan) o; + return nodeId == that.nodeId; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), nodeId); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java new file mode 100644 index 000000000000..813b116c645c --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.request.write.model; + +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; + +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +public class DropModelPlan extends ConfigPhysicalPlan { + + private String modelName; + + public DropModelPlan() { + super(ConfigPhysicalPlanType.DropModel); + } + + public DropModelPlan(String modelName) { + super(ConfigPhysicalPlanType.DropModel); + this.modelName = modelName; + } + + public String getModelName() { + return modelName; + } + + @Override + protected void serializeImpl(DataOutputStream stream) throws IOException { + stream.writeShort(getType().getPlanType()); + ReadWriteIOUtils.write(modelName, stream); + } + + @Override + protected void deserializeImpl(ByteBuffer buffer) throws IOException { + modelName = ReadWriteIOUtils.readString(buffer); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + DropModelPlan that = (DropModelPlan) o; + return modelName.equals(that.modelName); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), modelName); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java new file mode 100644 index 000000000000..ca74d2daf69d --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.request.write.model; + +import org.apache.iotdb.commons.model.ModelInformation; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; +import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; + +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class UpdateModelInfoPlan extends ConfigPhysicalPlan { + + private String modelName; + private ModelInformation modelInformation; + + // The node which has the model which is only updated in model registration + private List nodeIds; + + public UpdateModelInfoPlan() { + super(ConfigPhysicalPlanType.UpdateModelInfo); + } + + public UpdateModelInfoPlan(String modelName, ModelInformation modelInformation) { + super(ConfigPhysicalPlanType.UpdateModelInfo); + this.modelName = modelName; + this.modelInformation = modelInformation; + this.nodeIds = Collections.emptyList(); + } + + public UpdateModelInfoPlan( + String modelName, ModelInformation modelInformation, List nodeIds) { + super(ConfigPhysicalPlanType.UpdateModelInfo); + this.modelName = modelName; + this.modelInformation = modelInformation; + this.nodeIds = nodeIds; + } + + public String getModelName() { + return modelName; + } + + public ModelInformation getModelInformation() { + return modelInformation; + } + + public List getNodeIds() { + return nodeIds; + } + + @Override + protected void serializeImpl(DataOutputStream stream) throws IOException { + stream.writeShort(getType().getPlanType()); + ReadWriteIOUtils.write(modelName, stream); + this.modelInformation.serialize(stream); + ReadWriteIOUtils.write(nodeIds.size(), stream); + for (Integer nodeId : nodeIds) { + ReadWriteIOUtils.write(nodeId, stream); + } + } + + @Override + protected void deserializeImpl(ByteBuffer buffer) throws IOException { + this.modelName = ReadWriteIOUtils.readString(buffer); + this.modelInformation = ModelInformation.deserialize(buffer); + int size = ReadWriteIOUtils.readInt(buffer); + this.nodeIds = new ArrayList<>(); + for (int i = 0; i < size; i++) { + this.nodeIds.add(ReadWriteIOUtils.readInt(buffer)); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + UpdateModelInfoPlan that = (UpdateModelInfoPlan) o; + return modelName.equals(that.modelName) + && modelInformation.equals(that.modelInformation) + && nodeIds.equals(that.nodeIds); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), modelName, modelInformation, nodeIds); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ainode/AINodeConfigurationResp.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ainode/AINodeConfigurationResp.java new file mode 100644 index 000000000000..018ed605bb31 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ainode/AINodeConfigurationResp.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.response.ainode; + +import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeConfigurationResp; +import org.apache.iotdb.consensus.common.DataSet; +import org.apache.iotdb.rpc.TSStatusCode; + +import java.util.Map; + +public class AINodeConfigurationResp implements DataSet { + + private TSStatus status; + private Map aiNodeConfigurationMap; + + public AINodeConfigurationResp() { + // empty constructor + } + + public void setStatus(TSStatus status) { + this.status = status; + } + + public TSStatus getStatus() { + return status; + } + + public void setAiNodeConfigurationMap(Map aiNodeConfigurationMap) { + this.aiNodeConfigurationMap = aiNodeConfigurationMap; + } + + public void convertToRpcAINodeLocationResp(TAINodeConfigurationResp resp) { + resp.setStatus(status); + if (status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + resp.setAiNodeConfigurationMap(aiNodeConfigurationMap); + } + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ainode/AINodeRegisterResp.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ainode/AINodeRegisterResp.java new file mode 100644 index 000000000000..c5b9e4b02260 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ainode/AINodeRegisterResp.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.response.ainode; + +import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterResp; +import org.apache.iotdb.consensus.common.DataSet; +import org.apache.iotdb.rpc.TSStatusCode; + +import java.util.List; + +public class AINodeRegisterResp implements DataSet { + + private TSStatus status; + private List configNodeList; + private Integer aiNodeId; + + public AINodeRegisterResp() { + this.aiNodeId = null; + } + + public TSStatus getStatus() { + return status; + } + + public void setStatus(TSStatus status) { + this.status = status; + } + + public void setConfigNodeList(List configNodeList) { + this.configNodeList = configNodeList; + } + + public void setAINodeId(Integer aiNodeId) { + this.aiNodeId = aiNodeId; + } + + public TAINodeRegisterResp convertToAINodeRegisterResp() { + TAINodeRegisterResp resp = new TAINodeRegisterResp(); + resp.setStatus(status); + resp.setConfigNodeList(configNodeList); + + if (status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + resp.setAiNodeId(aiNodeId); + } + + return resp; + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java new file mode 100644 index 000000000000..14101b95d123 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.response.model; + +import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; +import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; +import org.apache.iotdb.consensus.common.DataSet; + +import java.nio.ByteBuffer; + +public class GetModelInfoResp implements DataSet { + + private final TSStatus status; + private ByteBuffer serializedModelInformation; + + private int targetAINodeId; + private TEndPoint targetAINodeAddress; + + public TSStatus getStatus() { + return status; + } + + public GetModelInfoResp(TSStatus status) { + this.status = status; + } + + public void setModelInfo(ByteBuffer serializedModelInformation) { + this.serializedModelInformation = serializedModelInformation; + } + + public int getTargetAINodeId() { + return targetAINodeId; + } + + public void setTargetAINodeId(int targetAINodeId) { + this.targetAINodeId = targetAINodeId; + } + + public void setTargetAINodeAddress(TAINodeConfiguration aiNodeConfiguration) { + if (aiNodeConfiguration.getLocation() == null) { + return; + } + this.targetAINodeAddress = aiNodeConfiguration.getLocation().getInternalEndPoint(); + } + + public TGetModelInfoResp convertToThriftResponse() { + TGetModelInfoResp resp = new TGetModelInfoResp(status); + resp.setModelInfo(serializedModelInformation); + resp.setAiNodeAddress(targetAINodeAddress); + return resp; + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java new file mode 100644 index 000000000000..9a23d9ed7130 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.consensus.response.model; + +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.model.ModelInformation; +import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp; +import org.apache.iotdb.consensus.common.DataSet; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +public class ModelTableResp implements DataSet { + + private final TSStatus status; + private final List serializedAllModelInformation; + + public ModelTableResp(TSStatus status) { + this.status = status; + this.serializedAllModelInformation = new ArrayList<>(); + } + + public void addModelInformation(List modelInformationList) throws IOException { + for (ModelInformation modelInformation : modelInformationList) { + this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult()); + } + } + + public void addModelInformation(ModelInformation modelInformation) throws IOException { + this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult()); + } + + public TShowModelResp convertToThriftResponse() throws IOException { + return new TShowModelResp(status, serializedAllModelInformation); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java index 59dacc5ad5ca..2e40eb4c572e 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java @@ -19,6 +19,8 @@ package org.apache.iotdb.confignode.manager; +import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; +import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId; import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration; @@ -57,6 +59,7 @@ import org.apache.iotdb.confignode.conf.SystemPropertiesUtils; import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; import org.apache.iotdb.confignode.consensus.request.auth.AuthorPlan; +import org.apache.iotdb.confignode.consensus.request.read.ainode.GetAINodeConfigurationPlan; import org.apache.iotdb.confignode.consensus.request.read.database.CountDatabasePlan; import org.apache.iotdb.confignode.consensus.request.read.database.GetDatabasePlan; import org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan; @@ -67,6 +70,7 @@ import org.apache.iotdb.confignode.consensus.request.read.partition.GetSchemaPartitionPlan; import org.apache.iotdb.confignode.consensus.request.read.region.GetRegionInfoListPlan; import org.apache.iotdb.confignode.consensus.request.read.ttl.ShowTTLPlan; +import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; import org.apache.iotdb.confignode.consensus.request.write.confignode.RemoveConfigNodePlan; import org.apache.iotdb.confignode.consensus.request.write.database.DatabaseSchemaPlan; import org.apache.iotdb.confignode.consensus.request.write.database.SetDataReplicationFactorPlan; @@ -75,6 +79,7 @@ import org.apache.iotdb.confignode.consensus.request.write.database.SetTimePartitionIntervalPlan; import org.apache.iotdb.confignode.consensus.request.write.datanode.RemoveDataNodePlan; import org.apache.iotdb.confignode.consensus.request.write.template.CreateSchemaTemplatePlan; +import org.apache.iotdb.confignode.consensus.response.ainode.AINodeRegisterResp; import org.apache.iotdb.confignode.consensus.response.auth.PermissionInfoResp; import org.apache.iotdb.confignode.consensus.response.database.CountDatabaseResp; import org.apache.iotdb.confignode.consensus.response.database.DatabaseSchemaResp; @@ -105,6 +110,7 @@ import org.apache.iotdb.confignode.manager.subscription.SubscriptionManager; import org.apache.iotdb.confignode.persistence.AuthorInfo; import org.apache.iotdb.confignode.persistence.ClusterInfo; +import org.apache.iotdb.confignode.persistence.ModelInfo; import org.apache.iotdb.confignode.persistence.ProcedureInfo; import org.apache.iotdb.confignode.persistence.TTLInfo; import org.apache.iotdb.confignode.persistence.TriggerInfo; @@ -118,6 +124,9 @@ import org.apache.iotdb.confignode.persistence.schema.ClusterSchemaInfo; import org.apache.iotdb.confignode.persistence.subscription.SubscriptionInfo; import org.apache.iotdb.confignode.procedure.impl.schema.SchemaUtils; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp; import org.apache.iotdb.confignode.rpc.thrift.TAlterLogicalViewReq; import org.apache.iotdb.confignode.rpc.thrift.TAlterPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TAlterSchemaTemplateReq; @@ -131,6 +140,7 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; +import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -146,6 +156,7 @@ import org.apache.iotdb.confignode.rpc.thrift.TDeleteLogicalViewReq; import org.apache.iotdb.confignode.rpc.thrift.TDeleteTimeSeriesReq; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; +import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropTopicReq; @@ -159,6 +170,8 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -181,11 +194,14 @@ import org.apache.iotdb.confignode.rpc.thrift.TSchemaPartitionTableResp; import org.apache.iotdb.confignode.rpc.thrift.TSetDataNodeStatusReq; import org.apache.iotdb.confignode.rpc.thrift.TSetSchemaTemplateReq; +import org.apache.iotdb.confignode.rpc.thrift.TShowAINodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowCQResp; import org.apache.iotdb.confignode.rpc.thrift.TShowClusterResp; import org.apache.iotdb.confignode.rpc.thrift.TShowConfigNodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowDataNodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowDatabaseResp; +import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq; +import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp; import org.apache.iotdb.confignode.rpc.thrift.TShowPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TShowPipeResp; import org.apache.iotdb.confignode.rpc.thrift.TShowSubscriptionReq; @@ -277,6 +293,9 @@ public class ConfigManager implements IManager { /** CQ. */ private final CQManager cqManager; + /** AI Model. */ + private final ModelManager modelManager; + /** Pipe */ private final PipeManager pipeManager; @@ -306,6 +325,7 @@ public ConfigManager() throws IOException { UDFInfo udfInfo = new UDFInfo(); TriggerInfo triggerInfo = new TriggerInfo(); CQInfo cqInfo = new CQInfo(); + ModelInfo modelInfo = new ModelInfo(); PipeInfo pipeInfo = new PipeInfo(); QuotaInfo quotaInfo = new QuotaInfo(); TTLInfo ttlInfo = new TTLInfo(); @@ -323,6 +343,7 @@ public ConfigManager() throws IOException { udfInfo, triggerInfo, cqInfo, + modelInfo, pipeInfo, subscriptionInfo, quotaInfo, @@ -344,6 +365,7 @@ public ConfigManager() throws IOException { this.udfManager = new UDFManager(this, udfInfo); this.triggerManager = new TriggerManager(this, triggerInfo); this.cqManager = new CQManager(this); + this.modelManager = new ModelManager(this, modelInfo); this.pipeManager = new PipeManager(this, pipeInfo); this.subscriptionManager = new SubscriptionManager(this, subscriptionInfo); @@ -401,7 +423,7 @@ public DataSet getSystemConfiguration() { } @Override - public DataSet registerDataNode(TDataNodeRegisterReq req) { + public synchronized DataSet registerDataNode(TDataNodeRegisterReq req) { TSStatus status = confirmLeader(); if (status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { status = ClusterNodeStartUtils.confirmDataNodeRegistration(req, this); @@ -447,6 +469,36 @@ public DataSet removeDataNode(RemoveDataNodePlan removeDataNodePlan) { } } + @Override + public TAINodeRestartResp restartAINode(TAINodeRestartReq req) { + TSStatus status = confirmLeader(); + if (status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + status = + ClusterNodeStartUtils.confirmNodeRestart( + NodeType.AINode, + req.getClusterName(), + req.getAiNodeConfiguration().getLocation().getAiNodeId(), + req.getAiNodeConfiguration().getLocation(), + this); + if (status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + return nodeManager.updateAINodeIfNecessary(req); + } + } + return new TAINodeRestartResp() + .setStatus(status) + .setConfigNodeList(getNodeManager().getRegisteredConfigNodes()); + } + + @Override + public TSStatus removeAINode(RemoveAINodePlan removeAINodePlan) { + TSStatus status = confirmLeader(); + if (status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + return nodeManager.removeAINode(removeAINodePlan); + } else { + return status; + } + } + @Override public TSStatus reportDataNodeShutdown(TDataNodeLocation dataNodeLocation) { TSStatus status = confirmLeader(); @@ -477,6 +529,11 @@ public DataSet getDataNodeConfiguration( } } + @Override + public DataSet getAINodeConfiguration(GetAINodeConfigurationPlan getAINodeConfigurationPlan) { + return nodeManager.getAINodeConfiguration(getAINodeConfigurationPlan); + } + @Override public TShowClusterResp showCluster() { TSStatus status = confirmLeader(); @@ -499,10 +556,22 @@ public TShowClusterResp showCluster() { nodeStatus.putIfAbsent( dataNodeLocation.getDataNodeId(), NodeStatus.Unknown.toString())); + List aiNodeLocations = + getNodeManager().getRegisteredAINodes().stream() + .map(TAINodeConfiguration::getLocation) + .sorted(Comparator.comparingInt(TAINodeLocation::getAiNodeId)) + .collect(Collectors.toList()); + Map nodeStatusMap = getLoadManager().getNodeStatusWithReason(); + aiNodeLocations.forEach( + aiNodeLocation -> + nodeStatusMap.putIfAbsent( + aiNodeLocation.getAiNodeId(), NodeStatus.Unknown.toString())); + return new TShowClusterResp() .setStatus(status) .setConfigNodeList(configNodeLocations) .setDataNodeList(dataNodeLocations) + .setAiNodeList(aiNodeLocations) .setNodeStatus(nodeStatus) .setNodeVersionInfo(nodeVersionInfo); } else { @@ -510,6 +579,7 @@ public TShowClusterResp showCluster() { .setStatus(status) .setConfigNodeList(Collections.emptyList()) .setDataNodeList(Collections.emptyList()) + .setAiNodeList(Collections.emptyList()) .setNodeStatus(Collections.emptyMap()) .setNodeVersionInfo(Collections.emptyMap()); } @@ -1081,6 +1151,11 @@ public TriggerManager getTriggerManager() { return triggerManager; } + @Override + public ModelManager getModelManager() { + return modelManager; + } + @Override public PipeManager getPipeManager() { return pipeManager; @@ -1659,6 +1734,18 @@ public RegionInfoListResp showRegion(GetRegionInfoListPlan getRegionInfoListPlan } } + @Override + public TShowAINodesResp showAINodes() { + TSStatus status = confirmLeader(); + TShowAINodesResp resp = new TShowAINodesResp(); + if (status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + return resp.setAiNodesInfoList(nodeManager.getRegisteredAINodeInfoList()) + .setStatus(StatusUtils.OK); + } else { + return resp.setStatus(status); + } + } + @Override public TShowDataNodesResp showDataNodes() { TSStatus status = confirmLeader(); @@ -2325,6 +2412,42 @@ public TSStatus transfer(List newUnknownDataList) { return transferResult; } + @Override + public TSStatus createModel(TCreateModelReq req) { + TSStatus status = confirmLeader(); + if (nodeManager.getRegisteredAINodes().isEmpty()) { + return new TSStatus(TSStatusCode.NO_REGISTERED_AI_NODE_ERROR.getStatusCode()) + .setMessage("There is no available AINode! Try to start one."); + } + return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() + ? modelManager.createModel(req) + : status; + } + + @Override + public TSStatus dropModel(TDropModelReq req) { + TSStatus status = confirmLeader(); + return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() + ? modelManager.dropModel(req) + : status; + } + + @Override + public TShowModelResp showModel(TShowModelReq req) { + TSStatus status = confirmLeader(); + return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() + ? modelManager.showModel(req) + : new TShowModelResp(status, Collections.emptyList()); + } + + @Override + public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { + TSStatus status = confirmLeader(); + return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() + ? modelManager.getModelInfo(req) + : new TGetModelInfoResp(status); + } + @Override public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) { TSStatus status = confirmLeader(); @@ -2367,4 +2490,19 @@ public TThrottleQuotaResp getThrottleQuota() { ? clusterQuotaManager.getThrottleQuota() : new TThrottleQuotaResp(status); } + + @Override + public DataSet registerAINode(TAINodeRegisterReq req) { + TSStatus status = confirmLeader(); + if (status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + status = ClusterNodeStartUtils.confirmAINodeRegistration(req, this); + if (status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + return nodeManager.registerAINode(req); + } + } + AINodeRegisterResp resp = new AINodeRegisterResp(); + resp.setStatus(status); + resp.setConfigNodeList(getNodeManager().getRegisteredConfigNodes()); + return resp; + } } diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java index 366be81c9a34..336ecd5694dc 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java @@ -30,6 +30,7 @@ import org.apache.iotdb.commons.path.PartialPath; import org.apache.iotdb.commons.path.PathPatternTree; import org.apache.iotdb.confignode.consensus.request.auth.AuthorPlan; +import org.apache.iotdb.confignode.consensus.request.read.ainode.GetAINodeConfigurationPlan; import org.apache.iotdb.confignode.consensus.request.read.database.CountDatabasePlan; import org.apache.iotdb.confignode.consensus.request.read.database.GetDatabasePlan; import org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan; @@ -37,6 +38,7 @@ import org.apache.iotdb.confignode.consensus.request.read.partition.GetOrCreateDataPartitionPlan; import org.apache.iotdb.confignode.consensus.request.read.region.GetRegionInfoListPlan; import org.apache.iotdb.confignode.consensus.request.read.ttl.ShowTTLPlan; +import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; import org.apache.iotdb.confignode.consensus.request.write.confignode.RemoveConfigNodePlan; import org.apache.iotdb.confignode.consensus.request.write.database.DatabaseSchemaPlan; import org.apache.iotdb.confignode.consensus.request.write.database.SetDataReplicationFactorPlan; @@ -52,6 +54,9 @@ import org.apache.iotdb.confignode.manager.pipe.coordinator.PipeManager; import org.apache.iotdb.confignode.manager.schema.ClusterSchemaManager; import org.apache.iotdb.confignode.manager.subscription.SubscriptionManager; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp; import org.apache.iotdb.confignode.rpc.thrift.TAlterLogicalViewReq; import org.apache.iotdb.confignode.rpc.thrift.TAlterPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TAlterSchemaTemplateReq; @@ -63,6 +68,7 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; +import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -77,6 +83,7 @@ import org.apache.iotdb.confignode.rpc.thrift.TDeleteLogicalViewReq; import org.apache.iotdb.confignode.rpc.thrift.TDeleteTimeSeriesReq; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; +import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropTopicReq; @@ -90,6 +97,8 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -111,11 +120,14 @@ import org.apache.iotdb.confignode.rpc.thrift.TSchemaPartitionTableResp; import org.apache.iotdb.confignode.rpc.thrift.TSetDataNodeStatusReq; import org.apache.iotdb.confignode.rpc.thrift.TSetSchemaTemplateReq; +import org.apache.iotdb.confignode.rpc.thrift.TShowAINodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowCQResp; import org.apache.iotdb.confignode.rpc.thrift.TShowClusterResp; import org.apache.iotdb.confignode.rpc.thrift.TShowConfigNodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowDataNodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowDatabaseResp; +import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq; +import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp; import org.apache.iotdb.confignode.rpc.thrift.TShowPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TShowPipeResp; import org.apache.iotdb.confignode.rpc.thrift.TShowSubscriptionReq; @@ -209,6 +221,13 @@ public interface IManager { */ CQManager getCQManager(); + /** + * Get {@link ModelManager}. + * + * @return {@link ModelManager} instance + */ + ModelManager getModelManager(); + /** * Get {@link PipeManager}. * @@ -275,6 +294,30 @@ public interface IManager { */ DataSet removeDataNode(RemoveDataNodePlan removeDataNodePlan); + /** + * Register AINode + * + * @param req TAINodeRegisterReq + * @return AINodeConfigurationDataSet + */ + DataSet registerAINode(TAINodeRegisterReq req); + + /** + * Restart AINode. + * + * @param req TAINodeRestartReq + * @return SUCCESS_STATUS if allow AINode to restart, REJECT_START otherwise + */ + TAINodeRestartResp restartAINode(TAINodeRestartReq req); + + /** + * Remove AINode. + * + * @param removeAINodePlan RemoveAINodePlan + * @return AINodeToStatusResp + */ + TSStatus removeAINode(RemoveAINodePlan removeAINodePlan); + /** * Report that the specified DataNode will be shutdown. * @@ -291,6 +334,15 @@ public interface IManager { */ DataSet getDataNodeConfiguration(GetDataNodeConfigurationPlan getDataNodeConfigurationPlan); + /** + * Get AINode info. + * + * @param getAINodeConfigurationPlan which contains specific AINode id or -1 to get all AINodes' + * configuration. + * @return AINodeConfigurationDataSet + */ + DataSet getAINodeConfiguration(GetAINodeConfigurationPlan getAINodeConfigurationPlan); + /** * Get Cluster Nodes' information. * @@ -525,6 +577,9 @@ TDataPartitionTableResp getOrCreateDataPartition( /** Show (data/schemaengine) regions. */ DataSet showRegion(GetRegionInfoListPlan getRegionInfoListPlan); + /** Show AINodes. */ + TShowAINodesResp showAINodes(); + /** Show DataNodes. */ TShowDataNodesResp showDataNodes(); @@ -737,6 +792,18 @@ TDataPartitionTableResp getOrCreateDataPartition( TSStatus transfer(List newUnknownDataList); + /** Create a model. */ + TSStatus createModel(TCreateModelReq req); + + /** Drop a model. */ + TSStatus dropModel(TDropModelReq req); + + /** Return the model table. */ + TShowModelResp showModel(TShowModelReq req); + + /** Update the model state */ + TGetModelInfoResp getModelInfo(TGetModelInfoReq req); + /** Set space quota. */ TSStatus setSpaceQuota(TSetSpaceQuotaReq req); } diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java new file mode 100644 index 000000000000..aa7ef2fb6e64 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.manager; + +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; +import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; +import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp; +import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp; +import org.apache.iotdb.confignode.persistence.ModelInfo; +import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; +import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; +import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq; +import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp; +import org.apache.iotdb.consensus.common.DataSet; +import org.apache.iotdb.consensus.exception.ConsensusException; +import org.apache.iotdb.rpc.TSStatusCode; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +public class ModelManager { + + private static final Logger LOGGER = LoggerFactory.getLogger(ModelManager.class); + + private final ConfigManager configManager; + private final ModelInfo modelInfo; + + public ModelManager(ConfigManager configManager, ModelInfo modelInfo) { + this.configManager = configManager; + this.modelInfo = modelInfo; + } + + public TSStatus createModel(TCreateModelReq req) { + if (modelInfo.contain(req.modelName)) { + return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode()) + .setMessage(String.format("Model name %s already exists", req.modelName)); + } + return configManager.getProcedureManager().createModel(req.modelName, req.uri); + } + + public TSStatus dropModel(TDropModelReq req) { + if (!modelInfo.contain(req.modelId)) { + return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode()) + .setMessage(String.format("Model name %s doesn't exists", req.modelId)); + } + return configManager.getProcedureManager().dropModel(req.getModelId()); + } + + public TShowModelResp showModel(TShowModelReq req) { + try { + DataSet response = configManager.getConsensusManager().read(new ShowModelPlan(req)); + return ((ModelTableResp) response).convertToThriftResponse(); + } catch (ConsensusException e) { + LOGGER.warn( + String.format("Unexpected error happened while showing model %s: ", req.getModelId()), e); + // consensus layer related errors + TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()); + res.setMessage(e.getMessage()); + return new TShowModelResp(res, Collections.emptyList()); + } catch (IOException e) { + LOGGER.warn("Fail to get ModelTable", e); + return new TShowModelResp( + new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) + .setMessage(e.getMessage()), + Collections.emptyList()); + } + } + + public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { + try { + GetModelInfoResp response = + (GetModelInfoResp) configManager.getConsensusManager().read(new GetModelInfoPlan(req)); + if (response.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + return new TGetModelInfoResp(response.getStatus()); + } + int aiNodeId = response.getTargetAINodeId(); + if (aiNodeId != 0) { + response.setTargetAINodeAddress( + configManager.getNodeManager().getRegisteredAINode(aiNodeId)); + } else { + if (configManager.getNodeManager().getRegisteredAINodes().isEmpty()) { + return new TGetModelInfoResp( + new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()) + .setMessage("There is no AINode available")); + } + response.setTargetAINodeAddress( + configManager.getNodeManager().getRegisteredAINodes().get(0)); + } + return response.convertToThriftResponse(); + } catch (ConsensusException e) { + LOGGER.warn("Unexpected error happened while getting model: ", e); + // consensus layer related errors + TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()); + res.setMessage(e.getMessage()); + return new TGetModelInfoResp(res); + } + } + + public List getModelDistributions(String modelName) { + return modelInfo.getNodeIds(modelName); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java index c77af9837578..fccdb66127ea 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java @@ -41,6 +41,7 @@ import org.apache.iotdb.confignode.conf.ConfigNodeConfig; import org.apache.iotdb.confignode.conf.ConfigNodeDescriptor; import org.apache.iotdb.confignode.consensus.request.auth.AuthorPlan; +import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; import org.apache.iotdb.confignode.consensus.request.write.confignode.RemoveConfigNodePlan; import org.apache.iotdb.confignode.consensus.request.write.database.SetTTLPlan; import org.apache.iotdb.confignode.consensus.request.write.datanode.RemoveDataNodePlan; @@ -54,7 +55,10 @@ import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; import org.apache.iotdb.confignode.procedure.env.RegionMaintainHandler; import org.apache.iotdb.confignode.procedure.impl.cq.CreateCQProcedure; +import org.apache.iotdb.confignode.procedure.impl.model.CreateModelProcedure; +import org.apache.iotdb.confignode.procedure.impl.model.DropModelProcedure; import org.apache.iotdb.confignode.procedure.impl.node.AddConfigNodeProcedure; +import org.apache.iotdb.confignode.procedure.impl.node.RemoveAINodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveConfigNodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveDataNodeProcedure; import org.apache.iotdb.confignode.procedure.impl.pipe.plugin.CreatePipePluginProcedure; @@ -550,6 +554,13 @@ public boolean removeDataNode(RemoveDataNodePlan removeDataNodePlan) { return true; } + public boolean removeAINode(RemoveAINodePlan removeAINodePlan) { + this.executor.submitProcedure(new RemoveAINodeProcedure(removeAINodePlan.getAINodeLocation())); + LOGGER.info( + "Submit RemoveAINodeProcedure successfully, {}", removeAINodePlan.getAINodeLocation()); + return true; + } + // region region migration private TSStatus checkRegionMigrate( TMigrateRegionReq migrateRegionReq, @@ -800,6 +811,25 @@ public TSStatus createCQ(TCreateCQReq req, ScheduledExecutorService scheduledExe return statusList.get(0); } + public TSStatus createModel(String modelName, String uri) { + long procedureId = executor.submitProcedure(new CreateModelProcedure(modelName, uri)); + LOGGER.info("CreateModelProcedure was submitted, procedureId: {}.", procedureId); + return RpcUtils.SUCCESS_STATUS; + } + + public TSStatus dropModel(String modelId) { + long procedureId = executor.submitProcedure(new DropModelProcedure(modelId)); + List statusList = new ArrayList<>(); + boolean isSucceed = + waitingProcedureFinished(Collections.singletonList(procedureId), statusList); + if (isSucceed) { + return RpcUtils.SUCCESS_STATUS; + } else { + return new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()) + .setMessage(statusList.get(0).getMessage()); + } + } + public TSStatus createPipePlugin( PipePluginMeta pipePluginMeta, byte[] jarFile, boolean isSetIfNotExistsCondition) { final CreatePipePluginProcedure createPipePluginProcedure = diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/LoadManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/LoadManager.java index 3260c3679d17..31ef0c17c3a1 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/LoadManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/LoadManager.java @@ -262,9 +262,13 @@ public void forceUpdateNodeCache( loadCache.cacheConfigNodeHeartbeatSample(nodeId, heartbeatSample); break; case DataNode: - default: loadCache.cacheDataNodeHeartbeatSample(nodeId, heartbeatSample); break; + case AINode: + loadCache.cacheAINodeHeartbeatSample(nodeId, heartbeatSample); + break; + default: + break; } loadCache.updateNodeStatistics(); eventService.checkAndBroadcastNodeStatisticsChangeEventIfNecessary(); diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/LoadCache.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/LoadCache.java index e51d69073874..8fe7233809d3 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/LoadCache.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/LoadCache.java @@ -34,6 +34,7 @@ import org.apache.iotdb.confignode.manager.load.cache.consensus.ConsensusGroupCache; import org.apache.iotdb.confignode.manager.load.cache.consensus.ConsensusGroupHeartbeatSample; import org.apache.iotdb.confignode.manager.load.cache.consensus.ConsensusGroupStatistics; +import org.apache.iotdb.confignode.manager.load.cache.node.AINodeHeartbeatCache; import org.apache.iotdb.confignode.manager.load.cache.node.BaseNodeCache; import org.apache.iotdb.confignode.manager.load.cache.node.ConfigNodeHeartbeatCache; import org.apache.iotdb.confignode.manager.load.cache.node.DataNodeHeartbeatCache; @@ -192,9 +193,11 @@ public void createNodeHeartbeatCache(NodeType nodeType, int nodeId) { nodeCacheMap.put(nodeId, new ConfigNodeHeartbeatCache(nodeId)); break; case DataNode: - default: nodeCacheMap.put(nodeId, new DataNodeHeartbeatCache(nodeId)); break; + case AINode: + nodeCacheMap.put(nodeId, new AINodeHeartbeatCache(nodeId)); + break; } heartbeatProcessingMap.put(nodeId, new AtomicBoolean(false)); } @@ -225,6 +228,19 @@ public void cacheDataNodeHeartbeatSample(int nodeId, NodeHeartbeatSample sample) Optional.ofNullable(heartbeatProcessingMap.get(nodeId)).ifPresent(node -> node.set(false)); } + /** + * Cache the latest heartbeat sample of a AINode. + * + * @param nodeId the id of the AINode + * @param sample the latest heartbeat sample + */ + public void cacheAINodeHeartbeatSample(int nodeId, NodeHeartbeatSample sample) { + nodeCacheMap + .computeIfAbsent(nodeId, empty -> new AINodeHeartbeatCache(nodeId)) + .cacheHeartbeatSample(sample); + heartbeatProcessingMap.get(nodeId).set(false); + } + public void resetHeartbeatProcessing(int nodeId) { heartbeatProcessingMap.get(nodeId).set(false); } diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/node/AINodeHeartbeatCache.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/node/AINodeHeartbeatCache.java new file mode 100644 index 000000000000..f6c845f30e62 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/node/AINodeHeartbeatCache.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.manager.load.cache.node; + +import org.apache.iotdb.common.rpc.thrift.TLoadSample; +import org.apache.iotdb.commons.cluster.NodeStatus; + +import java.util.concurrent.atomic.AtomicReference; + +public class AINodeHeartbeatCache extends BaseNodeCache { + + private final AtomicReference latestLoadSample; + + public AINodeHeartbeatCache(int aiNodeId) { + super(aiNodeId); + this.latestLoadSample = new AtomicReference<>(new TLoadSample()); + } + + @Override + public void updateCurrentStatistics() { + NodeHeartbeatSample lastSample = null; + synchronized (slidingWindow) { + if (!slidingWindow.isEmpty()) { + lastSample = (NodeHeartbeatSample) getLastSample(); + } + } + long lastSendTime = lastSample == null ? 0 : lastSample.getSampleLogicalTimestamp(); + + /* Update load sample */ + if (lastSample != null && lastSample.isSetLoadSample()) { + latestLoadSample.set((lastSample.getLoadSample())); + } + + /* Update Node status */ + NodeStatus status = null; + String statusReason = null; + long currentNanoTime = System.nanoTime(); + if (lastSample != null && NodeStatus.Removing.equals(lastSample.getStatus())) { + status = NodeStatus.Removing; + } else if (currentNanoTime - lastSendTime > HEARTBEAT_TIMEOUT_TIME_IN_NS) { + status = NodeStatus.Unknown; + } else if (lastSample != null) { + status = lastSample.getStatus(); + statusReason = lastSample.getStatusReason(); + } + + long loadScore = NodeStatus.isNormalStatus(status) ? 0 : Long.MAX_VALUE; + + NodeStatistics newStatistics = + new NodeStatistics(currentNanoTime, status, statusReason, loadScore); + if (!currentStatistics.get().equals(newStatistics)) { + // Update the current NodeStatistics if necessary + currentStatistics.set(newStatistics); + } + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/node/DataNodeHeartbeatCache.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/node/DataNodeHeartbeatCache.java index 8948384efa88..39454a4ccb41 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/node/DataNodeHeartbeatCache.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/node/DataNodeHeartbeatCache.java @@ -19,8 +19,8 @@ package org.apache.iotdb.confignode.manager.load.cache.node; +import org.apache.iotdb.common.rpc.thrift.TLoadSample; import org.apache.iotdb.commons.cluster.NodeStatus; -import org.apache.iotdb.mpp.rpc.thrift.TLoadSample; import java.util.concurrent.atomic.AtomicReference; diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/node/NodeHeartbeatSample.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/node/NodeHeartbeatSample.java index 0c8f2c23cb20..8217593f5d67 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/node/NodeHeartbeatSample.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/cache/node/NodeHeartbeatSample.java @@ -19,11 +19,12 @@ package org.apache.iotdb.confignode.manager.load.cache.node; +import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatResp; +import org.apache.iotdb.common.rpc.thrift.TLoadSample; import org.apache.iotdb.commons.cluster.NodeStatus; import org.apache.iotdb.confignode.manager.load.cache.AbstractHeartbeatSample; import org.apache.iotdb.confignode.rpc.thrift.TConfigNodeHeartbeatResp; import org.apache.iotdb.mpp.rpc.thrift.TDataNodeHeartbeatResp; -import org.apache.iotdb.mpp.rpc.thrift.TLoadSample; /** NodeHeartbeatSample records the heartbeat sample of a Node. */ public class NodeHeartbeatSample extends AbstractHeartbeatSample { @@ -58,6 +59,18 @@ public NodeHeartbeatSample(TDataNodeHeartbeatResp heartbeatResp) { this.loadSample = heartbeatResp.isSetLoadSample() ? heartbeatResp.getLoadSample() : null; } + /** Constructor for AINode sample. */ + public NodeHeartbeatSample(TAIHeartbeatResp heartbeatResp) { + super(heartbeatResp.getHeartbeatTimestamp()); + this.status = NodeStatus.parse(heartbeatResp.getStatus()); + this.statusReason = heartbeatResp.isSetStatusReason() ? heartbeatResp.getStatusReason() : null; + if (heartbeatResp.isSetLoadSample()) { + this.loadSample = heartbeatResp.getLoadSample(); + } else { + this.loadSample = null; + } + } + /** Constructor for ConfigNode sample. */ public NodeHeartbeatSample(TConfigNodeHeartbeatResp heartbeatResp) { super(heartbeatResp.getTimestamp()); diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/service/HeartbeatService.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/service/HeartbeatService.java index 7bd24c2710b8..c713b0eb0a39 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/service/HeartbeatService.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/load/service/HeartbeatService.java @@ -19,6 +19,8 @@ package org.apache.iotdb.confignode.manager.load.service; +import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatReq; +import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TEndPoint; @@ -26,8 +28,10 @@ import org.apache.iotdb.commons.concurrent.ThreadName; import org.apache.iotdb.commons.concurrent.threadpool.ScheduledExecutorUtil; import org.apache.iotdb.commons.pipe.config.PipeConfig; +import org.apache.iotdb.confignode.client.async.AsyncAINodeHeartbeatClientPool; import org.apache.iotdb.confignode.client.async.AsyncConfigNodeHeartbeatClientPool; import org.apache.iotdb.confignode.client.async.AsyncDataNodeHeartbeatClientPool; +import org.apache.iotdb.confignode.client.async.handlers.heartbeat.AINodeHeartbeatHandler; import org.apache.iotdb.confignode.client.async.handlers.heartbeat.ConfigNodeHeartbeatHandler; import org.apache.iotdb.confignode.client.async.handlers.heartbeat.DataNodeHeartbeatHandler; import org.apache.iotdb.confignode.conf.ConfigNodeDescriptor; @@ -126,6 +130,8 @@ private void heartbeatLoopBody() { // Send heartbeat requests to all the registered DataNodes pingRegisteredDataNodes( genHeartbeatReq(), getNodeManager().getRegisteredDataNodes()); + // Send heartbeat requests to all the registered AINodes + pingRegisteredAINodes(genMLHeartbeatReq(), getNodeManager().getRegisteredAINodes()); } }); } @@ -187,6 +193,17 @@ private TConfigNodeHeartbeatReq genConfigNodeHeartbeatReq() { return req; } + private TAIHeartbeatReq genMLHeartbeatReq() { + /* Generate heartbeat request */ + TAIHeartbeatReq heartbeatReq = new TAIHeartbeatReq(); + heartbeatReq.setHeartbeatTimestamp(System.nanoTime()); + + // We sample AINode's load in every 10 heartbeat loop + heartbeatReq.setNeedSamplingLoad(heartbeatCounter.get() % 10 == 0); + + return heartbeatReq; + } + /** * Send heartbeat requests to all the Registered ConfigNodes. * @@ -245,6 +262,24 @@ private void pingRegisteredDataNodes( } } + /** + * Send heartbeat requests to all the Registered AINodes. + * + * @param registeredAINodes DataNodes that registered in cluster + */ + private void pingRegisteredAINodes( + TAIHeartbeatReq heartbeatReq, List registeredAINodes) { + // Send heartbeat requests + for (TAINodeConfiguration aiNodeInfo : registeredAINodes) { + AINodeHeartbeatHandler handler = + new AINodeHeartbeatHandler( + aiNodeInfo.getLocation().getAiNodeId(), configManager.getLoadManager()); + AsyncAINodeHeartbeatClientPool.getInstance() + .getAINodeHeartBeat( + aiNodeInfo.getLocation().getInternalEndPoint(), heartbeatReq, handler); + } + } + private ConsensusManager getConsensusManager() { return configManager.getConsensusManager(); } diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/node/ClusterNodeStartUtils.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/node/ClusterNodeStartUtils.java index 035517172f35..df6e2d80305f 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/node/ClusterNodeStartUtils.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/node/ClusterNodeStartUtils.java @@ -19,6 +19,8 @@ package org.apache.iotdb.confignode.manager.node; +import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; +import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; @@ -30,6 +32,7 @@ import org.apache.iotdb.confignode.conf.ConfigNodeConfig; import org.apache.iotdb.confignode.conf.ConfigNodeDescriptor; import org.apache.iotdb.confignode.manager.ConfigManager; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq; import org.apache.iotdb.confignode.rpc.thrift.TConfigNodeRegisterReq; import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRegisterReq; import org.apache.iotdb.rpc.TSStatusCode; @@ -159,6 +162,52 @@ public static TSStatus confirmConfigNodeRegistration( return ACCEPT_NODE_REGISTRATION; } + public static TSStatus confirmAINodeRegistration( + TAINodeRegisterReq req, ConfigManager configManager) { + // Confirm cluster name + TSStatus status = confirmClusterName(NodeType.AINode, req.getClusterName()); + if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + return status; + } + // Confirm end point conflicts + List conflictEndPoints = + checkConflictTEndPointForNewAINode( + req.getAiNodeConfiguration().getLocation(), + configManager.getNodeManager().getRegisteredAINodes()); + if (!conflictEndPoints.isEmpty()) { + return rejectRegistrationBecauseConflictEndPoints(NodeType.AINode, conflictEndPoints); + } + // Confirm whether cluster id has been generated + status = confirmClusterId(configManager); + if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + return status; + } + // Success + return ACCEPT_NODE_REGISTRATION; + } + + /** + * Check if there exist conflict TEndPoints on the DataNode to be registered. + * + * @param newAINodeLocation The TDataNodeLocation of the DataNode to be registered + * @param registeredAINodes All registered DataNodes + * @return The conflict TEndPoints if exist + */ + public static List checkConflictTEndPointForNewAINode( + TAINodeLocation newAINodeLocation, List registeredAINodes) { + Set conflictEndPointSet = new HashSet<>(); + for (TAINodeConfiguration registeredAINode : registeredAINodes) { + TAINodeLocation registeredLocation = registeredAINode.getLocation(); + if (registeredLocation + .getInternalEndPoint() + .equals(newAINodeLocation.getInternalEndPoint())) { + conflictEndPointSet.add(newAINodeLocation.getInternalEndPoint()); + } + } + + return new ArrayList<>(conflictEndPointSet); + } + public static TSStatus confirmNodeRestart( NodeType nodeType, String clusterName, diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/node/NodeManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/node/NodeManager.java index 4cb7945cf9d2..0217d3920987 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/node/NodeManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/node/NodeManager.java @@ -19,6 +19,7 @@ package org.apache.iotdb.confignode.manager.node; +import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId; import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration; @@ -43,13 +44,19 @@ import org.apache.iotdb.confignode.client.sync.SyncDataNodeClientPool; import org.apache.iotdb.confignode.conf.ConfigNodeConfig; import org.apache.iotdb.confignode.conf.ConfigNodeDescriptor; +import org.apache.iotdb.confignode.consensus.request.read.ainode.GetAINodeConfigurationPlan; import org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan; +import org.apache.iotdb.confignode.consensus.request.write.ainode.RegisterAINodePlan; +import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; +import org.apache.iotdb.confignode.consensus.request.write.ainode.UpdateAINodePlan; import org.apache.iotdb.confignode.consensus.request.write.confignode.ApplyConfigNodePlan; import org.apache.iotdb.confignode.consensus.request.write.confignode.RemoveConfigNodePlan; import org.apache.iotdb.confignode.consensus.request.write.confignode.UpdateVersionInfoPlan; import org.apache.iotdb.confignode.consensus.request.write.datanode.RegisterDataNodePlan; import org.apache.iotdb.confignode.consensus.request.write.datanode.RemoveDataNodePlan; import org.apache.iotdb.confignode.consensus.request.write.datanode.UpdateDataNodePlan; +import org.apache.iotdb.confignode.consensus.response.ainode.AINodeConfigurationResp; +import org.apache.iotdb.confignode.consensus.response.ainode.AINodeRegisterResp; import org.apache.iotdb.confignode.consensus.response.datanode.ConfigurationResp; import org.apache.iotdb.confignode.consensus.response.datanode.DataNodeConfigurationResp; import org.apache.iotdb.confignode.consensus.response.datanode.DataNodeRegisterResp; @@ -68,6 +75,10 @@ import org.apache.iotdb.confignode.manager.schema.ClusterSchemaManager; import org.apache.iotdb.confignode.persistence.node.NodeInfo; import org.apache.iotdb.confignode.procedure.env.RegionMaintainHandler; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp; import org.apache.iotdb.confignode.rpc.thrift.TCQConfig; import org.apache.iotdb.confignode.rpc.thrift.TConfigNodeInfo; import org.apache.iotdb.confignode.rpc.thrift.TConfigNodeRegisterReq; @@ -91,6 +102,7 @@ import org.slf4j.LoggerFactory; import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.List; @@ -427,6 +439,148 @@ public TSStatus updateConfigNodeIfNecessary(int configNodeId, TNodeVersionInfo v return ClusterNodeStartUtils.ACCEPT_NODE_RESTART; } + public List getRegisteredAINodeInfoList() { + List aiNodeInfoList = new ArrayList<>(); + for (TAINodeConfiguration aiNodeConfiguration : getRegisteredAINodes()) { + TAINodeInfo aiNodeInfo = new TAINodeInfo(); + aiNodeInfo.setAiNodeId(aiNodeConfiguration.getLocation().getAiNodeId()); + aiNodeInfo.setStatus(getLoadManager().getNodeStatusWithReason(aiNodeInfo.getAiNodeId())); + aiNodeInfo.setInternalAddress(aiNodeConfiguration.getLocation().getInternalEndPoint().ip); + aiNodeInfo.setInternalPort(aiNodeConfiguration.getLocation().getInternalEndPoint().port); + aiNodeInfoList.add(aiNodeInfo); + } + return aiNodeInfoList; + } + + /** + * @return All registered AINodes + */ + public List getRegisteredAINodes() { + return nodeInfo.getRegisteredAINodes(); + } + + public TAINodeConfiguration getRegisteredAINode(int aiNodeId) { + return nodeInfo.getRegisteredAINode(aiNodeId); + } + + /** + * Register AINode. Use synchronized to make sure + * + * @param req TAINodeRegisterReq + * @return AINodeConfigurationDataSet. The {@link TSStatus} will be set to {@link + * TSStatusCode#SUCCESS_STATUS} when register success. + */ + public synchronized DataSet registerAINode(TAINodeRegisterReq req) { + + if (!nodeInfo.getRegisteredAINodes().isEmpty()) { + AINodeRegisterResp dataSet = new AINodeRegisterResp(); + dataSet.setConfigNodeList(Collections.emptyList()); + dataSet.setStatus( + new TSStatus(TSStatusCode.REGISTER_AI_NODE_ERROR.getStatusCode()) + .setMessage("There is already one AINode in the cluster.")); + return dataSet; + } + + int aiNodeId = nodeInfo.generateNextNodeId(); + getLoadManager().getLoadCache().createNodeHeartbeatCache(NodeType.AINode, aiNodeId); + RegisterAINodePlan registerAINodePlan = new RegisterAINodePlan(req.getAiNodeConfiguration()); + // Register new DataNode + registerAINodePlan.getAINodeConfiguration().getLocation().setAiNodeId(aiNodeId); + try { + getConsensusManager().write(registerAINodePlan); + } catch (ConsensusException e) { + LOGGER.warn(CONSENSUS_WRITE_ERROR, e); + } + + // update datanode's versionInfo + UpdateVersionInfoPlan updateVersionInfoPlan = + new UpdateVersionInfoPlan(req.getVersionInfo(), aiNodeId); + try { + getConsensusManager().write(updateVersionInfoPlan); + } catch (ConsensusException e) { + LOGGER.warn(CONSENSUS_WRITE_ERROR, e); + } + + AINodeRegisterResp resp = new AINodeRegisterResp(); + resp.setStatus(ClusterNodeStartUtils.ACCEPT_NODE_REGISTRATION); + resp.setConfigNodeList(getRegisteredConfigNodes()); + resp.setAINodeId(registerAINodePlan.getAINodeConfiguration().getLocation().getAiNodeId()); + return resp; + } + + /** + * Remove AINodes. + * + * @param removeAINodePlan removeDataNodePlan + */ + public TSStatus removeAINode(RemoveAINodePlan removeAINodePlan) { + LOGGER.info("NodeManager start to remove AINode {}", removeAINodePlan); + + // check if the node exists + if (!nodeInfo.containsAINode(removeAINodePlan.getAINodeLocation().getAiNodeId())) { + return new TSStatus(TSStatusCode.REMOVE_AI_NODE_ERROR.getStatusCode()) + .setMessage("AINode doesn't exist."); + } + + // Add request to queue, then return to client + boolean removeSucceed = configManager.getProcedureManager().removeAINode(removeAINodePlan); + TSStatus status; + if (removeSucceed) { + status = new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); + status.setMessage("Server accepted the request"); + } else { + status = new TSStatus(TSStatusCode.REMOVE_AI_NODE_ERROR.getStatusCode()); + status.setMessage("Server rejected the request, maybe requests are too many"); + } + + LOGGER.info( + "NodeManager submit RemoveAINodePlan finished, removeAINodePlan: {}", removeAINodePlan); + return status; + } + + public TAINodeRestartResp updateAINodeIfNecessary(TAINodeRestartReq req) { + int nodeId = req.getAiNodeConfiguration().getLocation().getAiNodeId(); + TAINodeConfiguration aiNodeConfiguration = getRegisteredAINode(nodeId); + if (!req.getAiNodeConfiguration().equals(aiNodeConfiguration)) { + // Update AINodeConfiguration when modified during restart + UpdateAINodePlan updateAINodePlan = new UpdateAINodePlan(req.getAiNodeConfiguration()); + try { + getConsensusManager().write(updateAINodePlan); + } catch (ConsensusException e) { + LOGGER.warn(CONSENSUS_WRITE_ERROR, e); + } + } + TNodeVersionInfo versionInfo = nodeInfo.getVersionInfo(nodeId); + if (!req.getVersionInfo().equals(versionInfo)) { + // Update versionInfo when modified during restart + UpdateVersionInfoPlan updateVersionInfoPlan = + new UpdateVersionInfoPlan(req.getVersionInfo(), nodeId); + try { + getConsensusManager().write(updateVersionInfoPlan); + } catch (ConsensusException e) { + LOGGER.warn(CONSENSUS_WRITE_ERROR, e); + } + } + + TAINodeRestartResp resp = new TAINodeRestartResp(); + resp.setStatus(ClusterNodeStartUtils.ACCEPT_NODE_RESTART); + resp.setConfigNodeList(getRegisteredConfigNodes()); + return resp; + } + + public AINodeConfigurationResp getAINodeConfiguration(GetAINodeConfigurationPlan req) { + try { + return (AINodeConfigurationResp) getConsensusManager().read(req); + } catch (ConsensusException e) { + LOGGER.warn("Failed in the read API executing the consensus layer due to: ", e); + TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()); + res.setMessage(e.getMessage()); + AINodeConfigurationResp response = new AINodeConfigurationResp(); + response.setStatus(res); + return response; + } + } + /** * Get TDataNodeConfiguration. * diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java new file mode 100644 index 000000000000..d3407e5c5dee --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java @@ -0,0 +1,382 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.persistence; + +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.model.ModelInformation; +import org.apache.iotdb.commons.model.ModelStatus; +import org.apache.iotdb.commons.model.ModelTable; +import org.apache.iotdb.commons.model.ModelType; +import org.apache.iotdb.commons.snapshot.SnapshotProcessor; +import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; +import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; +import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; +import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; +import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp; +import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp; +import org.apache.iotdb.rpc.TSStatusCode; + +import org.apache.thrift.TException; +import org.apache.tsfile.utils.PublicBAOS; +import org.apache.tsfile.utils.ReadWriteIOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.ThreadSafe; + +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +@ThreadSafe +public class ModelInfo implements SnapshotProcessor { + + private static final Logger LOGGER = LoggerFactory.getLogger(ModelInfo.class); + + private static final String SNAPSHOT_FILENAME = "model_info.snapshot"; + + private ModelTable modelTable; + + private final Map> modelNameToNodes; + + private final ReadWriteLock modelTableLock = new ReentrantReadWriteLock(); + + private static final Set builtInForecastModel = new HashSet<>(); + + private static final Set builtInAnomalyDetectionModel = new HashSet<>(); + + static { + builtInForecastModel.add("_ARIMA"); + builtInForecastModel.add("_NaiveForecaster"); + builtInForecastModel.add("_STLForecaster"); + builtInForecastModel.add("_ExponentialSmoothing"); + builtInAnomalyDetectionModel.add("_GaussianHMM"); + builtInAnomalyDetectionModel.add("_GMMHMM"); + builtInAnomalyDetectionModel.add("_Stray"); + } + + public ModelInfo() { + this.modelTable = new ModelTable(); + this.modelNameToNodes = new HashMap<>(); + } + + public boolean contain(String modelName) { + return modelTable.containsModel(modelName); + } + + public void acquireModelTableReadLock() { + LOGGER.info("acquire ModelTableReadLock"); + modelTableLock.readLock().lock(); + } + + public void releaseModelTableReadLock() { + LOGGER.info("release ModelTableReadLock"); + modelTableLock.readLock().unlock(); + } + + public void acquireModelTableWriteLock() { + LOGGER.info("acquire ModelTableWriteLock"); + modelTableLock.writeLock().lock(); + } + + public void releaseModelTableWriteLock() { + LOGGER.info("release ModelTableWriteLock"); + modelTableLock.writeLock().unlock(); + } + + // init the model in modeInfo, it won't update the details information of the model + public TSStatus createModel(CreateModelPlan plan) { + try { + acquireModelTableWriteLock(); + String modelName = plan.getModelName(); + if (modelTable.containsModel(modelName)) { + return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode()) + .setMessage(String.format("model [%s] has already been created.", modelName)); + } else { + modelTable.addModel(new ModelInformation(modelName, ModelStatus.LOADING)); + return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); + } + } catch (Exception e) { + final String errorMessage = + String.format( + "Failed to add model [%s] in ModelTable on Config Nodes, because of %s", + plan.getModelName(), e); + LOGGER.warn(errorMessage, e); + return new TSStatus(TSStatusCode.CREATE_MODEL_ERROR.getStatusCode()).setMessage(errorMessage); + } finally { + releaseModelTableWriteLock(); + } + } + + public TSStatus dropModelInNode(int aiNodeId) { + acquireModelTableWriteLock(); + try { + for (Map.Entry> entry : modelNameToNodes.entrySet()) { + entry.getValue().remove(Integer.valueOf(aiNodeId)); + // if list is empty, remove this model totally + if (entry.getValue().isEmpty()) { + modelTable.removeModel(entry.getKey()); + modelNameToNodes.remove(entry.getKey()); + } + } + // currently, we only have one AINode at a time, so we can just clear failed model. + modelTable.clearFailedModel(); + return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); + } finally { + releaseModelTableWriteLock(); + } + } + + public TSStatus dropModel(String modelName) { + acquireModelTableWriteLock(); + TSStatus status; + if (modelTable.containsModel(modelName)) { + modelTable.removeModel(modelName); + modelNameToNodes.remove(modelName); + status = new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); + } else { + status = + new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()) + .setMessage(String.format("model [%s] has not been created.", modelName)); + } + releaseModelTableWriteLock(); + return status; + } + + public List getNodeIds(String modelName) { + return modelNameToNodes.getOrDefault(modelName, Collections.emptyList()); + } + + private ModelInformation getModelByName(String modelName) { + ModelType modelType = checkModelType(modelName); + if (modelType != ModelType.USER_DEFINED) { + if (modelType == ModelType.BUILT_IN_FORECAST && builtInForecastModel.contains(modelName)) { + return new ModelInformation(ModelType.BUILT_IN_FORECAST, modelName); + } else if (modelType == ModelType.BUILT_IN_ANOMALY_DETECTION + && builtInAnomalyDetectionModel.contains(modelName)) { + return new ModelInformation(ModelType.BUILT_IN_ANOMALY_DETECTION, modelName); + } + } else { + return modelTable.getModelInformationById(modelName); + } + return null; + } + + public ModelTableResp showModel(ShowModelPlan plan) { + acquireModelTableReadLock(); + try { + ModelTableResp modelTableResp = + new ModelTableResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); + if (plan.isSetModelName()) { + ModelInformation modelInformation = getModelByName(plan.getModelName()); + if (modelInformation != null) { + modelTableResp.addModelInformation(modelInformation); + } + } else { + modelTableResp.addModelInformation(modelTable.getAllModelInformation()); + for (String modelName : builtInForecastModel) { + modelTableResp.addModelInformation( + new ModelInformation(ModelType.BUILT_IN_FORECAST, modelName)); + } + for (String modelName : builtInAnomalyDetectionModel) { + modelTableResp.addModelInformation( + new ModelInformation(ModelType.BUILT_IN_ANOMALY_DETECTION, modelName)); + } + } + return modelTableResp; + } catch (IOException e) { + LOGGER.warn("Fail to get ModelTable", e); + return new ModelTableResp( + new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) + .setMessage(e.getMessage())); + } finally { + releaseModelTableReadLock(); + } + } + + private boolean containsBuiltInModelName(Set builtInModelSet, String modelName) { + // ignore the case + for (String builtInModelName : builtInModelSet) { + if (builtInModelName.equalsIgnoreCase(modelName)) { + return true; + } + } + return false; + } + + private ModelType checkModelType(String modelName) { + if (containsBuiltInModelName(builtInForecastModel, modelName)) { + return ModelType.BUILT_IN_FORECAST; + } else if (containsBuiltInModelName(builtInAnomalyDetectionModel, modelName)) { + return ModelType.BUILT_IN_ANOMALY_DETECTION; + } else { + return ModelType.USER_DEFINED; + } + } + + private int getAvailableAINodeForModel(String modelName, ModelType modelType) { + if (modelType == ModelType.USER_DEFINED) { + List aiNodeIds = modelNameToNodes.get(modelName); + if (aiNodeIds != null) { + return aiNodeIds.get(0); + } + } else { + // any AINode is fine for built-in model + // 0 is always the nodeId for configNode, so it's fine to use 0 as special value + return 0; + } + return -1; + } + + // This method will be used by dataNode to get schema of the model for inference + public GetModelInfoResp getModelInfo(GetModelInfoPlan plan) { + acquireModelTableReadLock(); + try { + String modelName = plan.getModelId(); + GetModelInfoResp getModelInfoResp; + ModelInformation modelInformation; + ModelType modelType; + // check if it's a built-in model + if ((modelType = checkModelType(modelName)) != ModelType.USER_DEFINED) { + modelInformation = new ModelInformation(modelType, modelName); + } else { + modelInformation = modelTable.getModelInformationById(modelName); + } + + if (modelInformation != null) { + getModelInfoResp = + new GetModelInfoResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); + } else { + TSStatus errorStatus = new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); + errorStatus.setMessage(String.format("model [%s] has not been created.", modelName)); + getModelInfoResp = new GetModelInfoResp(errorStatus); + return getModelInfoResp; + } + PublicBAOS buffer = new PublicBAOS(); + DataOutputStream stream = new DataOutputStream(buffer); + modelInformation.serialize(stream); + getModelInfoResp.setModelInfo(ByteBuffer.wrap(buffer.getBuf(), 0, buffer.size())); + // select the nodeId to process the task, currently we default use the first one. + int aiNodeId = getAvailableAINodeForModel(modelName, modelType); + if (aiNodeId == -1) { + TSStatus errorStatus = new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); + errorStatus.setMessage(String.format("There is no AINode with %s available", modelName)); + getModelInfoResp = new GetModelInfoResp(errorStatus); + return getModelInfoResp; + } else { + getModelInfoResp.setTargetAINodeId(aiNodeId); + } + return getModelInfoResp; + } catch (IOException e) { + LOGGER.warn("Fail to get model info", e); + return new GetModelInfoResp( + new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) + .setMessage(e.getMessage())); + } finally { + releaseModelTableReadLock(); + } + } + + public TSStatus updateModelInfo(UpdateModelInfoPlan plan) { + acquireModelTableWriteLock(); + try { + String modelName = plan.getModelName(); + if (modelTable.containsModel(modelName)) { + modelTable.updateModel(modelName, plan.getModelInformation()); + } + if (!plan.getNodeIds().isEmpty()) { + // only used in model registration, so we can just put the nodeIds in the map without + // checking + modelNameToNodes.put(modelName, plan.getNodeIds()); + } + return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); + } finally { + releaseModelTableWriteLock(); + } + } + + @Override + public boolean processTakeSnapshot(File snapshotDir) throws TException, IOException { + File snapshotFile = new File(snapshotDir, SNAPSHOT_FILENAME); + if (snapshotFile.exists() && snapshotFile.isFile()) { + LOGGER.error( + "Failed to take snapshot of ModelInfo, because snapshot file [{}] is already exist.", + snapshotFile.getAbsolutePath()); + return false; + } + + acquireModelTableReadLock(); + try (FileOutputStream fileOutputStream = new FileOutputStream(snapshotFile)) { + modelTable.serialize(fileOutputStream); + ReadWriteIOUtils.write(modelNameToNodes.size(), fileOutputStream); + for (Map.Entry> entry : modelNameToNodes.entrySet()) { + ReadWriteIOUtils.write(entry.getKey(), fileOutputStream); + ReadWriteIOUtils.write(entry.getValue().size(), fileOutputStream); + for (Integer nodeId : entry.getValue()) { + ReadWriteIOUtils.write(nodeId, fileOutputStream); + } + } + fileOutputStream.getFD().sync(); + return true; + } finally { + releaseModelTableReadLock(); + } + } + + @Override + public void processLoadSnapshot(File snapshotDir) throws TException, IOException { + File snapshotFile = new File(snapshotDir, SNAPSHOT_FILENAME); + if (!snapshotFile.exists() || !snapshotFile.isFile()) { + LOGGER.error( + "Failed to load snapshot of ModelInfo, snapshot file [{}] does not exist.", + snapshotFile.getAbsolutePath()); + return; + } + acquireModelTableWriteLock(); + try (FileInputStream fileInputStream = new FileInputStream(snapshotFile)) { + modelTable.clear(); + modelTable = ModelTable.deserialize(fileInputStream); + int size = ReadWriteIOUtils.readInt(fileInputStream); + for (int i = 0; i < size; i++) { + String modelName = ReadWriteIOUtils.readString(fileInputStream); + int nodeSize = ReadWriteIOUtils.readInt(fileInputStream); + List nodes = new LinkedList<>(); + for (int j = 0; j < nodeSize; j++) { + nodes.add(ReadWriteIOUtils.readInt(fileInputStream)); + } + modelNameToNodes.put(modelName, nodes); + } + } finally { + releaseModelTableWriteLock(); + } + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java index dc1ffbde7203..e2a3cd1d9323 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java @@ -28,10 +28,13 @@ import org.apache.iotdb.commons.snapshot.SnapshotProcessor; import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; import org.apache.iotdb.confignode.consensus.request.auth.AuthorPlan; +import org.apache.iotdb.confignode.consensus.request.read.ainode.GetAINodeConfigurationPlan; import org.apache.iotdb.confignode.consensus.request.read.database.CountDatabasePlan; import org.apache.iotdb.confignode.consensus.request.read.database.GetDatabasePlan; import org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan; import org.apache.iotdb.confignode.consensus.request.read.function.GetUDFJarPlan; +import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; +import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.CountTimeSlotListPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.GetDataPartitionPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.GetNodePathsPartitionPlan; @@ -49,6 +52,9 @@ import org.apache.iotdb.confignode.consensus.request.read.trigger.GetTriggerLocationPlan; import org.apache.iotdb.confignode.consensus.request.read.trigger.GetTriggerTablePlan; import org.apache.iotdb.confignode.consensus.request.read.ttl.ShowTTLPlan; +import org.apache.iotdb.confignode.consensus.request.write.ainode.RegisterAINodePlan; +import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; +import org.apache.iotdb.confignode.consensus.request.write.ainode.UpdateAINodePlan; import org.apache.iotdb.confignode.consensus.request.write.confignode.ApplyConfigNodePlan; import org.apache.iotdb.confignode.consensus.request.write.confignode.RemoveConfigNodePlan; import org.apache.iotdb.confignode.consensus.request.write.confignode.UpdateClusterIdPlan; @@ -70,6 +76,10 @@ import org.apache.iotdb.confignode.consensus.request.write.datanode.UpdateDataNodePlan; import org.apache.iotdb.confignode.consensus.request.write.function.CreateFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.DropFunctionPlan; +import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; +import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan; +import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; +import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AddRegionLocationPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.CreateDataPartitionPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.CreateSchemaPartitionPlan; @@ -118,6 +128,7 @@ import org.apache.iotdb.confignode.manager.pipe.agent.PipeConfigNodeAgent; import org.apache.iotdb.confignode.persistence.AuthorInfo; import org.apache.iotdb.confignode.persistence.ClusterInfo; +import org.apache.iotdb.confignode.persistence.ModelInfo; import org.apache.iotdb.confignode.persistence.ProcedureInfo; import org.apache.iotdb.confignode.persistence.TTLInfo; import org.apache.iotdb.confignode.persistence.TriggerInfo; @@ -176,6 +187,8 @@ public class ConfigPlanExecutor { private final CQInfo cqInfo; + private final ModelInfo modelInfo; + private final PipeInfo pipeInfo; private final SubscriptionInfo subscriptionInfo; @@ -194,6 +207,7 @@ public ConfigPlanExecutor( UDFInfo udfInfo, TriggerInfo triggerInfo, CQInfo cqInfo, + ModelInfo modelInfo, PipeInfo pipeInfo, SubscriptionInfo subscriptionInfo, QuotaInfo quotaInfo, @@ -225,6 +239,9 @@ public ConfigPlanExecutor( this.cqInfo = cqInfo; this.snapshotProcessorList.add(cqInfo); + this.modelInfo = modelInfo; + this.snapshotProcessorList.add(modelInfo); + this.pipeInfo = pipeInfo; this.snapshotProcessorList.add(pipeInfo); @@ -248,6 +265,8 @@ public DataSet executeQueryPlan(ConfigPhysicalPlan req) switch (req.getType()) { case GetDataNodeConfiguration: return nodeInfo.getDataNodeConfiguration((GetDataNodeConfigurationPlan) req); + case GetAINodeConfiguration: + return nodeInfo.getAINodeConfiguration((GetAINodeConfigurationPlan) req); case CountDatabase: return clusterSchemaInfo.countMatchedDatabases((CountDatabasePlan) req); case GetDatabase: @@ -304,6 +323,10 @@ public DataSet executeQueryPlan(ConfigPhysicalPlan req) return udfInfo.getUDFTable(); case GetFunctionJar: return udfInfo.getUDFJar((GetUDFJarPlan) req); + case ShowModel: + return modelInfo.showModel((ShowModelPlan) req); + case GetModelInfo: + return modelInfo.getModelInfo((GetModelInfoPlan) req); case GetPipePluginTable: return pipeInfo.getPipePluginInfo().showPipePlugins(); case GetPipePluginJar: @@ -335,6 +358,12 @@ public TSStatus executeNonQueryPlan(ConfigPhysicalPlan physicalPlan) return status; } return partitionInfo.updateDataNode((UpdateDataNodePlan) physicalPlan); + case RegisterAINode: + return nodeInfo.registerAINode((RegisterAINodePlan) physicalPlan); + case UpdateAINodeConfiguration: + return nodeInfo.updateAINode((UpdateAINodePlan) physicalPlan); + case RemoveAINode: + return nodeInfo.removeAINode((RemoveAINodePlan) physicalPlan); case CreateDatabase: status = clusterSchemaInfo.createDatabase((DatabaseSchemaPlan) physicalPlan); if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { @@ -496,10 +525,22 @@ public TSStatus executeNonQueryPlan(ConfigPhysicalPlan physicalPlan) return cqInfo.activeCQ((ActiveCQPlan) physicalPlan); case UPDATE_CQ_LAST_EXEC_TIME: return cqInfo.updateCQLastExecutionTime((UpdateCQLastExecTimePlan) physicalPlan); + case CreateModel: + return modelInfo.createModel((CreateModelPlan) physicalPlan); + case UpdateModelInfo: + return modelInfo.updateModelInfo((UpdateModelInfoPlan) physicalPlan); + case DropModel: + return modelInfo.dropModel(((DropModelPlan) physicalPlan).getModelName()); + case DropModelInNode: + return modelInfo.dropModelInNode(((DropModelInNodePlan) physicalPlan).getNodeId()); case CreatePipePlugin: return pipeInfo.getPipePluginInfo().createPipePlugin((CreatePipePluginPlan) physicalPlan); case DropPipePlugin: return pipeInfo.getPipePluginInfo().dropPipePlugin((DropPipePluginPlan) physicalPlan); + case setSpaceQuota: + return quotaInfo.setSpaceQuota((SetSpaceQuotaPlan) physicalPlan); + case setThrottleQuota: + return quotaInfo.setThrottleQuota((SetThrottleQuotaPlan) physicalPlan); case CreatePipeSinkV1: case DropPipeV1: case DropPipeSinkV1: @@ -509,10 +550,6 @@ public TSStatus executeNonQueryPlan(ConfigPhysicalPlan physicalPlan) case SetPipeStatusV1: case ShowPipeV1: return new TSStatus(TSStatusCode.INCOMPATIBLE_VERSION.getStatusCode()); - case setSpaceQuota: - return quotaInfo.setSpaceQuota((SetSpaceQuotaPlan) physicalPlan); - case setThrottleQuota: - return quotaInfo.setThrottleQuota((SetThrottleQuotaPlan) physicalPlan); case PipeEnriched: return executeNonQueryPlan(((PipeEnrichedPlan) physicalPlan).getInnerPlan()); case PipeDeleteTimeSeries: diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/node/NodeInfo.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/node/NodeInfo.java index 4f1c8819d8fe..7dd4af334a72 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/node/NodeInfo.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/node/NodeInfo.java @@ -19,19 +19,26 @@ package org.apache.iotdb.confignode.persistence.node; +import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; +import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TSStatus; import org.apache.iotdb.commons.snapshot.SnapshotProcessor; import org.apache.iotdb.confignode.conf.ConfigNodeDescriptor; import org.apache.iotdb.confignode.conf.SystemPropertiesUtils; +import org.apache.iotdb.confignode.consensus.request.read.ainode.GetAINodeConfigurationPlan; import org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan; +import org.apache.iotdb.confignode.consensus.request.write.ainode.RegisterAINodePlan; +import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; +import org.apache.iotdb.confignode.consensus.request.write.ainode.UpdateAINodePlan; import org.apache.iotdb.confignode.consensus.request.write.confignode.ApplyConfigNodePlan; import org.apache.iotdb.confignode.consensus.request.write.confignode.RemoveConfigNodePlan; import org.apache.iotdb.confignode.consensus.request.write.confignode.UpdateVersionInfoPlan; import org.apache.iotdb.confignode.consensus.request.write.datanode.RegisterDataNodePlan; import org.apache.iotdb.confignode.consensus.request.write.datanode.RemoveDataNodePlan; import org.apache.iotdb.confignode.consensus.request.write.datanode.UpdateDataNodePlan; +import org.apache.iotdb.confignode.consensus.response.ainode.AINodeConfigurationResp; import org.apache.iotdb.confignode.consensus.response.datanode.DataNodeConfigurationResp; import org.apache.iotdb.confignode.rpc.thrift.TNodeVersionInfo; import org.apache.iotdb.rpc.TSStatusCode; @@ -62,6 +69,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.ReentrantReadWriteLock; +import static org.apache.iotdb.confignode.conf.ConfigNodeConstant.REMOVE_AINODE_PROCESS; import static org.apache.iotdb.confignode.conf.ConfigNodeConstant.REMOVE_DATANODE_PROCESS; /** @@ -91,6 +99,9 @@ public class NodeInfo implements SnapshotProcessor { private final Map registeredDataNodes; private final ReentrantReadWriteLock dataNodeInfoReadWriteLock; + private final Map registeredAINodes; + private final ReentrantReadWriteLock aiNodeInfoReadWriteLock; + private final Map nodeVersionInfo; private final ReentrantReadWriteLock versionInfoReadWriteLock; @@ -103,6 +114,9 @@ public NodeInfo() { this.dataNodeInfoReadWriteLock = new ReentrantReadWriteLock(); this.registeredDataNodes = new ConcurrentHashMap<>(); + this.aiNodeInfoReadWriteLock = new ReentrantReadWriteLock(); + this.registeredAINodes = new ConcurrentHashMap<>(); + this.nodeVersionInfo = new ConcurrentHashMap<>(); this.versionInfoReadWriteLock = new ReentrantReadWriteLock(); } @@ -223,6 +237,28 @@ public DataNodeConfigurationResp getDataNodeConfiguration( return result; } + public AINodeConfigurationResp getAINodeConfiguration( + GetAINodeConfigurationPlan getAINodeConfigurationPlan) { + AINodeConfigurationResp result = new AINodeConfigurationResp(); + result.setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); + + int aiNodeId = getAINodeConfigurationPlan.getAiNodeId(); + aiNodeInfoReadWriteLock.readLock().lock(); + try { + if (aiNodeId == -1) { + result.setAiNodeConfigurationMap(new HashMap<>(registeredAINodes)); + } else { + result.setAiNodeConfigurationMap( + registeredAINodes.get(aiNodeId) == null + ? new HashMap<>(0) + : Collections.singletonMap(aiNodeId, registeredAINodes.get(aiNodeId))); + } + } finally { + aiNodeInfoReadWriteLock.readLock().unlock(); + } + return result; + } + /** Return the number of registered Nodes. */ public int getRegisteredNodeCount() { int result; @@ -321,6 +357,47 @@ public List getRegisteredDataNodes(List dataNod return result; } + public List getRegisteredAINodes() { + List result; + aiNodeInfoReadWriteLock.readLock().lock(); + try { + result = new ArrayList<>(registeredAINodes.values()); + } finally { + aiNodeInfoReadWriteLock.readLock().unlock(); + } + return result; + } + + public TAINodeConfiguration getRegisteredAINode(int aiNodeId) { + aiNodeInfoReadWriteLock.readLock().lock(); + try { + return registeredAINodes.getOrDefault(aiNodeId, new TAINodeConfiguration()).deepCopy(); + } finally { + aiNodeInfoReadWriteLock.readLock().unlock(); + } + } + + /** Return the number of registered DataNodes. */ + public int getRegisteredAINodeCount() { + int result; + aiNodeInfoReadWriteLock.readLock().lock(); + try { + result = registeredAINodes.size(); + } finally { + aiNodeInfoReadWriteLock.readLock().unlock(); + } + return result; + } + + public boolean containsAINode(int aiNodeId) { + aiNodeInfoReadWriteLock.readLock().lock(); + try { + return registeredAINodes.containsKey(aiNodeId); + } finally { + aiNodeInfoReadWriteLock.readLock().unlock(); + } + } + /** * Update ConfigNodeList both in memory and confignode-system{@literal .}properties file. * @@ -391,6 +468,77 @@ public TSStatus removeConfigNode(RemoveConfigNodePlan removeConfigNodePlan) { return status; } + /** + * Persist AINode info. + * + * @param registerAINodePlan RegisterAINodePlan + * @return {@link TSStatusCode#SUCCESS_STATUS} + */ + public TSStatus registerAINode(RegisterAINodePlan registerAINodePlan) { + TSStatus result; + TAINodeConfiguration info = registerAINodePlan.getAINodeConfiguration(); + aiNodeInfoReadWriteLock.writeLock().lock(); + try { + synchronized (nextNodeId) { + if (nextNodeId.get() < info.getLocation().getAiNodeId()) { + nextNodeId.set(info.getLocation().getAiNodeId()); + } + } + registeredAINodes.put(info.getLocation().getAiNodeId(), info); + result = new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); + } finally { + aiNodeInfoReadWriteLock.writeLock().unlock(); + } + return result; + } + + /** + * Update the specified AINode‘s location. + * + * @param updateAINodePlan UpdateAINodePlan + * @return {@link TSStatusCode#SUCCESS_STATUS} if update AINode info successfully. + */ + public TSStatus updateAINode(UpdateAINodePlan updateAINodePlan) { + dataNodeInfoReadWriteLock.writeLock().lock(); + try { + TAINodeConfiguration newConfiguration = updateAINodePlan.getAINodeConfiguration(); + registeredAINodes.replace(newConfiguration.getLocation().getAiNodeId(), newConfiguration); + } finally { + dataNodeInfoReadWriteLock.writeLock().unlock(); + } + return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); + } + + /** + * Persist Information about remove dataNode. + * + * @param req RemoveDataNodePlan + * @return {@link TSStatus} + */ + public TSStatus removeAINode(RemoveAINodePlan req) { + LOGGER.info( + "{}, There are {} AI nodes in cluster before executed RemoveAINodePlan", + REMOVE_AINODE_PROCESS, + registeredAINodes.size()); + + aiNodeInfoReadWriteLock.writeLock().lock(); + versionInfoReadWriteLock.writeLock().lock(); + TAINodeLocation removedAINode = req.getAINodeLocation(); + try { + registeredAINodes.remove(removedAINode.getAiNodeId()); + nodeVersionInfo.remove(removedAINode.getAiNodeId()); + LOGGER.info("Removed the AINode {} from cluster", removedAINode); + } finally { + versionInfoReadWriteLock.writeLock().unlock(); + aiNodeInfoReadWriteLock.writeLock().unlock(); + } + LOGGER.info( + "{}, There are {} AI nodes in cluster after executed RemoveAINodePlan", + REMOVE_AINODE_PROCESS, + registeredAINodes.size()); + return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); + } + /** * Update the specified Node‘s versionInfo. * @@ -482,6 +630,7 @@ public boolean processTakeSnapshot(File snapshotDir) throws IOException, TExcept File tmpFile = new File(snapshotFile.getAbsolutePath() + "-" + UUID.randomUUID()); configNodeInfoReadWriteLock.readLock().lock(); dataNodeInfoReadWriteLock.readLock().lock(); + aiNodeInfoReadWriteLock.readLock().lock(); versionInfoReadWriteLock.readLock().lock(); try (FileOutputStream fileOutputStream = new FileOutputStream(tmpFile); TIOStreamTransport tioStreamTransport = new TIOStreamTransport(fileOutputStream)) { @@ -494,6 +643,8 @@ public boolean processTakeSnapshot(File snapshotDir) throws IOException, TExcept serializeRegisteredDataNode(fileOutputStream, protocol); + serializeRegisteredAINode(fileOutputStream, protocol); + serializeVersionInfo(fileOutputStream); tioStreamTransport.flush(); @@ -505,6 +656,7 @@ public boolean processTakeSnapshot(File snapshotDir) throws IOException, TExcept return tmpFile.renameTo(snapshotFile); } finally { versionInfoReadWriteLock.readLock().unlock(); + aiNodeInfoReadWriteLock.readLock().unlock(); dataNodeInfoReadWriteLock.readLock().unlock(); configNodeInfoReadWriteLock.readLock().unlock(); for (int retry = 0; retry < 5; retry++) { @@ -536,6 +688,15 @@ private void serializeRegisteredDataNode(OutputStream outputStream, TProtocol pr } } + private void serializeRegisteredAINode(OutputStream outputStream, TProtocol protocol) + throws IOException, TException { + ReadWriteIOUtils.write(registeredAINodes.size(), outputStream); + for (Entry entry : registeredAINodes.entrySet()) { + ReadWriteIOUtils.write(entry.getKey(), outputStream); + entry.getValue().write(protocol); + } + } + private void serializeVersionInfo(OutputStream outputStream) throws IOException { ReadWriteIOUtils.write(nodeVersionInfo.size(), outputStream); for (Entry entry : nodeVersionInfo.entrySet()) { @@ -558,6 +719,7 @@ public void processLoadSnapshot(File snapshotDir) throws IOException, TException configNodeInfoReadWriteLock.writeLock().lock(); dataNodeInfoReadWriteLock.writeLock().lock(); + aiNodeInfoReadWriteLock.writeLock().lock(); versionInfoReadWriteLock.writeLock().lock(); try (FileInputStream fileInputStream = new FileInputStream(snapshotFile); @@ -572,10 +734,13 @@ public void processLoadSnapshot(File snapshotDir) throws IOException, TException deserializeRegisteredDataNode(fileInputStream, protocol); + deserializeRegisteredAINode(fileInputStream, protocol); + deserializeBuildInfo(fileInputStream); } finally { versionInfoReadWriteLock.writeLock().unlock(); + aiNodeInfoReadWriteLock.writeLock().unlock(); dataNodeInfoReadWriteLock.writeLock().unlock(); configNodeInfoReadWriteLock.writeLock().unlock(); } @@ -605,6 +770,18 @@ private void deserializeRegisteredDataNode(InputStream inputStream, TProtocol pr } } + private void deserializeRegisteredAINode(InputStream inputStream, TProtocol protocol) + throws IOException, TException { + int size = ReadWriteIOUtils.readInt(inputStream); + while (size > 0) { + int aiNodeId = ReadWriteIOUtils.readInt(inputStream); + TAINodeConfiguration aiNodeInfo = new TAINodeConfiguration(); + aiNodeInfo.read(protocol); + registeredAINodes.put(aiNodeId, aiNodeInfo); + size--; + } + } + private void deserializeBuildInfo(InputStream inputStream) throws IOException { // old version may not have build info, // thus we need to check inputStream before deserialize. @@ -643,6 +820,7 @@ public boolean equals(Object o) { return registeredConfigNodes.equals(nodeInfo.registeredConfigNodes) && nextNodeId.get() == nodeInfo.nextNodeId.get() && registeredDataNodes.equals(nodeInfo.registeredDataNodes) + && registeredAINodes.equals(nodeInfo.registeredAINodes) && nodeVersionInfo.equals(nodeInfo.nodeVersionInfo); } diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java new file mode 100644 index 000000000000..8282608466d6 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.procedure.impl.model; + +import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.client.ainode.AINodeClient; +import org.apache.iotdb.commons.client.ainode.AINodeClientManager; +import org.apache.iotdb.commons.exception.ainode.LoadModelException; +import org.apache.iotdb.commons.model.ModelInformation; +import org.apache.iotdb.commons.model.ModelStatus; +import org.apache.iotdb.commons.model.exception.ModelManagementException; +import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; +import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; +import org.apache.iotdb.confignode.manager.ConfigManager; +import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; +import org.apache.iotdb.confignode.procedure.exception.ProcedureException; +import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure; +import org.apache.iotdb.confignode.procedure.state.model.CreateModelState; +import org.apache.iotdb.confignode.procedure.store.ProcedureType; +import org.apache.iotdb.consensus.exception.ConsensusException; +import org.apache.iotdb.rpc.TSStatusCode; + +import org.apache.tsfile.utils.ReadWriteIOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +public class CreateModelProcedure extends AbstractNodeProcedure { + + private static final Logger LOGGER = LoggerFactory.getLogger(CreateModelProcedure.class); + private static final int RETRY_THRESHOLD = 0; + + private String modelName; + + private String uri; + + private ModelInformation modelInformation = null; + + private List aiNodeIds; + + private String loadErrorMsg = ""; + + public CreateModelProcedure() { + super(); + } + + public CreateModelProcedure(String modelName, String uri) { + super(); + this.modelName = modelName; + this.uri = uri; + this.aiNodeIds = new ArrayList<>(); + } + + @Override + protected Flow executeFromState(ConfigNodeProcedureEnv env, CreateModelState state) { + if (modelName == null || uri == null) { + return Flow.NO_MORE_STATE; + } + try { + switch (state) { + case LOADING: + initModel(env); + loadModel(env); + setNextState(CreateModelState.ACTIVE); + break; + case ACTIVE: + modelInformation.updateStatus(ModelStatus.ACTIVE); + updateModel(env); + return Flow.NO_MORE_STATE; + default: + throw new UnsupportedOperationException( + String.format("Unknown state during executing createModelProcedure, %s", state)); + } + } catch (Exception e) { + if (isRollbackSupported(state)) { + LOGGER.error("Fail in CreateModelProcedure", e); + setFailure(new ProcedureException(e.getMessage())); + } else { + LOGGER.error( + "Retrievable error trying to create model [{}], state [{}]", modelName, state, e); + if (getCycles() > RETRY_THRESHOLD) { + modelInformation = new ModelInformation(modelName, ModelStatus.UNAVAILABLE); + modelInformation.setAttribute(loadErrorMsg); + updateModel(env); + setFailure( + new ProcedureException( + String.format("Fail to create model [%s] at STATE [%s]", modelName, state))); + } + } + } + return Flow.HAS_MORE_STATE; + } + + private void initModel(ConfigNodeProcedureEnv env) throws ConsensusException { + LOGGER.info("Start to add model [{}]", modelName); + + ConfigManager configManager = env.getConfigManager(); + TSStatus response = configManager.getConsensusManager().write(new CreateModelPlan(modelName)); + if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + throw new ModelManagementException( + String.format( + "Failed to add model [%s] in ModelTable on Config Nodes: %s", + modelName, response.getMessage())); + } + } + + private void checkModelInformationEquals(ModelInformation receiveModelInfo) { + if (modelInformation == null) { + modelInformation = receiveModelInfo; + } else { + if (!modelInformation.equals(receiveModelInfo)) { + throw new ModelManagementException( + String.format( + "Failed to load model [%s] on AI Nodes, model information is not equal in different nodes", + modelName)); + } + } + } + + private void loadModel(ConfigNodeProcedureEnv env) { + for (TAINodeConfiguration curNodeConfig : + env.getConfigManager().getNodeManager().getRegisteredAINodes()) { + try (AINodeClient client = + AINodeClientManager.getInstance() + .borrowClient(curNodeConfig.getLocation().getInternalEndPoint())) { + ModelInformation resp = client.registerModel(modelName, uri); + checkModelInformationEquals(resp); + aiNodeIds.add(curNodeConfig.getLocation().aiNodeId); + } catch (LoadModelException e) { + LOGGER.warn(e.getMessage()); + loadErrorMsg = e.getMessage(); + } catch (Exception e) { + LOGGER.warn( + "Failed to load model on AINode {} from ConfigNode", + curNodeConfig.getLocation().getInternalEndPoint()); + loadErrorMsg = e.getMessage(); + } + } + + if (aiNodeIds.isEmpty()) { + throw new ModelManagementException( + String.format("CREATE MODEL [%s] failed on all AINodes:[%s]", modelName, loadErrorMsg)); + } + } + + private void updateModel(ConfigNodeProcedureEnv env) { + LOGGER.info("Start to update model [{}]", modelName); + + ConfigManager configManager = env.getConfigManager(); + try { + TSStatus response = + configManager + .getConsensusManager() + .write(new UpdateModelInfoPlan(modelName, modelInformation, aiNodeIds)); + if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + throw new ModelManagementException( + String.format( + "Failed to update model [%s] in ModelTable on Config Nodes: %s", + modelName, response.getMessage())); + } + } catch (Exception e) { + throw new ModelManagementException( + String.format( + "Failed to update model [%s] in ModelTable on Config Nodes: %s", + modelName, e.getMessage())); + } + } + + @Override + protected void rollbackState(ConfigNodeProcedureEnv env, CreateModelState state) + throws IOException, InterruptedException, ProcedureException { + // do nothing + } + + @Override + protected boolean isRollbackSupported(CreateModelState state) { + return false; + } + + @Override + protected CreateModelState getState(int stateId) { + return CreateModelState.values()[stateId]; + } + + @Override + protected int getStateId(CreateModelState createModelState) { + return createModelState.ordinal(); + } + + @Override + protected CreateModelState getInitialState() { + return CreateModelState.LOADING; + } + + @Override + public void serialize(DataOutputStream stream) throws IOException { + stream.writeShort(ProcedureType.CREATE_MODEL_PROCEDURE.getTypeCode()); + super.serialize(stream); + ReadWriteIOUtils.write(modelName, stream); + ReadWriteIOUtils.write(uri, stream); + } + + @Override + public void deserialize(ByteBuffer byteBuffer) { + super.deserialize(byteBuffer); + modelName = ReadWriteIOUtils.readString(byteBuffer); + uri = ReadWriteIOUtils.readString(byteBuffer); + } + + @Override + public boolean equals(Object that) { + if (that instanceof CreateModelProcedure) { + CreateModelProcedure thatProc = (CreateModelProcedure) that; + return thatProc.getProcId() == this.getProcId() + && thatProc.getState() == this.getState() + && Objects.equals(thatProc.modelName, this.modelName) + && Objects.equals(thatProc.uri, this.uri); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(getProcId(), getState(), modelName, uri); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java new file mode 100644 index 000000000000..5a8f83254031 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.procedure.impl.model; + +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.client.ainode.AINodeClient; +import org.apache.iotdb.commons.client.ainode.AINodeClientManager; +import org.apache.iotdb.commons.model.exception.ModelManagementException; +import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; +import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; +import org.apache.iotdb.confignode.procedure.exception.ProcedureException; +import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure; +import org.apache.iotdb.confignode.procedure.state.model.DropModelState; +import org.apache.iotdb.confignode.procedure.store.ProcedureType; +import org.apache.iotdb.rpc.TSStatusCode; + +import org.apache.thrift.TException; +import org.apache.tsfile.utils.ReadWriteIOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Objects; + +import static org.apache.iotdb.confignode.procedure.state.model.DropModelState.CONFIG_NODE_DROPPED; + +public class DropModelProcedure extends AbstractNodeProcedure { + + private static final Logger LOGGER = LoggerFactory.getLogger(DropModelProcedure.class); + private static final int RETRY_THRESHOLD = 1; + + private String modelName; + + public DropModelProcedure() { + super(); + } + + public DropModelProcedure(String modelName) { + super(); + this.modelName = modelName; + } + + @Override + protected Flow executeFromState(ConfigNodeProcedureEnv env, DropModelState state) { + if (modelName == null) { + return Flow.NO_MORE_STATE; + } + try { + switch (state) { + case AI_NODE_DROPPED: + LOGGER.info("Start to drop model [{}] on AI Nodes", modelName); + dropModelOnAINode(env); + setNextState(CONFIG_NODE_DROPPED); + break; + case CONFIG_NODE_DROPPED: + dropModelOnConfigNode(env); + return Flow.NO_MORE_STATE; + default: + throw new UnsupportedOperationException( + String.format("Unknown state during executing dropModelProcedure, %s", state)); + } + } catch (Exception e) { + if (isRollbackSupported(state)) { + LOGGER.error("Fail in DropModelProcedure", e); + setFailure(new ProcedureException(e.getMessage())); + } else { + LOGGER.error( + "Retrievable error trying to drop model [{}], state [{}]", modelName, state, e); + if (getCycles() > RETRY_THRESHOLD) { + setFailure( + new ProcedureException( + String.format( + "Fail to drop model [%s] at STATE [%s], %s", + modelName, state, e.getMessage()))); + } + } + } + return Flow.HAS_MORE_STATE; + } + + private void dropModelOnAINode(ConfigNodeProcedureEnv env) { + LOGGER.info("Start to drop model file [{}] on AI Node", modelName); + + List nodeIds = + env.getConfigManager().getModelManager().getModelDistributions(modelName); + for (Integer nodeId : nodeIds) { + try (AINodeClient client = + AINodeClientManager.getInstance() + .borrowClient( + env.getConfigManager() + .getNodeManager() + .getRegisteredAINode(nodeId) + .getLocation() + .getInternalEndPoint())) { + TSStatus status = client.deleteModel(modelName); + if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + LOGGER.warn( + "Failed to drop model [{}] on AINode [{}], status: {}", + modelName, + nodeId, + status.getMessage()); + } + } catch (Exception e) { + LOGGER.warn( + "Failed to drop model [{}] on AINode [{}], status: {}", + modelName, + nodeId, + e.getMessage()); + } + } + } + + private void dropModelOnConfigNode(ConfigNodeProcedureEnv env) { + try { + TSStatus response = + env.getConfigManager().getConsensusManager().write(new DropModelPlan(modelName)); + if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + throw new TException(response.getMessage()); + } + } catch (Exception e) { + throw new ModelManagementException( + String.format( + "Fail to start training model [%s] on AI Node: %s", modelName, e.getMessage())); + } + } + + @Override + protected void rollbackState(ConfigNodeProcedureEnv env, DropModelState state) + throws IOException, InterruptedException, ProcedureException { + // no need to rollback + } + + @Override + protected DropModelState getState(int stateId) { + return DropModelState.values()[stateId]; + } + + @Override + protected int getStateId(DropModelState dropModelState) { + return dropModelState.ordinal(); + } + + @Override + protected DropModelState getInitialState() { + return DropModelState.AI_NODE_DROPPED; + } + + @Override + public void serialize(DataOutputStream stream) throws IOException { + stream.writeShort(ProcedureType.DROP_MODEL_PROCEDURE.getTypeCode()); + super.serialize(stream); + ReadWriteIOUtils.write(modelName, stream); + } + + @Override + public void deserialize(ByteBuffer byteBuffer) { + super.deserialize(byteBuffer); + modelName = ReadWriteIOUtils.readString(byteBuffer); + } + + @Override + public boolean equals(Object that) { + if (that instanceof DropModelProcedure) { + DropModelProcedure thatProc = (DropModelProcedure) that; + return thatProc.getProcId() == this.getProcId() + && thatProc.getState() == this.getState() + && (thatProc.modelName).equals(this.modelName); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(getProcId(), getState(), modelName); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java new file mode 100644 index 000000000000..03ea4b13a279 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.procedure.impl.node; + +import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.utils.ThriftCommonsSerDeUtils; +import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; +import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan; +import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; +import org.apache.iotdb.confignode.procedure.exception.ProcedureException; +import org.apache.iotdb.confignode.procedure.exception.ProcedureSuspendedException; +import org.apache.iotdb.confignode.procedure.exception.ProcedureYieldException; +import org.apache.iotdb.confignode.procedure.state.RemoveAINodeState; +import org.apache.iotdb.confignode.procedure.store.ProcedureType; +import org.apache.iotdb.rpc.TSStatusCode; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +public class RemoveAINodeProcedure extends AbstractNodeProcedure { + + private static final Logger LOGGER = LoggerFactory.getLogger(RemoveAINodeProcedure.class); + private static final int RETRY_THRESHOLD = 5; + + private TAINodeLocation removedAINode; + + public RemoveAINodeProcedure(TAINodeLocation removedAINode) { + super(); + this.removedAINode = removedAINode; + } + + public RemoveAINodeProcedure() { + super(); + } + + @Override + protected Flow executeFromState(ConfigNodeProcedureEnv env, RemoveAINodeState state) + throws ProcedureSuspendedException, ProcedureYieldException, InterruptedException { + if (removedAINode == null) { + return Flow.NO_MORE_STATE; + } + + try { + switch (state) { + case MODEL_DELETE: + env.getConfigManager() + .getConsensusManager() + .write(new DropModelInNodePlan(removedAINode.aiNodeId)); + // Cause the AINode is removed, so we don't need to remove the model file. + setNextState(RemoveAINodeState.NODE_REMOVE); + break; + case NODE_REMOVE: + TSStatus response = + env.getConfigManager() + .getConsensusManager() + .write(new RemoveAINodePlan(removedAINode)); + + if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + throw new ProcedureException( + String.format( + "Fail to remove [%s] AINode on Config Nodes [%s]", + removedAINode, response.getMessage())); + } + return Flow.NO_MORE_STATE; + default: + throw new UnsupportedOperationException( + String.format("Unknown state during executing removeAINodeProcedure, %s", state)); + } + } catch (Exception e) { + if (isRollbackSupported(state)) { + setFailure(new ProcedureException(e.getMessage())); + } else { + LOGGER.error( + "Retrievable error trying to remove AINode [{}], state [{}]", removedAINode, state, e); + if (getCycles() > RETRY_THRESHOLD) { + setFailure( + new ProcedureException( + String.format( + "Fail to remove AINode [%s] at STATE [%s], %s", + removedAINode, state, e.getMessage()))); + } + } + } + return Flow.HAS_MORE_STATE; + } + + @Override + protected void rollbackState( + ConfigNodeProcedureEnv configNodeProcedureEnv, RemoveAINodeState removeAINodeState) + throws IOException, InterruptedException, ProcedureException { + // no need to rollback + } + + @Override + protected RemoveAINodeState getState(int stateId) { + return RemoveAINodeState.values()[stateId]; + } + + @Override + protected int getStateId(RemoveAINodeState removeAINodeState) { + return removeAINodeState.ordinal(); + } + + @Override + protected RemoveAINodeState getInitialState() { + return RemoveAINodeState.MODEL_DELETE; + } + + @Override + public void serialize(DataOutputStream stream) throws IOException { + stream.writeShort(ProcedureType.REMOVE_AI_NODE_PROCEDURE.getTypeCode()); + super.serialize(stream); + ThriftCommonsSerDeUtils.serializeTAINodeLocation(removedAINode, stream); + } + + @Override + public void deserialize(ByteBuffer byteBuffer) { + super.deserialize(byteBuffer); + removedAINode = ThriftCommonsSerDeUtils.deserializeTAINodeLocation(byteBuffer); + } + + @Override + public boolean equals(Object that) { + if (that instanceof RemoveAINodeProcedure) { + RemoveAINodeProcedure thatProc = (RemoveAINodeProcedure) that; + return thatProc.getProcId() == this.getProcId() + && thatProc.getState() == this.getState() + && (thatProc.removedAINode).equals(this.removedAINode); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(getProcId(), getState(), removedAINode); + } +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java new file mode 100644 index 000000000000..eecb5a4d9d98 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.procedure.state; + +public enum RemoveAINodeState { + MODEL_DELETE, + NODE_REMOVE +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/CreateModelState.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/CreateModelState.java new file mode 100644 index 000000000000..9bf9d347afb2 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/CreateModelState.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.procedure.state.model; + +public enum CreateModelState { + LOADING, + ACTIVE +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java new file mode 100644 index 000000000000..a06c19cc7046 --- /dev/null +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.confignode.procedure.state.model; + +public enum DropModelState { + AI_NODE_DROPPED, + CONFIG_NODE_DROPPED +} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java index 7297e2ead0de..33cdfabd85e5 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java @@ -22,7 +22,10 @@ import org.apache.iotdb.commons.exception.runtime.ThriftSerDeException; import org.apache.iotdb.confignode.procedure.Procedure; import org.apache.iotdb.confignode.procedure.impl.cq.CreateCQProcedure; +import org.apache.iotdb.confignode.procedure.impl.model.CreateModelProcedure; +import org.apache.iotdb.confignode.procedure.impl.model.DropModelProcedure; import org.apache.iotdb.confignode.procedure.impl.node.AddConfigNodeProcedure; +import org.apache.iotdb.confignode.procedure.impl.node.RemoveAINodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveConfigNodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveDataNodeProcedure; import org.apache.iotdb.confignode.procedure.impl.pipe.plugin.CreatePipePluginProcedure; @@ -188,6 +191,12 @@ public Procedure create(ByteBuffer buffer) throws IOException { case DROP_PIPE_PLUGIN_PROCEDURE: procedure = new DropPipePluginProcedure(); break; + case CREATE_MODEL_PROCEDURE: + procedure = new CreateModelProcedure(); + break; + case DROP_MODEL_PROCEDURE: + procedure = new DropModelProcedure(); + break; case AUTH_OPERATE_PROCEDURE: procedure = new AuthOperationProcedure(false); break; @@ -221,6 +230,9 @@ public Procedure create(ByteBuffer buffer) throws IOException { case PIPE_ENRICHED_AUTH_OPERATE_PROCEDURE: procedure = new AuthOperationProcedure(true); break; + case REMOVE_AI_NODE_PROCEDURE: + procedure = new RemoveAINodeProcedure(); + break; case PIPE_ENRICHED_SET_TTL_PROCEDURE: procedure = new SetTTLProcedure(true); break; @@ -289,6 +301,8 @@ public static ProcedureType getProcedureType(Procedure procedure) { return ProcedureType.REMOVE_CONFIG_NODE_PROCEDURE; } else if (procedure instanceof RemoveDataNodeProcedure) { return ProcedureType.REMOVE_DATA_NODE_PROCEDURE; + } else if (procedure instanceof RemoveAINodeProcedure) { + return ProcedureType.REMOVE_AI_NODE_PROCEDURE; } else if (procedure instanceof RegionMigrateProcedure) { return ProcedureType.REGION_MIGRATE_PROCEDURE; } else if (procedure instanceof AddRegionPeerProcedure) { @@ -323,6 +337,10 @@ public static ProcedureType getProcedureType(Procedure procedure) { return ProcedureType.CREATE_PIPE_PLUGIN_PROCEDURE; } else if (procedure instanceof DropPipePluginProcedure) { return ProcedureType.DROP_PIPE_PLUGIN_PROCEDURE; + } else if (procedure instanceof CreateModelProcedure) { + return ProcedureType.CREATE_MODEL_PROCEDURE; + } else if (procedure instanceof DropModelProcedure) { + return ProcedureType.DROP_MODEL_PROCEDURE; } else if (procedure instanceof CreatePipeProcedureV2) { return ProcedureType.CREATE_PIPE_PROCEDURE_V2; } else if (procedure instanceof StartPipeProcedureV2) { diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java index 683365b1dfc5..093be00baf04 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java @@ -62,6 +62,11 @@ public enum ProcedureType { UNSET_TEMPLATE_PROCEDURE((short) 701), SET_TEMPLATE_PROCEDURE((short) 702), + /** AI Model */ + CREATE_MODEL_PROCEDURE((short) 800), + DROP_MODEL_PROCEDURE((short) 801), + REMOVE_AI_NODE_PROCEDURE((short) 802), + // ProcedureId 800-899 is used by IoTDB-Ml /** Pipe Plugin */ diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java index f1a09ad088a0..d7fc4527a0d9 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java @@ -46,6 +46,7 @@ import org.apache.iotdb.confignode.conf.SystemPropertiesUtils; import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; import org.apache.iotdb.confignode.consensus.request.auth.AuthorPlan; +import org.apache.iotdb.confignode.consensus.request.read.ainode.GetAINodeConfigurationPlan; import org.apache.iotdb.confignode.consensus.request.read.database.CountDatabasePlan; import org.apache.iotdb.confignode.consensus.request.read.database.GetDatabasePlan; import org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan; @@ -53,6 +54,7 @@ import org.apache.iotdb.confignode.consensus.request.read.partition.GetOrCreateDataPartitionPlan; import org.apache.iotdb.confignode.consensus.request.read.region.GetRegionInfoListPlan; import org.apache.iotdb.confignode.consensus.request.read.ttl.ShowTTLPlan; +import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; import org.apache.iotdb.confignode.consensus.request.write.confignode.RemoveConfigNodePlan; import org.apache.iotdb.confignode.consensus.request.write.database.DatabaseSchemaPlan; import org.apache.iotdb.confignode.consensus.request.write.database.SetDataReplicationFactorPlan; @@ -60,6 +62,8 @@ import org.apache.iotdb.confignode.consensus.request.write.database.SetTTLPlan; import org.apache.iotdb.confignode.consensus.request.write.database.SetTimePartitionIntervalPlan; import org.apache.iotdb.confignode.consensus.request.write.datanode.RemoveDataNodePlan; +import org.apache.iotdb.confignode.consensus.response.ainode.AINodeConfigurationResp; +import org.apache.iotdb.confignode.consensus.response.ainode.AINodeRegisterResp; import org.apache.iotdb.confignode.consensus.response.auth.PermissionInfoResp; import org.apache.iotdb.confignode.consensus.response.database.CountDatabaseResp; import org.apache.iotdb.confignode.consensus.response.database.DatabaseSchemaResp; @@ -72,6 +76,12 @@ import org.apache.iotdb.confignode.manager.ConfigManager; import org.apache.iotdb.confignode.manager.consensus.ConsensusManager; import org.apache.iotdb.confignode.rpc.thrift.IConfigNodeRPCService; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeConfigurationResp; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterResp; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRemoveReq; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp; import org.apache.iotdb.confignode.rpc.thrift.TAddConsensusGroupReq; import org.apache.iotdb.confignode.rpc.thrift.TAlterLogicalViewReq; import org.apache.iotdb.confignode.rpc.thrift.TAlterPipeReq; @@ -91,6 +101,7 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; +import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -114,6 +125,7 @@ import org.apache.iotdb.confignode.rpc.thrift.TDeleteTimeSeriesReq; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; +import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropTopicReq; @@ -128,6 +140,8 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -155,11 +169,14 @@ import org.apache.iotdb.confignode.rpc.thrift.TSetSchemaReplicationFactorReq; import org.apache.iotdb.confignode.rpc.thrift.TSetSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TSetTimePartitionIntervalReq; +import org.apache.iotdb.confignode.rpc.thrift.TShowAINodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowCQResp; import org.apache.iotdb.confignode.rpc.thrift.TShowClusterResp; import org.apache.iotdb.confignode.rpc.thrift.TShowConfigNodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowDataNodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowDatabaseResp; +import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq; +import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp; import org.apache.iotdb.confignode.rpc.thrift.TShowPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TShowPipeResp; import org.apache.iotdb.confignode.rpc.thrift.TShowRegionReq; @@ -279,6 +296,47 @@ public TDataNodeRestartResp restartDataNode(TDataNodeRestartReq req) { return resp; } + @Override + public TAINodeRegisterResp registerAINode(TAINodeRegisterReq req) { + TAINodeRegisterResp resp = + ((AINodeRegisterResp) configManager.registerAINode(req)).convertToAINodeRegisterResp(); + LOGGER.info("Execute RegisterAINodeRequest {} with result {}", req, resp); + return resp; + } + + @Override + public TAINodeRestartResp restartAINode(TAINodeRestartReq req) { + TAINodeRestartResp resp = configManager.restartAINode(req); + LOGGER.info("Execute RestartAINodeRequest {} with result {}", req, resp); + return resp; + } + + @Override + public TSStatus removeAINode(TAINodeRemoveReq req) { + LOGGER.info("ConfigNode RPC Service start to remove AINode, req: {}", req); + RemoveAINodePlan removeAINodePlan = new RemoveAINodePlan(req.getAiNodeLocation()); + TSStatus status = configManager.removeAINode(removeAINodePlan); + LOGGER.info( + "ConfigNode RPC Service finished to remove AINode, req: {}, result: {}", req, status); + return status; + } + + @Override + public TShowAINodesResp showAINodes() throws TException { + return configManager.showAINodes(); + } + + @Override + public TAINodeConfigurationResp getAINodeConfiguration(int aiNodeId) throws TException { + GetAINodeConfigurationPlan getAINodeConfigurationPlan = + new GetAINodeConfigurationPlan(aiNodeId); + AINodeConfigurationResp aiNodeConfigurationResp = + (AINodeConfigurationResp) configManager.getAINodeConfiguration(getAINodeConfigurationPlan); + TAINodeConfigurationResp resp = new TAINodeConfigurationResp(); + aiNodeConfigurationResp.convertToRpcAINodeLocationResp(resp); + return resp; + } + @Override public TDataNodeRemoveResp removeDataNode(TDataNodeRemoveReq req) { LOGGER.info("ConfigNode RPC Service start to remove DataNode, req: {}", req); @@ -1165,6 +1223,26 @@ public TShowCQResp showCQ() { return configManager.showCQ(); } + @Override + public TSStatus createModel(TCreateModelReq req) { + return configManager.createModel(req); + } + + @Override + public TSStatus dropModel(TDropModelReq req) { + return configManager.dropModel(req); + } + + @Override + public TShowModelResp showModel(TShowModelReq req) { + return configManager.showModel(req); + } + + @Override + public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { + return configManager.getModelInfo(req); + } + @Override public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) throws TException { return configManager.setSpaceQuota(req); diff --git a/iotdb-core/datanode/pom.xml b/iotdb-core/datanode/pom.xml index 3c13cbc4fe7f..fa90a07cac88 100644 --- a/iotdb-core/datanode/pom.xml +++ b/iotdb-core/datanode/pom.xml @@ -114,6 +114,11 @@ iotdb-thrift-confignode 1.3.3-SNAPSHOT + + org.apache.iotdb + iotdb-thrift-ainode + 1.3.3-SNAPSHOT + org.apache.iotdb pipe-api diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java index 1b291eb7c335..9f6355b99ebf 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java @@ -432,6 +432,9 @@ public class IoTDBConfig { /** Compact the unsequence files into the overlapped sequence files */ private volatile boolean enableCrossSpaceCompaction = true; + /** Enable the service for AINode */ + private boolean enableAINodeService = false; + /** The buffer for sort operation */ private long sortBufferSize = 1024 * 1024L; @@ -910,6 +913,9 @@ public class IoTDBConfig { /** Internal port for coordinator */ private int internalPort = 10730; + /** Port for AINode */ + private int aiNodePort = 10780; + /** Internal port for dataRegion consensus protocol */ private int dataRegionConsensusPort = 10760; @@ -2848,6 +2854,14 @@ public void setEnableCrossSpaceCompaction(boolean enableCrossSpaceCompaction) { this.enableCrossSpaceCompaction = enableCrossSpaceCompaction; } + public boolean isEnableAINodeService() { + return enableAINodeService; + } + + public void setEnableAINodeService(boolean enableAINodeService) { + this.enableAINodeService = enableAINodeService; + } + public InnerSequenceCompactionSelector getInnerSequenceCompactionSelector() { return innerSequenceCompactionSelector; } @@ -3126,6 +3140,14 @@ public void setInternalPort(int internalPort) { this.internalPort = internalPort; } + public int getAINodePort() { + return aiNodePort; + } + + public void setAINodePort(int aiNodePort) { + this.aiNodePort = aiNodePort; + } + public int getDataRegionConsensusPort() { return dataRegionConsensusPort; } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java index d8da4478ac30..5eaa88e3ee5d 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java @@ -287,6 +287,19 @@ public void loadProperties(Properties properties) throws BadNodeUrlException, IO .getProperty(IoTDBConstant.DN_RPC_PORT, Integer.toString(conf.getRpcPort())) .trim())); + conf.setEnableAINodeService( + Boolean.parseBoolean( + properties + .getProperty( + "enable_ainode_rpc_service", Boolean.toString(conf.isEnableAINodeService())) + .trim())); + + conf.setAINodePort( + Integer.parseInt( + properties + .getProperty("ainode_rpc_port", Integer.toString(conf.getAINodePort())) + .trim())); + conf.setBufferedArraysMemoryProportion( Double.parseDouble( properties diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java new file mode 100644 index 000000000000..03402d30c643 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.exception.ainode; + +import org.apache.iotdb.rpc.TSStatusCode; + +public class GetModelInfoException extends ModelException { + public GetModelInfoException(String message) { + super(message, TSStatusCode.GET_MODEL_INFO_ERROR); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java new file mode 100644 index 000000000000..4a007e7048ce --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.exception.ainode; + +import org.apache.iotdb.rpc.TSStatusCode; + +public class ModelException extends RuntimeException { + TSStatusCode statusCode; + + public ModelException(String message, TSStatusCode code) { + super(message); + this.statusCode = code; + } + + public TSStatusCode getStatusCode() { + return statusCode; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelNotFoundException.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelNotFoundException.java new file mode 100644 index 000000000000..38a5105cded1 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelNotFoundException.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.exception.ainode; + +import org.apache.iotdb.rpc.TSStatusCode; + +public class ModelNotFoundException extends ModelException { + public ModelNotFoundException(String message) { + super(message, TSStatusCode.MODEL_NOT_FOUND_ERROR); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/runtime/ModelInferenceProcessException.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/runtime/ModelInferenceProcessException.java new file mode 100644 index 000000000000..586c624a8d33 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/runtime/ModelInferenceProcessException.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.exception.runtime; + +public class ModelInferenceProcessException extends RuntimeException { + + public ModelInferenceProcessException(String message) { + super(message); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java index 122ad5d7c37e..d146712ecf1b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java @@ -40,6 +40,12 @@ import org.apache.iotdb.commons.client.sync.SyncThriftClientWithErrorHandler; import org.apache.iotdb.commons.consensus.ConfigRegionId; import org.apache.iotdb.confignode.rpc.thrift.IConfigNodeRPCService; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeConfigurationResp; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterResp; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRemoveReq; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq; +import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp; import org.apache.iotdb.confignode.rpc.thrift.TAddConsensusGroupReq; import org.apache.iotdb.confignode.rpc.thrift.TAlterLogicalViewReq; import org.apache.iotdb.confignode.rpc.thrift.TAlterPipeReq; @@ -59,6 +65,7 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; +import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -82,6 +89,7 @@ import org.apache.iotdb.confignode.rpc.thrift.TDeleteTimeSeriesReq; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; +import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropTopicReq; @@ -96,6 +104,8 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -123,11 +133,14 @@ import org.apache.iotdb.confignode.rpc.thrift.TSetSchemaReplicationFactorReq; import org.apache.iotdb.confignode.rpc.thrift.TSetSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TSetTimePartitionIntervalReq; +import org.apache.iotdb.confignode.rpc.thrift.TShowAINodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowCQResp; import org.apache.iotdb.confignode.rpc.thrift.TShowClusterResp; import org.apache.iotdb.confignode.rpc.thrift.TShowConfigNodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowDataNodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowDatabaseResp; +import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq; +import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp; import org.apache.iotdb.confignode.rpc.thrift.TShowPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TShowPipeResp; import org.apache.iotdb.confignode.rpc.thrift.TShowRegionReq; @@ -447,6 +460,33 @@ public TDataNodeRestartResp restartDataNode(TDataNodeRestartReq req) throws TExc () -> client.restartDataNode(req), resp -> !updateConfigNodeLeader(resp.status)); } + @Override + public TAINodeRegisterResp registerAINode(TAINodeRegisterReq req) throws TException { + throw new UnsupportedOperationException("RegisterAINode method is not supported in datanode"); + } + + @Override + public TAINodeRestartResp restartAINode(TAINodeRestartReq req) throws TException { + throw new UnsupportedOperationException("RestartAINode method is not supported in datanode"); + } + + @Override + public TSStatus removeAINode(TAINodeRemoveReq req) throws TException { + throw new UnsupportedOperationException("RemoveAINode method is not supported in datanode"); + } + + @Override + public TShowAINodesResp showAINodes() throws TException { + return executeRemoteCallWithRetry( + () -> client.showAINodes(), resp -> !updateConfigNodeLeader(resp.status)); + } + + @Override + public TAINodeConfigurationResp getAINodeConfiguration(int aiNodeId) throws TException { + throw new UnsupportedOperationException( + "GetAINodeConfiguration method is not supported in datanode"); + } + @Override public TDataNodeRemoveResp removeDataNode(TDataNodeRemoveReq req) throws TException { return executeRemoteCallWithRetry( @@ -1139,6 +1179,30 @@ public TShowCQResp showCQ() throws TException { () -> client.showCQ(), resp -> !updateConfigNodeLeader(resp.status)); } + @Override + public TSStatus createModel(TCreateModelReq req) throws TException { + return executeRemoteCallWithRetry( + () -> client.createModel(req), status -> !updateConfigNodeLeader(status)); + } + + @Override + public TSStatus dropModel(TDropModelReq req) throws TException { + return executeRemoteCallWithRetry( + () -> client.dropModel(req), status -> !updateConfigNodeLeader(status)); + } + + @Override + public TShowModelResp showModel(TShowModelReq req) throws TException { + return executeRemoteCallWithRetry( + () -> client.showModel(req), resp -> !updateConfigNodeLeader(resp.status)); + } + + @Override + public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) throws TException { + return executeRemoteCallWithRetry( + () -> client.getModelInfo(req), resp -> !updateConfigNodeLeader(resp.getStatus())); + } + @Override public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) throws TException { return executeRemoteCallWithRetry( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/handler/AINodeRPCServiceThriftHandler.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/handler/AINodeRPCServiceThriftHandler.java new file mode 100644 index 000000000000..c5969f8678a3 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/handler/AINodeRPCServiceThriftHandler.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + *

http://www.apache.org/licenses/LICENSE-2.0 + * + *

Unless required by applicable law or agreed to in writing, software distributed under the + * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.protocol.thrift.handler; + +import org.apache.iotdb.db.protocol.thrift.impl.IAINodeRPCServiceWithHandler; + +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.server.ServerContext; +import org.apache.thrift.server.TServerEventHandler; +import org.apache.thrift.transport.TTransport; + +import java.util.concurrent.atomic.AtomicLong; + +public class AINodeRPCServiceThriftHandler implements TServerEventHandler { + + private final AtomicLong thriftConnectionNumber = new AtomicLong(0); + private final IAINodeRPCServiceWithHandler eventHandler; + + public AINodeRPCServiceThriftHandler(IAINodeRPCServiceWithHandler eventHandler) { + this.eventHandler = eventHandler; + } + + @Override + public ServerContext createContext(TProtocol in, TProtocol out) { + thriftConnectionNumber.incrementAndGet(); + return null; + } + + @Override + public void deleteContext(ServerContext arg0, TProtocol in, TProtocol out) { + thriftConnectionNumber.decrementAndGet(); + eventHandler.handleExit(); + } + + @Override + public void preServe() { + // do nothing + } + + @Override + public void processContext( + ServerContext serverContext, TTransport tTransport, TTransport tTransport1) { + // do nothing + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/AINodeRPCServiceImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/AINodeRPCServiceImpl.java new file mode 100644 index 000000000000..68e492cdf462 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/AINodeRPCServiceImpl.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.protocol.thrift.impl; + +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.conf.IoTDBConstant.ClientVersion; +import org.apache.iotdb.db.protocol.session.IClientSession; +import org.apache.iotdb.db.protocol.session.InternalClientSession; +import org.apache.iotdb.db.protocol.session.SessionManager; +import org.apache.iotdb.db.protocol.thrift.OperationType; +import org.apache.iotdb.db.queryengine.common.header.DatasetHeader; +import org.apache.iotdb.db.queryengine.plan.Coordinator; +import org.apache.iotdb.db.queryengine.plan.analyze.ClusterPartitionFetcher; +import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; +import org.apache.iotdb.db.queryengine.plan.analyze.schema.ClusterSchemaFetcher; +import org.apache.iotdb.db.queryengine.plan.analyze.schema.ISchemaFetcher; +import org.apache.iotdb.db.queryengine.plan.execution.ExecutionResult; +import org.apache.iotdb.db.queryengine.plan.execution.IQueryExecution; +import org.apache.iotdb.db.queryengine.plan.parser.StatementGenerator; +import org.apache.iotdb.db.queryengine.plan.statement.Statement; +import org.apache.iotdb.db.utils.ErrorHandlingUtils; +import org.apache.iotdb.db.utils.QueryDataSetUtils; +import org.apache.iotdb.db.utils.SetThreadName; +import org.apache.iotdb.mpp.rpc.thrift.TFetchMoreDataReq; +import org.apache.iotdb.mpp.rpc.thrift.TFetchMoreDataResp; +import org.apache.iotdb.mpp.rpc.thrift.TFetchTimeseriesReq; +import org.apache.iotdb.mpp.rpc.thrift.TFetchTimeseriesResp; +import org.apache.iotdb.rpc.RpcUtils; +import org.apache.iotdb.rpc.TSStatusCode; + +import org.apache.thrift.TException; +import org.apache.tsfile.utils.Pair; + +import java.nio.ByteBuffer; +import java.time.ZoneId; +import java.util.List; + +public class AINodeRPCServiceImpl implements IAINodeRPCServiceWithHandler { + + public static final String AI_METRICS_PATH_PREFIX = "root.__system.AI.exp"; + + private static final SessionManager SESSION_MANAGER = SessionManager.getInstance(); + + private static final Coordinator COORDINATOR = Coordinator.getInstance(); + + private final IPartitionFetcher partitionFetcher; + + private final ISchemaFetcher schemaFetcher; + + private final IClientSession session; + + public AINodeRPCServiceImpl() { + super(); + partitionFetcher = ClusterPartitionFetcher.getInstance(); + schemaFetcher = ClusterSchemaFetcher.getInstance(); + session = new InternalClientSession("AINodeService"); + SESSION_MANAGER.registerSession(session); + SESSION_MANAGER.supplySession(session, "AINode", ZoneId.systemDefault(), ClientVersion.V_1_0); + } + + @Override + public TFetchTimeseriesResp fetchTimeseries(TFetchTimeseriesReq req) throws TException { + boolean finished = false; + TFetchTimeseriesResp resp = new TFetchTimeseriesResp(); + Throwable t = null; + try { + + Statement s = StatementGenerator.createStatement(req, session.getZoneId()); + + if (s == null) { + resp.setStatus( + RpcUtils.getStatus( + TSStatusCode.SQL_PARSE_ERROR, "This operation type is not supported")); + return resp; + } + + long queryId = + SESSION_MANAGER.requestQueryId(session, SESSION_MANAGER.requestStatementId(session)); + ExecutionResult result = + COORDINATOR.executeForTreeModel( + s, + queryId, + SESSION_MANAGER.getSessionInfo(session), + "", + partitionFetcher, + schemaFetcher, + req.getTimeout()); + + if (result.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode() + && result.status.code != TSStatusCode.REDIRECTION_RECOMMEND.getStatusCode()) { + resp.setStatus(result.status); + return resp; + } + + IQueryExecution queryExecution = COORDINATOR.getQueryExecution(queryId); + + try (SetThreadName threadName = new SetThreadName(result.queryId.getId())) { + + DatasetHeader header = queryExecution.getDatasetHeader(); + resp.setStatus(result.status); + resp.setColumnNameList(header.getRespColumns()); + resp.setColumnTypeList(header.getRespDataTypeList()); + resp.setColumnNameIndexMap(header.getColumnNameIndexMap()); + resp.setQueryId(queryId); + + Pair, Boolean> pair = + QueryDataSetUtils.convertQueryResultByFetchSize(queryExecution, req.fetchSize); + resp.setTsDataset(pair.left); + finished = pair.right; + resp.setHasMoreData(!finished); + return resp; + } + } catch (Exception e) { + finished = true; + t = e; + resp.setStatus(ErrorHandlingUtils.onQueryException(e, OperationType.EXECUTE_STATEMENT)); + return resp; + } catch (Error error) { + t = error; + throw error; + } finally { + if (finished) { + COORDINATOR.cleanupQueryExecution(resp.queryId, req, t); + } + } + } + + @Override + public TFetchMoreDataResp fetchMoreData(TFetchMoreDataReq req) throws TException { + TFetchMoreDataResp resp = new TFetchMoreDataResp(); + boolean finished = false; + Throwable t = null; + try { + IQueryExecution queryExecution = COORDINATOR.getQueryExecution(req.queryId); + resp.setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); + + if (queryExecution == null) { + resp.setHasMoreData(false); + return resp; + } + + try (SetThreadName queryName = new SetThreadName(queryExecution.getQueryId())) { + Pair, Boolean> pair = + QueryDataSetUtils.convertQueryResultByFetchSize(queryExecution, req.fetchSize); + List result = pair.left; + finished = pair.right; + resp.setTsDataset(result); + resp.setHasMoreData(!finished); + return resp; + } + } catch (Exception e) { + finished = true; + t = e; + resp.setStatus(ErrorHandlingUtils.onQueryException(e, OperationType.FETCH_RESULTS)); + return resp; + } catch (Error error) { + t = error; + throw error; + } finally { + if (finished) { + COORDINATOR.cleanupQueryExecution(req.queryId, req, t); + } + } + } + + @Override + public void handleExit() { + SESSION_MANAGER.closeSession(session, COORDINATOR::cleanupQueryExecution); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/DataNodeInternalRPCServiceImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/DataNodeInternalRPCServiceImpl.java index 7755bb99e31c..3604f0c00ddf 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/DataNodeInternalRPCServiceImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/DataNodeInternalRPCServiceImpl.java @@ -24,6 +24,7 @@ import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.common.rpc.thrift.TFlushReq; +import org.apache.iotdb.common.rpc.thrift.TLoadSample; import org.apache.iotdb.common.rpc.thrift.TNodeLocations; import org.apache.iotdb.common.rpc.thrift.TSStatus; import org.apache.iotdb.common.rpc.thrift.TSender; @@ -204,7 +205,6 @@ import org.apache.iotdb.mpp.rpc.thrift.TInvalidatePermissionCacheReq; import org.apache.iotdb.mpp.rpc.thrift.TLoadCommandReq; import org.apache.iotdb.mpp.rpc.thrift.TLoadResp; -import org.apache.iotdb.mpp.rpc.thrift.TLoadSample; import org.apache.iotdb.mpp.rpc.thrift.TMaintainPeerReq; import org.apache.iotdb.mpp.rpc.thrift.TPipeHeartbeatReq; import org.apache.iotdb.mpp.rpc.thrift.TPipeHeartbeatResp; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/IAINodeRPCServiceWithHandler.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/IAINodeRPCServiceWithHandler.java new file mode 100644 index 000000000000..7d9df50ac3ce --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/IAINodeRPCServiceWithHandler.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.protocol.thrift.impl; + +import org.apache.iotdb.mpp.rpc.thrift.IAINodeInternalRPCService; + +public interface IAINodeRPCServiceWithHandler extends IAINodeInternalRPCService.Iface { + void handleExit(); +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/header/ColumnHeaderConstant.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/header/ColumnHeaderConstant.java index f4fd39bf504d..088dcdabde80 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/header/ColumnHeaderConstant.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/header/ColumnHeaderConstant.java @@ -46,6 +46,7 @@ private ColumnHeaderConstant() { public static final String COMPRESSION = "Compression"; public static final String TAGS = "Tags"; public static final String ATTRIBUTES = "Attributes"; + public static final String NOTES = "Notes"; public static final String DEADBAND = "Deadband"; public static final String DEADBAND_PARAMETERS = "DeadbandParameters"; public static final String IS_ALIGNED = "IsAligned"; @@ -109,6 +110,8 @@ private ColumnHeaderConstant() { public static final String TRIGGER_NAME = "TriggerName"; public static final String EVENT = "Event"; public static final String STATE = "State"; + public static final String MODEL_TYPE = "ModelType"; + public static final String CONFIGS = "Configs"; public static final String PATH_PATTERN = "PathPattern"; public static final String CLASS_NAME = "ClassName"; @@ -128,6 +131,7 @@ private ColumnHeaderConstant() { // show cluster status public static final String NODE_TYPE_CONFIG_NODE = "ConfigNode"; public static final String NODE_TYPE_DATA_NODE = "DataNode"; + public static final String NODE_TYPE_AI_NODE = "AINode"; public static final String COLUMN_CLUSTER_NAME = "ClusterName"; public static final String CONFIG_NODE_CONSENSUS_PROTOCOL_CLASS = "ConfigNodeConsensusProtocolClass"; @@ -200,6 +204,9 @@ private ColumnHeaderConstant() { public static final String USER = "User"; public static final String READ_WRITE = "Read/Write"; + // column names for show models/trials + public static final String MODEL_ID = "ModelId"; + // column names for views (e.g. logical view) public static final String VIEW_TYPE = "ViewType"; public static final String SOURCE = "Source"; @@ -329,6 +336,13 @@ private ColumnHeaderConstant() { new ColumnHeader(ROLE, TSDataType.TEXT), new ColumnHeader(CREATE_TIME, TSDataType.TEXT)); + public static final List showAINodesColumnHeaders = + ImmutableList.of( + new ColumnHeader(NODE_ID, TSDataType.INT32), + new ColumnHeader(STATUS, TSDataType.TEXT), + new ColumnHeader(RPC_ADDRESS, TSDataType.TEXT), + new ColumnHeader(RPC_PORT, TSDataType.INT32)); + public static final List showDataNodesColumnHeaders = ImmutableList.of( new ColumnHeader(NODE_ID, TSDataType.INT32), @@ -497,6 +511,14 @@ private ColumnHeaderConstant() { new ColumnHeader(LIMIT, TSDataType.TEXT), new ColumnHeader(READ_WRITE, TSDataType.TEXT)); + public static final List showModelsColumnHeaders = + ImmutableList.of( + new ColumnHeader(MODEL_ID, TSDataType.TEXT), + new ColumnHeader(MODEL_TYPE, TSDataType.TEXT), + new ColumnHeader(STATE, TSDataType.TEXT), + new ColumnHeader(CONFIGS, TSDataType.TEXT), + new ColumnHeader(NOTES, TSDataType.TEXT)); + public static final List showLogicalViewColumnHeaders = ImmutableList.of( new ColumnHeader(TIMESERIES, TSDataType.TEXT), diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/header/DatasetHeaderFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/header/DatasetHeaderFactory.java index 9c0f89b19a57..132dafd246d9 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/header/DatasetHeaderFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/header/DatasetHeaderFactory.java @@ -119,6 +119,10 @@ public static DatasetHeader getShowRegionHeader() { return new DatasetHeader(ColumnHeaderConstant.showRegionColumnHeaders, true); } + public static DatasetHeader getShowAINodesHeader() { + return new DatasetHeader(ColumnHeaderConstant.showAINodesColumnHeaders, true); + } + public static DatasetHeader getShowDataNodesHeader() { return new DatasetHeader(ColumnHeaderConstant.showDataNodesColumnHeaders, true); } @@ -201,6 +205,10 @@ public static DatasetHeader getShowThrottleQuotaHeader() { return new DatasetHeader(ColumnHeaderConstant.showThrottleQuotaColumnHeaders, true); } + public static DatasetHeader getShowModelsHeader() { + return new DatasetHeader(ColumnHeaderConstant.showModelsColumnHeaders, true); + } + public static DatasetHeader getShowLogicalViewHeader() { return new DatasetHeader(ColumnHeaderConstant.showLogicalViewColumnHeaders, true); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/fragment/FragmentInstanceManager.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/fragment/FragmentInstanceManager.java index 40a120da7250..a82e5a6d92eb 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/fragment/FragmentInstanceManager.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/fragment/FragmentInstanceManager.java @@ -23,6 +23,7 @@ import org.apache.iotdb.commons.concurrent.IoTDBThreadPoolFactory; import org.apache.iotdb.commons.concurrent.ThreadName; import org.apache.iotdb.commons.concurrent.threadpool.ScheduledExecutorUtil; +import org.apache.iotdb.commons.conf.CommonDescriptor; import org.apache.iotdb.commons.exception.IoTDBException; import org.apache.iotdb.db.conf.IoTDBDescriptor; import org.apache.iotdb.db.queryengine.common.FragmentInstanceId; @@ -82,6 +83,7 @@ public class FragmentInstanceManager { private final Duration infoCacheTime; private final ExecutorService intoOperationExecutor; + private final ExecutorService modelInferenceExecutor; private final MPPDataExchangeManager exchangeManager = MPPDataExchangeService.getInstance().getMPPDataExchangeManager(); @@ -119,6 +121,11 @@ private FragmentInstanceManager() { IoTDBThreadPoolFactory.newFixedThreadPool( IoTDBDescriptor.getInstance().getConfig().getIntoOperationExecutionThreadCount(), "into-operation-executor"); + + this.modelInferenceExecutor = + IoTDBThreadPoolFactory.newFixedThreadPool( + CommonDescriptor.getInstance().getConfig().getModelInferenceExecutionThreadCount(), + "model-inference-executor"); } @SuppressWarnings("squid:S1181") @@ -424,6 +431,10 @@ public ExecutorService getIntoOperationExecutor() { return intoOperationExecutor; } + public ExecutorService getModelInferenceExecutor() { + return modelInferenceExecutor; + } + private static class InstanceHolder { private InstanceHolder() {} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/AI/InferenceOperator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/AI/InferenceOperator.java new file mode 100644 index 000000000000..ab8cf811af7e --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/AI/InferenceOperator.java @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.process.AI; + +import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; +import org.apache.iotdb.ainode.rpc.thrift.TWindowParams; +import org.apache.iotdb.commons.client.ainode.AINodeClient; +import org.apache.iotdb.commons.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.exception.runtime.ModelInferenceProcessException; +import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper; +import org.apache.iotdb.db.queryengine.execution.operator.Operator; +import org.apache.iotdb.db.queryengine.execution.operator.OperatorContext; +import org.apache.iotdb.db.queryengine.execution.operator.process.ProcessOperator; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.BottomInferenceWindowParameter; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindowParameter; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowType; +import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; +import org.apache.iotdb.rpc.TSStatusCode; + +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.TsBlock; +import org.apache.tsfile.read.common.block.TsBlockBuilder; +import org.apache.tsfile.read.common.block.column.TimeColumnBuilder; +import org.apache.tsfile.read.common.block.column.TsBlockSerde; +import org.apache.tsfile.utils.RamUsageEstimator; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.stream.Collectors; + +import static com.google.common.util.concurrent.Futures.successfulAsList; + +public class InferenceOperator implements ProcessOperator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(InferenceOperator.class); + + private final OperatorContext operatorContext; + private final Operator child; + private final ModelInferenceDescriptor modelInferenceDescriptor; + + private final TsBlockBuilder inputTsBlockBuilder; + + private final ExecutorService modelInferenceExecutor; + private ListenableFuture inferenceExecutionFuture; + + private boolean finished = false; + + private final long maxRetainedSize; + private final long maxReturnSize; + private final List inputColumnNames; + private final List targetColumnNames; + private long totalRow; + private int resultIndex = 0; + private List results; + private final TsBlockSerde serde = new TsBlockSerde(); + private InferenceWindowType windowType = null; + + public InferenceOperator( + OperatorContext operatorContext, + Operator child, + ModelInferenceDescriptor modelInferenceDescriptor, + ExecutorService modelInferenceExecutor, + List targetColumnNames, + List inputColumnNames, + long maxRetainedSize, + long maxReturnSize) { + this.operatorContext = operatorContext; + this.child = child; + this.modelInferenceDescriptor = modelInferenceDescriptor; + this.inputTsBlockBuilder = + new TsBlockBuilder( + Arrays.asList(modelInferenceDescriptor.getModelInformation().getInputDataType())); + this.modelInferenceExecutor = modelInferenceExecutor; + this.targetColumnNames = targetColumnNames; + this.inputColumnNames = inputColumnNames; + this.maxRetainedSize = maxRetainedSize; + this.maxReturnSize = maxReturnSize; + this.totalRow = 0; + + if (modelInferenceDescriptor.getInferenceWindowParameter() != null) { + windowType = modelInferenceDescriptor.getInferenceWindowParameter().getWindowType(); + } + } + + @Override + public OperatorContext getOperatorContext() { + return operatorContext; + } + + @Override + public ListenableFuture isBlocked() { + ListenableFuture childBlocked = child.isBlocked(); + boolean executionDone = forecastExecutionDone(); + if (executionDone && childBlocked.isDone()) { + return NOT_BLOCKED; + } else if (childBlocked.isDone()) { + return inferenceExecutionFuture; + } else if (executionDone) { + return childBlocked; + } else { + return successfulAsList(Arrays.asList(inferenceExecutionFuture, childBlocked)); + } + } + + private boolean forecastExecutionDone() { + if (inferenceExecutionFuture == null) { + return true; + } + return inferenceExecutionFuture.isDone(); + } + + @Override + public boolean hasNext() throws Exception { + return !finished || (results != null && results.size() != resultIndex); + } + + @Override + public TsBlock next() throws Exception { + if (inferenceExecutionFuture == null) { + if (child.hasNextWithTimer()) { + TsBlock inputTsBlock = child.nextWithTimer(); + if (inputTsBlock != null) { + appendTsBlockToBuilder(inputTsBlock); + } + } else { + submitInferenceTask(); + } + return null; + } else { + + if (results != null && resultIndex != results.size()) { + TsBlock tsBlock = serde.deserialize(results.get(resultIndex)); + resultIndex++; + return tsBlock; + } + + try { + if (!inferenceExecutionFuture.isDone()) { + throw new IllegalStateException( + "The operator cannot continue until the forecast execution is done."); + } + + TInferenceResp inferenceResp = inferenceExecutionFuture.get(); + if (inferenceResp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + String message = + String.format( + "Error occurred while executing inference:[%s]", + inferenceResp.getStatus().getMessage()); + throw new ModelInferenceProcessException(message); + } + + finished = true; + TsBlock resultTsBlock = serde.deserialize(inferenceResp.inferenceResult.get(0)); + results = inferenceResp.inferenceResult; + resultIndex++; + return resultTsBlock; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new ModelInferenceProcessException(e.getMessage()); + } catch (ExecutionException e) { + throw new ModelInferenceProcessException(e.getMessage()); + } + } + } + + private void appendTsBlockToBuilder(TsBlock inputTsBlock) { + TimeColumnBuilder timeColumnBuilder = inputTsBlockBuilder.getTimeColumnBuilder(); + ColumnBuilder[] columnBuilders = inputTsBlockBuilder.getValueColumnBuilders(); + totalRow += inputTsBlock.getPositionCount(); + for (int i = 0; i < inputTsBlock.getPositionCount(); i++) { + timeColumnBuilder.writeLong(inputTsBlock.getTimeByIndex(i)); + for (int columnIndex = 0; columnIndex < inputTsBlock.getValueColumnCount(); columnIndex++) { + columnBuilders[columnIndex].write(inputTsBlock.getColumn(columnIndex), i); + } + inputTsBlockBuilder.declarePosition(); + } + } + + private TWindowParams getWindowParams() { + TWindowParams windowParams; + if (windowType == null) { + return null; + } + if (windowType == InferenceWindowType.COUNT) { + CountInferenceWindowParameter countInferenceWindowParameter = + (CountInferenceWindowParameter) modelInferenceDescriptor.getInferenceWindowParameter(); + windowParams = new TWindowParams(); + windowParams.setWindowInterval((int) countInferenceWindowParameter.getInterval()); + windowParams.setWindowStep((int) countInferenceWindowParameter.getStep()); + } else { + windowParams = null; + } + return windowParams; + } + + private TsBlock preProcess(TsBlock inputTsBlock) { + boolean notBuiltIn = !modelInferenceDescriptor.getModelInformation().isBuiltIn(); + if (windowType == null || windowType == InferenceWindowType.HEAD) { + if (notBuiltIn + && totalRow != modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { + throw new ModelInferenceProcessException( + String.format( + "The number of rows %s in the input data does not match the model input %s. Try to use LIMIT in SQL or WINDOW in CALL INFERENCE", + totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); + } + return inputTsBlock; + } else if (windowType == InferenceWindowType.COUNT) { + if (notBuiltIn + && totalRow < modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { + throw new ModelInferenceProcessException( + String.format( + "The number of rows %s in the input data is less than the model input %s. ", + totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); + } + } else if (windowType == InferenceWindowType.TAIL) { + if (notBuiltIn + && totalRow < modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { + throw new ModelInferenceProcessException( + String.format( + "The number of rows %s in the input data is less than the model input %s. ", + totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); + } + // Tail window logic: get the latest data for inference + long windowSize = + (int) + ((BottomInferenceWindowParameter) + modelInferenceDescriptor.getInferenceWindowParameter()) + .getWindowSize(); + return inputTsBlock.subTsBlock((int) (totalRow - windowSize)); + } + return inputTsBlock; + } + + private void submitInferenceTask() { + + TsBlock inputTsBlock = inputTsBlockBuilder.build(); + + TsBlock finalInputTsBlock = preProcess(inputTsBlock); + TWindowParams windowParams = getWindowParams(); + + Map columnNameIndexMap = new HashMap<>(); + + for (int i = 0; i < inputColumnNames.size(); i++) { + columnNameIndexMap.put(inputColumnNames.get(i), i); + } + + inferenceExecutionFuture = + Futures.submit( + () -> { + try (AINodeClient client = + AINodeClientManager.getInstance() + .borrowClient(modelInferenceDescriptor.getTargetAINode())) { + return client.inference( + modelInferenceDescriptor.getModelName(), + targetColumnNames, + Arrays.stream(modelInferenceDescriptor.getModelInformation().getInputDataType()) + .map(TSDataType::toString) + .collect(Collectors.toList()), + columnNameIndexMap, + finalInputTsBlock, + modelInferenceDescriptor.getInferenceAttributes(), + windowParams); + } catch (Exception e) { + throw new ModelInferenceProcessException(e.getMessage()); + } + }, + modelInferenceExecutor); + } + + @Override + public boolean isFinished() throws Exception { + return finished && !hasNext(); + } + + @Override + public void close() throws Exception { + if (inferenceExecutionFuture != null) { + inferenceExecutionFuture.cancel(true); + } + child.close(); + } + + @Override + public long calculateMaxPeekMemory() { + return maxReturnSize + maxRetainedSize + child.calculateMaxPeekMemory(); + } + + @Override + public long calculateMaxReturnSize() { + return maxReturnSize; + } + + @Override + public long calculateRetainedSizeAfterCallingNext() { + return maxRetainedSize + child.calculateRetainedSizeAfterCallingNext(); + } + + @Override + public long ramBytesUsed() { + return INSTANCE_SIZE + + MemoryEstimationHelper.getEstimatedSizeOfAccountableObject(child) + + MemoryEstimationHelper.getEstimatedSizeOfAccountableObject(operatorContext) + + inputTsBlockBuilder.getRetainedSizeInBytes() + + (inputColumnNames == null + ? 0 + : inputColumnNames.stream().mapToLong(RamUsageEstimator::sizeOf).sum()) + + (targetColumnNames == null + ? 0 + : targetColumnNames.stream().mapToLong(RamUsageEstimator::sizeOf).sum()); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/BottomInferenceWindowParameter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/BottomInferenceWindowParameter.java new file mode 100644 index 000000000000..77953b122f4b --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/BottomInferenceWindowParameter.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.window.ainode; + +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; + +public class BottomInferenceWindowParameter extends InferenceWindowParameter { + + long windowSize; + + public BottomInferenceWindowParameter(long windowSize) { + this.windowSize = windowSize; + this.windowType = InferenceWindowType.TAIL; + } + + public long getWindowSize() { + return windowSize; + } + + @Override + public void serializeAttributes(ByteBuffer buffer) { + ReadWriteIOUtils.write(windowSize, buffer); + } + + @Override + public void serializeAttributes(DataOutputStream stream) throws IOException { + ReadWriteIOUtils.write(windowSize, stream); + } + + public static BottomInferenceWindowParameter deserialize(ByteBuffer byteBuffer) { + long windowSize = byteBuffer.getLong(); + return new BottomInferenceWindowParameter(windowSize); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof BottomInferenceWindowParameter)) { + return false; + } + BottomInferenceWindowParameter parameter = (BottomInferenceWindowParameter) obj; + return windowSize == parameter.windowSize; + } + + @Override + public int hashCode() { + return Long.hashCode(windowSize); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/CountInferenceWindow.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/CountInferenceWindow.java new file mode 100644 index 000000000000..723e87593464 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/CountInferenceWindow.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.window.ainode; + +public class CountInferenceWindow extends InferenceWindow { + + private final long interval; + private final long step; + + public CountInferenceWindow(long interval, long step) { + super(InferenceWindowType.COUNT); + this.interval = interval; + this.step = step; + } + + public long getInterval() { + return interval; + } + + public long getStep() { + return step; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/CountInferenceWindowParameter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/CountInferenceWindowParameter.java new file mode 100644 index 000000000000..6a6371c4a70a --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/CountInferenceWindowParameter.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.window.ainode; + +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +public class CountInferenceWindowParameter extends InferenceWindowParameter { + + private final long interval; + private final long step; + + public CountInferenceWindowParameter(long interval, long step) { + this.windowType = InferenceWindowType.COUNT; + this.interval = interval; + this.step = step; + } + + public long getInterval() { + return interval; + } + + public long getStep() { + return step; + } + + @Override + public void serializeAttributes(ByteBuffer buffer) { + ReadWriteIOUtils.write(interval, buffer); + ReadWriteIOUtils.write(step, buffer); + } + + @Override + public void serializeAttributes(DataOutputStream stream) throws IOException { + ReadWriteIOUtils.write(interval, stream); + ReadWriteIOUtils.write(step, stream); + } + + public static CountInferenceWindowParameter deserialize(ByteBuffer byteBuffer) { + long interval = ReadWriteIOUtils.readLong(byteBuffer); + long step = ReadWriteIOUtils.readLong(byteBuffer); + return new CountInferenceWindowParameter(interval, step); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof CountInferenceWindowParameter)) { + return false; + } + CountInferenceWindowParameter parameter = (CountInferenceWindowParameter) obj; + return interval == parameter.interval && step == parameter.step; + } + + @Override + public int hashCode() { + return Objects.hash(interval, step); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/HeadInferenceWindow.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/HeadInferenceWindow.java new file mode 100644 index 000000000000..8e4f2cc65cb8 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/HeadInferenceWindow.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.window.ainode; + +public class HeadInferenceWindow extends InferenceWindow { + private final long windowSize; + + public HeadInferenceWindow(long windowSize) { + super(InferenceWindowType.HEAD); + this.windowSize = windowSize; + } + + public long getWindowSize() { + return windowSize; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/InferenceWindow.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/InferenceWindow.java new file mode 100644 index 000000000000..e5c00f910d0c --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/InferenceWindow.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.window.ainode; + +public class InferenceWindow { + private final InferenceWindowType type; + + public InferenceWindow(InferenceWindowType type) { + this.type = type; + } + + public InferenceWindowType getType() { + return type; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/InferenceWindowParameter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/InferenceWindowParameter.java new file mode 100644 index 000000000000..b9ab1343c3b8 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/InferenceWindowParameter.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.window.ainode; + +import org.apache.iotdb.db.exception.sql.SemanticException; + +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; + +public abstract class InferenceWindowParameter { + + protected InferenceWindowType windowType; + + public InferenceWindowType getWindowType() { + return windowType; + } + + public abstract void serializeAttributes(ByteBuffer buffer); + + public abstract void serializeAttributes(DataOutputStream stream) throws IOException; + + public void serialize(ByteBuffer buffer) { + ReadWriteIOUtils.write(windowType.ordinal(), buffer); + serializeAttributes(buffer); + } + + public void serialize(DataOutputStream stream) throws IOException { + ReadWriteIOUtils.write(windowType.ordinal(), stream); + serializeAttributes(stream); + } + + public static InferenceWindowParameter deserialize(ByteBuffer byteBuffer) { + InferenceWindowType windowType = + InferenceWindowType.values()[ReadWriteIOUtils.readInt(byteBuffer)]; + if (windowType == InferenceWindowType.TAIL) { + return BottomInferenceWindowParameter.deserialize(byteBuffer); + } else if (windowType == InferenceWindowType.COUNT) { + return CountInferenceWindowParameter.deserialize(byteBuffer); + } else { + throw new SemanticException("Unsupported inference window type: " + windowType); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/InferenceWindowType.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/InferenceWindowType.java new file mode 100644 index 000000000000..f792327396f9 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/InferenceWindowType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iotdb.db.queryengine.execution.operator.window.ainode; + +public enum InferenceWindowType { + HEAD, + TAIL, + COUNT +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/TailInferenceWindow.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/TailInferenceWindow.java new file mode 100644 index 000000000000..3bbc568cac2b --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/window/ainode/TailInferenceWindow.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.window.ainode; + +public class TailInferenceWindow extends InferenceWindow { + + private final long windowSize; + + public TailInferenceWindow(long windowSize) { + super(InferenceWindowType.TAIL); + this.windowSize = windowSize; + } + + public long getWindowSize() { + return windowSize; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/Analysis.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/Analysis.java index 94a1e518b74b..93e386b9bd52 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/Analysis.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/Analysis.java @@ -25,6 +25,7 @@ import org.apache.iotdb.common.rpc.thrift.TSStatus; import org.apache.iotdb.common.rpc.thrift.TSchemaNode; import org.apache.iotdb.common.rpc.thrift.TTimePartitionSlot; +import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.commons.partition.DataPartition; import org.apache.iotdb.commons.partition.SchemaPartition; import org.apache.iotdb.commons.path.PartialPath; @@ -49,6 +50,7 @@ import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.GroupByTimeParameter; import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.IntoPathDescriptor; import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.OrderByParameter; +import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.db.queryengine.plan.statement.Statement; import org.apache.iotdb.db.queryengine.plan.statement.StatementType; import org.apache.iotdb.db.queryengine.plan.statement.component.Ordering; @@ -237,6 +239,8 @@ aggregation results last_value(temperature) and last_value(status), whereas buck // indicate whether the Nodes produce source data are VirtualSourceNodes private boolean isVirtualSource = false; + private ModelInferenceDescriptor modelInferenceDescriptor; + ///////////////////////////////////////////////////////////////////////////////////////////////// // SELECT INTO Analysis ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -892,6 +896,21 @@ public Map getOutputDeviceToQueriedDevicesMap() { return outputDeviceToQueriedDevicesMap; } + public ModelInferenceDescriptor getModelInferenceDescriptor() { + return modelInferenceDescriptor; + } + + public ModelInformation getModelInformation() { + if (modelInferenceDescriptor == null) { + return null; + } + return modelInferenceDescriptor.getModelInformation(); + } + + public void setModelInferenceDescriptor(ModelInferenceDescriptor modelInferenceDescriptor) { + this.modelInferenceDescriptor = modelInferenceDescriptor; + } + public void setOutputDeviceToQueriedDevicesMap( Map outputDeviceToQueriedDevicesMap) { this.outputDeviceToQueriedDevicesMap = outputDeviceToQueriedDevicesMap; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java index 605fb4c95dcb..b01ee8babed1 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java @@ -20,12 +20,14 @@ package org.apache.iotdb.db.queryengine.plan.analyze; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; +import org.apache.iotdb.common.rpc.thrift.TSStatus; import org.apache.iotdb.common.rpc.thrift.TTimePartitionSlot; import org.apache.iotdb.commons.client.exception.ClientManagerException; import org.apache.iotdb.commons.conf.IoTDBConstant; import org.apache.iotdb.commons.exception.IllegalPathException; import org.apache.iotdb.commons.exception.IoTDBException; import org.apache.iotdb.commons.exception.MetadataException; +import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.commons.partition.DataPartition; import org.apache.iotdb.commons.partition.DataPartitionQueryParam; import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition; @@ -41,6 +43,7 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetDataNodeLocationsResp; import org.apache.iotdb.db.conf.IoTDBConfig; import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.exception.ainode.GetModelInfoException; import org.apache.iotdb.db.exception.metadata.template.TemplateIncompatibleException; import org.apache.iotdb.db.exception.metadata.view.UnsupportedViewException; import org.apache.iotdb.db.exception.sql.SemanticException; @@ -59,6 +62,14 @@ import org.apache.iotdb.db.queryengine.common.schematree.IMeasurementSchemaInfo; import org.apache.iotdb.db.queryengine.common.schematree.ISchemaTree; import org.apache.iotdb.db.queryengine.execution.operator.window.WindowType; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.BottomInferenceWindowParameter; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindow; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindowParameter; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.HeadInferenceWindow; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindow; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowParameter; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowType; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.TailInferenceWindow; import org.apache.iotdb.db.queryengine.metric.QueryPlanCostMetricSet; import org.apache.iotdb.db.queryengine.plan.analyze.lock.DataNodeSchemaLockManager; import org.apache.iotdb.db.queryengine.plan.analyze.lock.SchemaLockType; @@ -217,11 +228,14 @@ public class AnalyzeVisitor extends StatementVisitor public static final Expression END_TIME_EXPRESSION = TimeSeriesOperand.constructColumnHeaderExpression(ENDTIME, TSDataType.INT64); + private static final String INFERENCE_COLUMN_NAME = "output"; + private final List lastQueryColumnNames = new ArrayList<>(Arrays.asList("TIME", "TIMESERIES", "VALUE", "DATATYPE")); private final IPartitionFetcher partitionFetcher; private final ISchemaFetcher schemaFetcher; + private final IModelFetcher modelFetcher; private static final PerformanceOverviewMetrics PERFORMANCE_OVERVIEW_METRICS = PerformanceOverviewMetrics.getInstance(); @@ -229,6 +243,7 @@ public class AnalyzeVisitor extends StatementVisitor public AnalyzeVisitor(IPartitionFetcher partitionFetcher, ISchemaFetcher schemaFetcher) { this.partitionFetcher = partitionFetcher; this.schemaFetcher = schemaFetcher; + this.modelFetcher = ModelFetcher.getInstance(); } @Override @@ -268,6 +283,9 @@ public Analysis visitQuery(QueryStatement queryStatement, MPPQueryContext contex // check for semantic errors queryStatement.semanticCheck(); + // fetch model inference information and check + analyzeModelInference(analysis, queryStatement); + ISchemaTree schemaTree = analyzeSchema(queryStatement, analysis, context); // If there is no leaf node in the schema tree, the query should be completed immediately @@ -397,6 +415,77 @@ public Analysis visitQuery(QueryStatement queryStatement, MPPQueryContext contex return analysis; } + // check if there is proper model to inference for MODEL_NAME, there is no need to do the + // following analyze if there isn't. + private void analyzeModelInference(Analysis analysis, QueryStatement queryStatement) { + if (!queryStatement.hasModelInference()) { + return; + } + + // Get model metadata from configNode and do some check + String modelId = queryStatement.getModelName(); + TSStatus status = modelFetcher.fetchModel(modelId, analysis); + if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + throw new GetModelInfoException(status.getMessage()); + } + ModelInformation modelInformation = analysis.getModelInformation(); + if (modelInformation == null || !modelInformation.available()) { + throw new SemanticException("Model " + modelId + " is not active"); + } + + // set inference window if there is + if (queryStatement.isSetInferenceWindow()) { + InferenceWindow window = queryStatement.getInferenceWindow(); + if (InferenceWindowType.HEAD == window.getType()) { + long windowSize = ((HeadInferenceWindow) window).getWindowSize(); + checkWindowSize(windowSize, modelInformation); + if (queryStatement.hasLimit() && queryStatement.getRowLimit() < windowSize) { + throw new SemanticException( + "Limit in Sql should be larger than window size in inference"); + } + // optimize head window by limitNode + queryStatement.setRowLimit(windowSize); + } else if (InferenceWindowType.TAIL == window.getType()) { + long windowSize = ((TailInferenceWindow) window).getWindowSize(); + checkWindowSize(windowSize, modelInformation); + InferenceWindowParameter inferenceWindowParameter = + new BottomInferenceWindowParameter(windowSize); + analysis + .getModelInferenceDescriptor() + .setInferenceWindowParameter(inferenceWindowParameter); + } else if (InferenceWindowType.COUNT == window.getType()) { + CountInferenceWindow countInferenceWindow = (CountInferenceWindow) window; + checkWindowSize(countInferenceWindow.getInterval(), modelInformation); + InferenceWindowParameter inferenceWindowParameter = + new CountInferenceWindowParameter( + countInferenceWindow.getInterval(), countInferenceWindow.getStep()); + analysis + .getModelInferenceDescriptor() + .setInferenceWindowParameter(inferenceWindowParameter); + } + } + + // set inference attributes if there is + if (queryStatement.hasInferenceAttributes()) { + analysis + .getModelInferenceDescriptor() + .setInferenceAttributes(queryStatement.getInferenceAttributes()); + } + } + + private void checkWindowSize(long windowSize, ModelInformation modelInformation) { + if (modelInformation.isBuiltIn()) { + return; + } + + if (modelInformation.getInputShape()[0] != windowSize) { + throw new SemanticException( + String.format( + "Window output %d is not equal to input size of model %d", + windowSize, modelInformation.getInputShape()[0])); + } + } + private ISchemaTree analyzeSchema( QueryStatement queryStatement, Analysis analysis, MPPQueryContext context) { // concat path and construct path pattern tree @@ -1574,6 +1663,27 @@ static void analyzeOutput( return; } + if (queryStatement.hasModelInference()) { + ModelInformation modelInformation = analysis.getModelInformation(); + // check input + checkInputShape(modelInformation, outputExpressions); + checkInputType(analysis, modelInformation, outputExpressions); + + // set output + List columnHeaders = new ArrayList<>(); + int[] outputShape = modelInformation.getOutputShape(); + TSDataType[] outputDataType = modelInformation.getOutputDataType(); + for (int i = 0; i < outputShape[1]; i++) { + columnHeaders.add(new ColumnHeader(INFERENCE_COLUMN_NAME + i, outputDataType[i])); + } + analysis + .getModelInferenceDescriptor() + .setOutputColumnNames( + columnHeaders.stream().map(ColumnHeader::getColumnName).collect(Collectors.toList())); + analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, true)); + return; + } + boolean isIgnoreTimestamp = queryStatement.isAggregationQuery() && !queryStatement.isGroupBy(); List columnHeaders = new ArrayList<>(); if (queryStatement.isAlignByDevice()) { @@ -1592,6 +1702,72 @@ static void analyzeOutput( analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, isIgnoreTimestamp)); } + // check if the result of SQL matches the input of model + private static void checkInputShape( + ModelInformation modelInformation, List> outputExpressions) { + if (modelInformation.isBuiltIn()) { + modelInformation.setInputColumnSize(outputExpressions.size()); + return; + } + + // check inputShape + int[] inputShape = modelInformation.getInputShape(); + if (inputShape.length != 2) { + throw new SemanticException( + String.format( + "The input shape of model is not correct, the dimension of input shape should be 2, actual dimension is %d", + inputShape.length)); + } + int columnNumber = inputShape[1]; + if (columnNumber != outputExpressions.size()) { + throw new SemanticException( + String.format( + "The column number of SQL result does not match the number of model input [%d] for inference", + columnNumber)); + } + } + + private static void checkInputType( + Analysis analysis, + ModelInformation modelInformation, + List> outputExpressions) { + + if (modelInformation.isBuiltIn()) { + TSDataType[] inputType = new TSDataType[outputExpressions.size()]; + for (int i = 0; i < outputExpressions.size(); i++) { + Expression inputExpression = outputExpressions.get(i).left; + TSDataType inputDataType = analysis.getType(inputExpression); + if (!inputDataType.isNumeric()) { + throw new SemanticException( + String.format( + "The type of SQL result column [%s in %d] should be numeric when inference", + inputDataType, i)); + } + inputType[i] = inputDataType; + } + modelInformation.setInputDataType(inputType); + return; + } + + TSDataType[] inputType = modelInformation.getInputDataType(); + if (inputType.length != modelInformation.getInputShape()[1]) { + throw new SemanticException( + String.format( + "The inputType does not match the input shape [%d] for inference", + modelInformation.getInputShape()[1])); + } + for (int i = 0; i < inputType.length; i++) { + Expression inputExpression = outputExpressions.get(i).left; + TSDataType inputDataType = analysis.getType(inputExpression); + if (inputDataType != inputType[i]) { + throw new SemanticException( + String.format( + "The type of SQL result column [%s in %d] does not match the type of model input [%s] when inference", + inputDataType, i, inputType[i])); + } + } + } + // For last query private void analyzeLastOrderBy(Analysis analysis, QueryStatement queryStatement) { if (!queryStatement.hasOrderBy()) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java new file mode 100644 index 000000000000..1feecaefde9c --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.analyze; + +import org.apache.iotdb.common.rpc.thrift.TSStatus; + +public interface IModelFetcher { + /** Get model information by model id from configNode. */ + TSStatus fetchModel(String modelId, Analysis analysis); +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java new file mode 100644 index 000000000000..8cefb5e0cf3c --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.analyze; + +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.commons.client.exception.ClientManagerException; +import org.apache.iotdb.commons.consensus.ConfigRegionId; +import org.apache.iotdb.commons.model.ModelInformation; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; +import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; +import org.apache.iotdb.db.exception.ainode.ModelNotFoundException; +import org.apache.iotdb.db.exception.sql.StatementAnalyzeException; +import org.apache.iotdb.db.protocol.client.ConfigNodeClient; +import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; +import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; +import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; +import org.apache.iotdb.rpc.TSStatusCode; + +import org.apache.thrift.TException; + +public class ModelFetcher implements IModelFetcher { + + private final IClientManager configNodeClientManager = + ConfigNodeClientManager.getInstance(); + + private static final class ModelFetcherHolder { + + private static final ModelFetcher INSTANCE = new ModelFetcher(); + + private ModelFetcherHolder() {} + } + + public static ModelFetcher getInstance() { + return ModelFetcherHolder.INSTANCE; + } + + private ModelFetcher() {} + + @Override + public TSStatus fetchModel(String modelName, Analysis analysis) { + try (ConfigNodeClient client = + configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { + TGetModelInfoResp getModelInfoResp = client.getModelInfo(new TGetModelInfoReq(modelName)); + if (getModelInfoResp.getStatus().getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + if (getModelInfoResp.modelInfo != null && getModelInfoResp.isSetAiNodeAddress()) { + analysis.setModelInferenceDescriptor( + new ModelInferenceDescriptor( + getModelInfoResp.aiNodeAddress, + ModelInformation.deserialize(getModelInfoResp.modelInfo))); + return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); + } else { + TSStatus status = new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); + status.setMessage(String.format("model [%s] is not available", modelName)); + return status; + } + } else { + throw new ModelNotFoundException(getModelInfoResp.getStatus().getMessage()); + } + } catch (ClientManagerException | TException e) { + throw new StatementAnalyzeException(e.getMessage()); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/ConfigTaskVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/ConfigTaskVisitor.java index 6a04b32c0ae6..b928d4722344 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/ConfigTaskVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/ConfigTaskVisitor.java @@ -52,6 +52,10 @@ import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowTriggersTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowVariablesTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.UnSetTTLTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model.CreateModelTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model.DropModelTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model.ShowAINodesTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model.ShowModelsTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.template.AlterSchemaTemplateTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.template.CreateSchemaTemplateTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.template.DeactivateSchemaTemplateTask; @@ -120,6 +124,10 @@ import org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowTriggersStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowVariablesStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.UnSetTTLStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.CreateModelStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.DropModelStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowAINodesStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowModelsStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.AlterPipeStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.CreatePipePluginStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.CreatePipeStatement; @@ -585,4 +593,29 @@ public IConfigTask visitShowThrottleQuota( ShowThrottleQuotaStatement showThrottleQuotaStatement, MPPQueryContext context) { return new ShowThrottleQuotaTask(showThrottleQuotaStatement); } + + /** AI Model Management */ + @Override + public IConfigTask visitCreateModel( + CreateModelStatement createModelStatement, MPPQueryContext context) { + return new CreateModelTask(createModelStatement, context); + } + + @Override + public IConfigTask visitDropModel( + DropModelStatement dropModelStatement, MPPQueryContext context) { + return new DropModelTask(dropModelStatement.getModelName()); + } + + @Override + public IConfigTask visitShowModels( + ShowModelsStatement showModelsStatement, MPPQueryContext context) { + return new ShowModelsTask(showModelsStatement.getModelName()); + } + + @Override + public IConfigTask visitShowAINodes( + ShowAINodesStatement showAINodesStatement, MPPQueryContext context) { + return new ShowAINodesTask(showAINodesStatement); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java index 8b6d53582ebe..7ac6b553e5ec 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java @@ -63,6 +63,7 @@ import org.apache.iotdb.confignode.rpc.thrift.TCountTimeSlotListResp; import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; +import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTopicReq; @@ -74,6 +75,7 @@ import org.apache.iotdb.confignode.rpc.thrift.TDeleteTimeSeriesReq; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; +import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropTopicReq; @@ -92,11 +94,14 @@ import org.apache.iotdb.confignode.rpc.thrift.TPipeConfigTransferReq; import org.apache.iotdb.confignode.rpc.thrift.TPipeConfigTransferResp; import org.apache.iotdb.confignode.rpc.thrift.TRegionInfo; +import org.apache.iotdb.confignode.rpc.thrift.TShowAINodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowCQResp; import org.apache.iotdb.confignode.rpc.thrift.TShowClusterResp; import org.apache.iotdb.confignode.rpc.thrift.TShowConfigNodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowDataNodesResp; import org.apache.iotdb.confignode.rpc.thrift.TShowDatabaseResp; +import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq; +import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp; import org.apache.iotdb.confignode.rpc.thrift.TShowPipeInfo; import org.apache.iotdb.confignode.rpc.thrift.TShowPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TShowRegionReq; @@ -150,6 +155,8 @@ import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowTTLTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowTriggersTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowVariablesTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model.ShowAINodesTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model.ShowModelsTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.template.ShowNodesInSchemaTemplateTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.template.ShowPathSetTemplateTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.template.ShowSchemaTemplateTask; @@ -180,6 +187,8 @@ import org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowDatabaseStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowRegionStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowTTLStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.CreateModelStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowAINodesStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.AlterPipeStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.CreatePipePluginStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.CreatePipeStatement; @@ -2645,6 +2654,90 @@ public SettableFuture showContinuousQueries() { return future; } + @Override + public SettableFuture createModel( + CreateModelStatement createModelStatement, MPPQueryContext context) { + SettableFuture future = SettableFuture.create(); + try (ConfigNodeClient client = + CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { + TCreateModelReq req = + new TCreateModelReq(createModelStatement.getModelName(), createModelStatement.getUri()); + final TSStatus status = client.createModel(req); + if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != status.getCode()) { + LOGGER.warn( + "[{}] Failed to create model {}. TSStatus is {}", + status, + createModelStatement.getModelName(), + status.message); + future.setException(new IoTDBException(status.message, status.code)); + } else { + future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS)); + } + } catch (ClientManagerException | TException e) { + future.setException(e); + } + return future; + } + + @Override + public SettableFuture dropModel(String modelName) { + SettableFuture future = SettableFuture.create(); + try (ConfigNodeClient client = + CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { + final TSStatus executionStatus = client.dropModel(new TDropModelReq(modelName)); + if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != executionStatus.getCode()) { + LOGGER.warn("[{}] Failed to drop model {}.", executionStatus, modelName); + future.setException(new IoTDBException(executionStatus.message, executionStatus.code)); + } else { + future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS)); + } + } catch (ClientManagerException | TException e) { + future.setException(e); + } + return future; + } + + @Override + public SettableFuture showModels(String modelName) { + SettableFuture future = SettableFuture.create(); + try (ConfigNodeClient client = + CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { + TShowModelReq req = new TShowModelReq(); + if (modelName != null) { + req.setModelId(modelName); + } + TShowModelResp showModelResp = client.showModel(req); + if (showModelResp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + future.setException( + new IoTDBException(showModelResp.getStatus().message, showModelResp.getStatus().code)); + return future; + } + // convert model info list and buildTsBlock + ShowModelsTask.buildTsBlock(showModelResp.getModelInfoList(), future); + } catch (ClientManagerException | TException e) { + future.setException(e); + } + return future; + } + + @Override + public SettableFuture showAINodes(ShowAINodesStatement showAINodesStatement) { + SettableFuture future = SettableFuture.create(); + TShowAINodesResp resp = new TShowAINodesResp(); + try (ConfigNodeClient client = + CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { + resp = client.showAINodes(); + if (resp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + future.setException(new IoTDBException(resp.getStatus().message, resp.getStatus().code)); + return future; + } + } catch (ClientManagerException | TException e) { + future.setException(e); + } + ShowAINodesTask.buildTsBlock(resp, future); + return future; + } + @Override public SettableFuture setSpaceQuota( SetSpaceQuotaStatement setSpaceQuotaStatement) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java index fad6ec99cba8..b3911f559795 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java @@ -46,6 +46,8 @@ import org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowDatabaseStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowRegionStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowTTLStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.CreateModelStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowAINodesStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.AlterPipeStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.CreatePipePluginStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.CreatePipeStatement; @@ -242,6 +244,15 @@ SettableFuture showThrottleQuota( TThrottleQuotaResp getThrottleQuota(); + SettableFuture createModel( + CreateModelStatement createModelStatement, MPPQueryContext context); + + SettableFuture dropModel(String modelName); + + SettableFuture showModels(String modelName); + + SettableFuture showAINodes(ShowAINodesStatement showAINodesStatement); + TPipeTransferResp handleTransferConfigPlan(String clientId, TPipeTransferReq req); void handlePipeConfigClientExit(String clientId); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowClusterDetailsTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowClusterDetailsTask.java index 54f6c5b57be0..f57c84fe3b93 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowClusterDetailsTask.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowClusterDetailsTask.java @@ -41,6 +41,7 @@ import java.util.List; import java.util.stream.Collectors; +import static org.apache.iotdb.db.queryengine.common.header.ColumnHeaderConstant.NODE_TYPE_AI_NODE; import static org.apache.iotdb.db.queryengine.common.header.ColumnHeaderConstant.NODE_TYPE_CONFIG_NODE; import static org.apache.iotdb.db.queryengine.common.header.ColumnHeaderConstant.NODE_TYPE_DATA_NODE; @@ -104,6 +105,56 @@ private static void buildConfigNodesTsBlock( builder.declarePosition(); } + private static void buildAINodeTsBlock( + TsBlockBuilder builder, + int nodeId, + String nodeStatus, + String internalAddress, + int internalPort, + TNodeVersionInfo versionInfo) { + + builder.getTimeColumnBuilder().writeLong(0L); + builder.getColumnBuilder(0).writeInt(nodeId); + builder + .getColumnBuilder(1) + .writeBinary(new Binary(NODE_TYPE_AI_NODE, TSFileConfig.STRING_CHARSET)); + if (nodeStatus == null) { + builder.getColumnBuilder(2).appendNull(); + } else { + builder.getColumnBuilder(2).writeBinary(new Binary(nodeStatus, TSFileConfig.STRING_CHARSET)); + } + + if (internalAddress == null) { + builder.getColumnBuilder(3).appendNull(); + } else { + builder + .getColumnBuilder(3) + .writeBinary(new Binary(internalAddress, TSFileConfig.STRING_CHARSET)); + } + builder.getColumnBuilder(4).writeInt(internalPort); + builder.getColumnBuilder(5).writeBinary(new Binary("", TSFileConfig.STRING_CHARSET)); + builder.getColumnBuilder(6).writeBinary(new Binary("", TSFileConfig.STRING_CHARSET)); + builder.getColumnBuilder(7).writeBinary(new Binary("", TSFileConfig.STRING_CHARSET)); + builder.getColumnBuilder(8).writeBinary(new Binary("", TSFileConfig.STRING_CHARSET)); + builder.getColumnBuilder(9).writeBinary(new Binary("", TSFileConfig.STRING_CHARSET)); + builder.getColumnBuilder(10).writeBinary(new Binary("", TSFileConfig.STRING_CHARSET)); + if (versionInfo == null || versionInfo.getVersion() == null) { + builder.getColumnBuilder(11).appendNull(); + } else { + builder + .getColumnBuilder(11) + .writeBinary(new Binary(versionInfo.getVersion(), TSFileConfig.STRING_CHARSET)); + } + if (versionInfo == null || versionInfo.getBuildInfo() == null) { + builder.getColumnBuilder(12).appendNull(); + } else { + builder + .getColumnBuilder(12) + .writeBinary(new Binary(versionInfo.getBuildInfo(), TSFileConfig.STRING_CHARSET)); + } + builder.declarePosition(); + } + @SuppressWarnings("squid:S107") private static void buildDataNodesTsBlock( TsBlockBuilder builder, @@ -208,6 +259,17 @@ public static void buildTSBlock( e.getSchemaRegionConsensusEndPoint().getPort(), e.getDataRegionConsensusEndPoint().getPort(), clusterNodeInfos.getNodeVersionInfo().get(e.getDataNodeId()))); + clusterNodeInfos + .getAiNodeList() + .forEach( + e -> + buildAINodeTsBlock( + builder, + e.getAiNodeId(), + clusterNodeInfos.getNodeStatus().get(e.getAiNodeId()), + e.getInternalEndPoint().getIp(), + e.getInternalEndPoint().getPort(), + clusterNodeInfos.getNodeVersionInfo().get(e.getAiNodeId()))); DatasetHeader datasetHeader = DatasetHeaderFactory.getShowClusterDetailsHeader(); future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS, builder.build(), datasetHeader)); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowClusterTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowClusterTask.java index 7d0c41f6baad..960b088293c5 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowClusterTask.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowClusterTask.java @@ -41,6 +41,7 @@ import java.util.List; import java.util.stream.Collectors; +import static org.apache.iotdb.db.queryengine.common.header.ColumnHeaderConstant.NODE_TYPE_AI_NODE; import static org.apache.iotdb.db.queryengine.common.header.ColumnHeaderConstant.NODE_TYPE_CONFIG_NODE; import static org.apache.iotdb.db.queryengine.common.header.ColumnHeaderConstant.NODE_TYPE_DATA_NODE; @@ -136,6 +137,20 @@ public static void buildTsBlock( e.getInternalEndPoint().getPort(), clusterNodeInfos.getNodeVersionInfo().get(e.getDataNodeId()))); + if (clusterNodeInfos.getAiNodeList() != null) { + clusterNodeInfos + .getAiNodeList() + .forEach( + e -> + buildTsBlock( + builder, + e.getAiNodeId(), + NODE_TYPE_AI_NODE, + clusterNodeInfos.getNodeStatus().get(e.getAiNodeId()), + e.getInternalEndPoint().getIp(), + e.getInternalEndPoint().getPort(), + clusterNodeInfos.getNodeVersionInfo().get(e.getAiNodeId()))); + } DatasetHeader datasetHeader = DatasetHeaderFactory.getShowClusterHeader(); future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS, builder.build(), datasetHeader)); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/CreateModelTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/CreateModelTask.java new file mode 100644 index 000000000000..842b558fd8a0 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/CreateModelTask.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model; + +import org.apache.iotdb.db.queryengine.common.MPPQueryContext; +import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult; +import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.executor.IConfigTaskExecutor; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.CreateModelStatement; + +import com.google.common.util.concurrent.ListenableFuture; + +public class CreateModelTask implements IConfigTask { + + private final CreateModelStatement createModelStatement; + private final MPPQueryContext context; + + public CreateModelTask(CreateModelStatement createModelStatement, MPPQueryContext context) { + this.createModelStatement = createModelStatement; + this.context = context; + } + + @Override + public ListenableFuture execute(IConfigTaskExecutor configTaskExecutor) + throws InterruptedException { + return configTaskExecutor.createModel(createModelStatement, context); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/DropModelTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/DropModelTask.java new file mode 100644 index 000000000000..f8db88790d4d --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/DropModelTask.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model; + +import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult; +import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.executor.IConfigTaskExecutor; + +import com.google.common.util.concurrent.ListenableFuture; + +public class DropModelTask implements IConfigTask { + + private final String modelName; + + public DropModelTask(String modelName) { + this.modelName = modelName; + } + + @Override + public ListenableFuture execute(IConfigTaskExecutor configTaskExecutor) + throws InterruptedException { + return configTaskExecutor.dropModel(modelName); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/ShowAINodesTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/ShowAINodesTask.java new file mode 100644 index 000000000000..f63a43aa1a36 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/ShowAINodesTask.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model; + +import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo; +import org.apache.iotdb.confignode.rpc.thrift.TShowAINodesResp; +import org.apache.iotdb.db.queryengine.common.header.ColumnHeader; +import org.apache.iotdb.db.queryengine.common.header.ColumnHeaderConstant; +import org.apache.iotdb.db.queryengine.common.header.DatasetHeader; +import org.apache.iotdb.db.queryengine.common.header.DatasetHeaderFactory; +import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult; +import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.executor.IConfigTaskExecutor; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowAINodesStatement; +import org.apache.iotdb.rpc.TSStatusCode; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.TsBlockBuilder; +import org.apache.tsfile.utils.BytesUtils; + +import java.util.List; +import java.util.stream.Collectors; + +public class ShowAINodesTask implements IConfigTask { + + private final ShowAINodesStatement showAINodesStatement; + + public ShowAINodesTask(ShowAINodesStatement showAINodesStatement) { + this.showAINodesStatement = showAINodesStatement; + } + + @Override + public ListenableFuture execute(IConfigTaskExecutor configTaskExecutor) + throws InterruptedException { + return configTaskExecutor.showAINodes(showAINodesStatement); + } + + public static void buildTsBlock( + TShowAINodesResp showAINodesResp, SettableFuture future) { + List outputDataTypes = + ColumnHeaderConstant.showAINodesColumnHeaders.stream() + .map(ColumnHeader::getColumnType) + .collect(Collectors.toList()); + TsBlockBuilder builder = new TsBlockBuilder(outputDataTypes); + if (showAINodesResp.getAiNodesInfoList() != null) { + for (TAINodeInfo aiNodeInfo : showAINodesResp.getAiNodesInfoList()) { + builder.getTimeColumnBuilder().writeLong(0); + builder.getColumnBuilder(0).writeInt(aiNodeInfo.getAiNodeId()); + builder.getColumnBuilder(1).writeBinary(BytesUtils.valueOf(aiNodeInfo.getStatus())); + builder + .getColumnBuilder(2) + .writeBinary(BytesUtils.valueOf(aiNodeInfo.getInternalAddress())); + builder.getColumnBuilder(3).writeInt(aiNodeInfo.getInternalPort()); + + builder.declarePosition(); + } + } + DatasetHeader datasetHeader = DatasetHeaderFactory.getShowAINodesHeader(); + future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS, builder.build(), datasetHeader)); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/ShowModelsTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/ShowModelsTask.java new file mode 100644 index 000000000000..0dc70f21e3ee --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/model/ShowModelsTask.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.execution.config.metadata.model; + +import org.apache.iotdb.commons.model.ModelType; +import org.apache.iotdb.db.queryengine.common.header.ColumnHeader; +import org.apache.iotdb.db.queryengine.common.header.ColumnHeaderConstant; +import org.apache.iotdb.db.queryengine.common.header.DatasetHeader; +import org.apache.iotdb.db.queryengine.common.header.DatasetHeaderFactory; +import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult; +import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.executor.IConfigTaskExecutor; +import org.apache.iotdb.rpc.TSStatusCode; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.TsBlockBuilder; +import org.apache.tsfile.utils.BytesUtils; +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +public class ShowModelsTask implements IConfigTask { + + private final String modelName; + + public ShowModelsTask(String modelName) { + this.modelName = modelName; + } + + private static final String INPUT_SHAPE = "inputShape:"; + private static final String OUTPUT_SHAPE = "outputShape:"; + private static final String INPUT_DATA_TYPE = "inputDataType:"; + private static final String OUTPUT_DATA_TYPE = "outputDataType:"; + private static final String EMPTY_STRING = ""; + + @Override + public ListenableFuture execute(IConfigTaskExecutor configTaskExecutor) + throws InterruptedException { + return configTaskExecutor.showModels(modelName); + } + + public static void buildTsBlock( + List modelInfoList, SettableFuture future) { + List outputDataTypes = + ColumnHeaderConstant.showModelsColumnHeaders.stream() + .map(ColumnHeader::getColumnType) + .collect(Collectors.toList()); + TsBlockBuilder builder = new TsBlockBuilder(outputDataTypes); + for (ByteBuffer modelInfo : modelInfoList) { + String modelId = ReadWriteIOUtils.readString(modelInfo); + String modelType = ReadWriteIOUtils.readString(modelInfo); + String state = ReadWriteIOUtils.readString(modelInfo); + String note; + String config; + if (Objects.equals(modelType, ModelType.USER_DEFINED.toString())) { + String inputShape = ReadWriteIOUtils.readString(modelInfo); + String outputShape = ReadWriteIOUtils.readString(modelInfo); + String inputTypes = ReadWriteIOUtils.readString(modelInfo); + String outputTypes = ReadWriteIOUtils.readString(modelInfo); + note = ReadWriteIOUtils.readString(modelInfo); + config = + INPUT_SHAPE + + inputShape + + OUTPUT_SHAPE + + outputShape + + INPUT_DATA_TYPE + + inputTypes + + OUTPUT_DATA_TYPE + + outputTypes; + } else { + config = EMPTY_STRING; + note = "Built-in model in IoTDB"; + } + + builder.getTimeColumnBuilder().writeLong(0L); + builder.getColumnBuilder(0).writeBinary(BytesUtils.valueOf(modelId)); + builder.getColumnBuilder(1).writeBinary(BytesUtils.valueOf(modelType)); + builder.getColumnBuilder(2).writeBinary(BytesUtils.valueOf(state)); + builder.getColumnBuilder(3).writeBinary(BytesUtils.valueOf(config)); + if (note != null) { + builder.getColumnBuilder(4).writeBinary(BytesUtils.valueOf(note)); + } else { + builder.getColumnBuilder(4).writeBinary(BytesUtils.valueOf("")); + } + builder.declarePosition(); + } + DatasetHeader datasetHeader = DatasetHeaderFactory.getShowModelsHeader(); + future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS, builder.build(), datasetHeader)); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java index 7235ffbc5185..099678ffcee2 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java @@ -53,6 +53,10 @@ import org.apache.iotdb.db.qp.sql.IoTDBSqlParserBaseVisitor; import org.apache.iotdb.db.queryengine.common.header.ColumnHeaderConstant; import org.apache.iotdb.db.queryengine.execution.operator.window.WindowType; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindow; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.HeadInferenceWindow; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindow; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.TailInferenceWindow; import org.apache.iotdb.db.queryengine.plan.analyze.ExpressionAnalyzer; import org.apache.iotdb.db.queryengine.plan.expression.Expression; import org.apache.iotdb.db.queryengine.plan.expression.ExpressionType; @@ -161,6 +165,10 @@ import org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowTriggersStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowVariablesStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.UnSetTTLStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.CreateModelStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.DropModelStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowAINodesStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowModelsStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.AlterPipeStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.CreatePipePluginStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.CreatePipeStatement; @@ -1300,6 +1308,43 @@ private void parseViewSourcePaths( } } + // Create Model ===================================================================== + public static void validateModelName(String modelName) { + if (modelName.length() < 2 || modelName.length() > 64) { + throw new SemanticException("Model name should be 2-64 characters"); + } else if (modelName.startsWith("_")) { + throw new SemanticException("Model name should not start with '_'"); + } else if (!modelName.matches("^[-\\w]*$")) { + throw new SemanticException("ModelName can only contain letters, numbers, and underscores"); + } + } + + @Override + public Statement visitCreateModel(IoTDBSqlParser.CreateModelContext ctx) { + CreateModelStatement createModelStatement = new CreateModelStatement(); + String modelName = ctx.modelName.getText(); + validateModelName(modelName); + createModelStatement.setModelName(parseIdentifier(modelName)); + createModelStatement.setUri(ctx.modelUri.getText()); + return createModelStatement; + } + + // Drop Model ===================================================================== + @Override + public Statement visitDropModel(IoTDBSqlParser.DropModelContext ctx) { + return new DropModelStatement(parseIdentifier(ctx.modelId.getText())); + } + + // Show Models ===================================================================== + @Override + public Statement visitShowModels(IoTDBSqlParser.ShowModelsContext ctx) { + ShowModelsStatement statement = new ShowModelsStatement(); + if (ctx.modelId != null) { + statement.setModelName(parseIdentifier(ctx.modelId.getText())); + } + return statement; + } + /** Data Manipulation Language (DML). */ // Select Statement ======================================================================== @@ -3426,6 +3471,11 @@ public Statement visitShowConfigNodes(IoTDBSqlParser.ShowConfigNodesContext ctx) return new ShowConfigNodesStatement(); } + @Override + public Statement visitShowAINodes(IoTDBSqlParser.ShowAINodesContext ctx) { + return new ShowAINodesStatement(); + } + // device template @Override @@ -4320,6 +4370,57 @@ public Statement visitShowSpaceQuota(IoTDBSqlParser.ShowSpaceQuotaContext ctx) { return showSpaceQuotaStatement; } + @Override + public Statement visitCallInference(IoTDBSqlParser.CallInferenceContext ctx) { + String sql = ctx.inputSql.getText(); + QueryStatement statement = + (QueryStatement) + StatementGenerator.createStatement(sql.substring(1, sql.length() - 1), zoneId); + + statement.setModelName(parseIdentifier(ctx.modelId.getText())); + statement.setHasModelInference(true); + + if (ctx.hparamPair() != null) { + for (IoTDBSqlParser.HparamPairContext context : ctx.hparamPair()) { + IoTDBSqlParser.HparamValueContext valueContext = context.hparamValue(); + String paramKey = context.hparamKey.getText(); + if (paramKey.equalsIgnoreCase("WINDOW")) { + if (statement.isSetInferenceWindow()) { + throw new SemanticException("There should be only one window in CALL INFERENCE."); + } + if (valueContext.windowFunction().isEmpty()) { + throw new SemanticException( + "Window Function(e.g. HEAD, TAIL, COUNT) should be set in value when key is 'WINDOW' in CALL INFERENCE"); + } + parseWindowFunctionInInference(valueContext.windowFunction(), statement); + } else { + statement.addInferenceAttribute( + paramKey, parseAttributeValue(valueContext.attributeValue())); + } + } + } + + return statement; + } + + private void parseWindowFunctionInInference( + IoTDBSqlParser.WindowFunctionContext windowContext, QueryStatement statement) { + InferenceWindow inferenceWindow = null; + if (windowContext.TAIL() != null) { + inferenceWindow = + new TailInferenceWindow(Integer.parseInt(windowContext.windowSize.getText())); + } else if (windowContext.HEAD() != null) { + inferenceWindow = + new HeadInferenceWindow(Integer.parseInt(windowContext.windowSize.getText())); + } else if (windowContext.COUNT() != null) { + inferenceWindow = + new CountInferenceWindow( + Integer.parseInt(windowContext.interval.getText()), + Integer.parseInt(windowContext.step.getText())); + } + statement.setInferenceWindow(inferenceWindow); + } + @Override public Statement visitShowCurrentTimestamp(IoTDBSqlParser.ShowCurrentTimestampContext ctx) { return new ShowCurrentTimestampStatement(); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/StatementGenerator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/StatementGenerator.java index 08cddf94ac6e..b220309df583 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/StatementGenerator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/StatementGenerator.java @@ -68,6 +68,7 @@ import org.apache.iotdb.db.utils.QueryDataSetUtils; import org.apache.iotdb.db.utils.TimestampPrecisionUtils; import org.apache.iotdb.db.utils.constant.SqlConstant; +import org.apache.iotdb.mpp.rpc.thrift.TFetchTimeseriesReq; import org.apache.iotdb.service.rpc.thrift.TSAggregationQueryReq; import org.apache.iotdb.service.rpc.thrift.TSCreateAlignedTimeseriesReq; import org.apache.iotdb.service.rpc.thrift.TSCreateMultiTimeseriesReq; @@ -861,4 +862,8 @@ private static PartialPath parseDatabaseRawString(String database) throws Illega MetaFormatUtils.checkDatabase(database); return databasePath; } + + public static Statement createStatement(TFetchTimeseriesReq fetchTimeseriesReq, ZoneId zoneId) { + return invokeParser(fetchTimeseriesReq.getQueryBody(), zoneId); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java index 1af2387b0d51..94332ea7871f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java @@ -59,6 +59,7 @@ import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metedata.read.SeriesSchemaFetchScanNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metedata.read.TimeSeriesCountNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metedata.read.TimeSeriesSchemaScanNode; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AI.InferenceNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.ColumnInjectNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.DeviceViewIntoNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.DeviceViewNode; @@ -1369,6 +1370,19 @@ public LogicalPlanBuilder planOrderBy(QueryStatement queryStatement, Analysis an return this; } + public LogicalPlanBuilder planInference(Analysis analysis) { + this.root = + new InferenceNode( + context.getQueryId().genPlanNodeId(), + root, + analysis.getModelInferenceDescriptor(), + analysis.getOutputExpressions().stream() + .map(expressionStringPair -> expressionStringPair.left.getExpressionString()) + .collect(Collectors.toList())); + + return this; + } + public LogicalPlanBuilder planEndTimeColumnInject( GroupByTimeParameter groupByTimeParameter, boolean ascending) { this.root = diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java index c179db107031..5913bc9f6658 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanVisitor.java @@ -229,6 +229,10 @@ public PlanNode visitQuery(QueryStatement queryStatement, MPPQueryContext contex planBuilder = planBuilder.planLimit(queryStatement.getRowLimit()); } + if (queryStatement.hasModelInference()) { + planBuilder.planInference(analysis); + } + // plan select into if (queryStatement.isAlignByDevice()) { planBuilder = planBuilder.planDeviceViewInto(analysis.getDeviceViewIntoPathDescriptor()); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java index 05425c23b2b3..662db079241c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java @@ -21,6 +21,7 @@ import org.apache.iotdb.common.rpc.thrift.TAggregationType; import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.commons.path.AlignedPath; import org.apache.iotdb.commons.path.MeasurementPath; import org.apache.iotdb.commons.path.PartialPath; @@ -185,6 +186,7 @@ import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metedata.read.SeriesSchemaFetchScanNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metedata.read.TimeSeriesCountNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metedata.read.TimeSeriesSchemaScanNode; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AI.InferenceNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.ActiveRegionScanMergeNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AggregationMergeSortNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AggregationNode; @@ -242,6 +244,7 @@ import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.IntoPathDescriptor; import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.OutputColumn; import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.SeriesScanOptions; +import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.db.queryengine.plan.statement.component.FillPolicy; import org.apache.iotdb.db.queryengine.plan.statement.component.OrderByKey; import org.apache.iotdb.db.queryengine.plan.statement.component.Ordering; @@ -268,6 +271,7 @@ import org.apache.tsfile.file.metadata.PlainDeviceID; import org.apache.tsfile.read.TimeValuePair; import org.apache.tsfile.read.common.block.TsBlockBuilder; +import org.apache.tsfile.read.common.block.column.TimeColumn; import org.apache.tsfile.read.filter.basic.Filter; import org.apache.tsfile.read.filter.operator.TimeFilterOperators.TimeGt; import org.apache.tsfile.read.filter.operator.TimeFilterOperators.TimeGtEq; @@ -2218,6 +2222,49 @@ public Operator visitSort(SortNode node, LocalExecutionPlanContext context) { getComparator(sortItemList, sortItemIndexList, sortItemDataTypeList)); } + @Override + public Operator visitInference(InferenceNode node, LocalExecutionPlanContext context) { + Operator child = node.getChild().accept(this, context); + OperatorContext operatorContext = + context + .getDriverContext() + .addOperatorContext( + context.getNextOperatorId(), + node.getPlanNodeId(), + org.apache.iotdb.db.queryengine.execution.operator.process.AI.InferenceOperator + .class + .getSimpleName()); + + ModelInferenceDescriptor modelInferenceDescriptor = node.getModelInferenceDescriptor(); + ModelInformation modelInformation = modelInferenceDescriptor.getModelInformation(); + int[] inputShape = modelInformation.getInputShape(); + int[] outputShape = modelInformation.getOutputShape(); + TSDataType[] inputTypes = modelInferenceDescriptor.getModelInformation().getInputDataType(); + TSDataType[] outputTypes = modelInferenceDescriptor.getModelInformation().getOutputDataType(); + + long maxRetainedSize = + calculateSize(inputShape[0], inputTypes) + TimeColumn.SIZE_IN_BYTES_PER_POSITION; + long maxReturnSize = calculateSize(outputShape[0], outputTypes); + + return new org.apache.iotdb.db.queryengine.execution.operator.process.AI.InferenceOperator( + operatorContext, + child, + modelInferenceDescriptor, + FragmentInstanceManager.getInstance().getModelInferenceExecutor(), + node.getInputColumnNames(), + node.getChild().getOutputColumnNames(), + maxRetainedSize, + maxReturnSize); + } + + private long calculateSize(long rowNumber, TSDataType[] dataTypes) { + long size = 0; + for (int i = 0; i < dataTypes.length; i++) { + size += getOutputColumnSizePerLine(dataTypes[i]); + } + return size * rowNumber; + } + @Override public Operator visitInto(IntoNode node, LocalExecutionPlanContext context) { Operator child = node.getChild().accept(this, context); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/PlanNodeType.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/PlanNodeType.java index 061f22f171d5..b3e1d212041a 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/PlanNodeType.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/PlanNodeType.java @@ -62,6 +62,7 @@ import org.apache.iotdb.db.queryengine.plan.planner.plan.node.pipe.PipeEnrichedNonWritePlanNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.pipe.PipeEnrichedWritePlanNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.pipe.PipeOperateSchemaQueueNode; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AI.InferenceNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.ActiveRegionScanMergeNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AggregationMergeSortNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AggregationNode; @@ -197,9 +198,7 @@ public enum PlanNodeType { LOGICAL_VIEW_SCHEMA_SCAN((short) 77), ALTER_LOGICAL_VIEW((short) 78), PIPE_ENRICHED_INSERT_DATA((short) 79), - - // NodeId 80 is used by IoTDB-ML which shouldn't be used. - + INFERENCE((short) 80), LAST_QUERY_TRANSFORM((short) 81), TOP_K((short) 82), COLUMN_INJECT((short) 83), @@ -437,6 +436,8 @@ public static PlanNode deserialize(ByteBuffer buffer, short nodeType) { return AlterLogicalViewNode.deserialize(buffer); case 79: return PipeEnrichedInsertNode.deserialize(buffer); + case 80: + return InferenceNode.deserialize(buffer); case 81: return LastQueryTransformNode.deserialize(buffer); case 82: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/PlanVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/PlanVisitor.java index 2654487da54c..92261d7ad479 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/PlanVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/PlanVisitor.java @@ -59,6 +59,7 @@ import org.apache.iotdb.db.queryengine.plan.planner.plan.node.pipe.PipeEnrichedNonWritePlanNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.pipe.PipeEnrichedWritePlanNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.pipe.PipeOperateSchemaQueueNode; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AI.InferenceNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.ActiveRegionScanMergeNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AggregationMergeSortNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AggregationNode; @@ -236,6 +237,10 @@ public R visitSingleDeviceView(SingleDeviceViewNode node, C context) { return visitSingleChildProcess(node, context); } + public R visitInference(InferenceNode node, C context) { + return visitSingleChildProcess(node, context); + } + public R visitExplainAnalyze(ExplainAnalyzeNode node, C context) { return visitSingleChildProcess(node, context); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java new file mode 100644 index 000000000000..95fe3437e788 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AI; + +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNodeId; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNodeType; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanVisitor; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.SingleChildProcessNode; +import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; + +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Objects; + +public class InferenceNode extends SingleChildProcessNode { + + private final ModelInferenceDescriptor modelInferenceDescriptor; + + // the column order in select item which reflects the real input order + private final List targetColumnNames; + + public InferenceNode( + PlanNodeId id, + PlanNode child, + ModelInferenceDescriptor modelInferenceDescriptor, + List targetColumnNames) { + super(id, child); + this.modelInferenceDescriptor = modelInferenceDescriptor; + this.targetColumnNames = targetColumnNames; + } + + public InferenceNode( + PlanNodeId id, + ModelInferenceDescriptor modelInferenceDescriptor, + List inputColumnNames) { + super(id); + this.modelInferenceDescriptor = modelInferenceDescriptor; + this.targetColumnNames = inputColumnNames; + } + + public ModelInferenceDescriptor getModelInferenceDescriptor() { + return modelInferenceDescriptor; + } + + public List getInputColumnNames() { + return targetColumnNames; + } + + @Override + public R accept(PlanVisitor visitor, C context) { + return visitor.visitInference(this, context); + } + + @Override + public PlanNode clone() { + return new InferenceNode(getPlanNodeId(), child, modelInferenceDescriptor, targetColumnNames); + } + + @Override + public List getOutputColumnNames() { + return modelInferenceDescriptor.getOutputColumnNames(); + } + + @Override + protected void serializeAttributes(ByteBuffer byteBuffer) { + PlanNodeType.INFERENCE.serialize(byteBuffer); + modelInferenceDescriptor.serialize(byteBuffer); + ReadWriteIOUtils.writeStringList(targetColumnNames, byteBuffer); + } + + @Override + protected void serializeAttributes(DataOutputStream stream) throws IOException { + PlanNodeType.INFERENCE.serialize(stream); + modelInferenceDescriptor.serialize(stream); + ReadWriteIOUtils.writeStringList(targetColumnNames, stream); + } + + public static InferenceNode deserialize(ByteBuffer buffer) { + ModelInferenceDescriptor modelInferenceDescriptor = + ModelInferenceDescriptor.deserialize(buffer); + List inputColumnNames = ReadWriteIOUtils.readStringList(buffer); + PlanNodeId planNodeId = PlanNodeId.deserialize(buffer); + return new InferenceNode(planNodeId, modelInferenceDescriptor, inputColumnNames); + } + + @Override + public String toString() { + return "InferenceNode-" + this.getPlanNodeId(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + InferenceNode that = (InferenceNode) o; + return modelInferenceDescriptor.equals(that.modelInferenceDescriptor) + && targetColumnNames.equals(that.targetColumnNames); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), modelInferenceDescriptor, targetColumnNames); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java new file mode 100644 index 000000000000..bf5f391d9e47 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model; + +import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.model.ModelInformation; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowParameter; + +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class ModelInferenceDescriptor { + + private final TEndPoint targetAINode; + private final ModelInformation modelInformation; + private List outputColumnNames; + private InferenceWindowParameter inferenceWindowParameter; + private Map inferenceAttributes; + + public ModelInferenceDescriptor(TEndPoint targetAINode, ModelInformation modelInformation) { + this.targetAINode = targetAINode; + this.modelInformation = modelInformation; + } + + private ModelInferenceDescriptor(ByteBuffer buffer) { + this.targetAINode = + new TEndPoint(ReadWriteIOUtils.readString(buffer), ReadWriteIOUtils.readInt(buffer)); + this.modelInformation = ModelInformation.deserialize(buffer); + int outputColumnNamesSize = ReadWriteIOUtils.readInt(buffer); + if (outputColumnNamesSize == 0) { + this.outputColumnNames = null; + } else { + this.outputColumnNames = new ArrayList<>(); + for (int i = 0; i < outputColumnNamesSize; i++) { + this.outputColumnNames.add(ReadWriteIOUtils.readString(buffer)); + } + } + boolean hasInferenceWindowParameter = ReadWriteIOUtils.readBool(buffer); + if (hasInferenceWindowParameter) { + this.inferenceWindowParameter = InferenceWindowParameter.deserialize(buffer); + } else { + this.inferenceWindowParameter = null; + } + int inferenceAttributesSize = ReadWriteIOUtils.readInt(buffer); + if (inferenceAttributesSize == 0) { + this.inferenceAttributes = null; + } else { + this.inferenceAttributes = new HashMap<>(); + for (int i = 0; i < inferenceAttributesSize; i++) { + this.inferenceAttributes.put( + ReadWriteIOUtils.readString(buffer), ReadWriteIOUtils.readString(buffer)); + } + } + } + + public void setInferenceAttributes(Map inferenceAttributes) { + this.inferenceAttributes = inferenceAttributes; + } + + public Map getInferenceAttributes() { + return inferenceAttributes; + } + + public void setInferenceWindowParameter(InferenceWindowParameter inferenceWindowParameter) { + this.inferenceWindowParameter = inferenceWindowParameter; + } + + public InferenceWindowParameter getInferenceWindowParameter() { + return inferenceWindowParameter; + } + + public ModelInformation getModelInformation() { + return modelInformation; + } + + public TEndPoint getTargetAINode() { + return targetAINode; + } + + public String getModelName() { + return modelInformation.getModelName(); + } + + public void setOutputColumnNames(List outputColumnNames) { + this.outputColumnNames = outputColumnNames; + } + + public List getOutputColumnNames() { + return outputColumnNames; + } + + public void serialize(ByteBuffer byteBuffer) { + ReadWriteIOUtils.write(targetAINode.ip, byteBuffer); + ReadWriteIOUtils.write(targetAINode.port, byteBuffer); + modelInformation.serialize(byteBuffer); + if (outputColumnNames == null) { + ReadWriteIOUtils.write(0, byteBuffer); + } else { + ReadWriteIOUtils.write(outputColumnNames.size(), byteBuffer); + for (String outputColumnName : outputColumnNames) { + ReadWriteIOUtils.write(outputColumnName, byteBuffer); + } + } + if (inferenceWindowParameter == null) { + ReadWriteIOUtils.write(false, byteBuffer); + } else { + ReadWriteIOUtils.write(true, byteBuffer); + inferenceWindowParameter.serialize(byteBuffer); + } + if (inferenceAttributes == null) { + ReadWriteIOUtils.write(0, byteBuffer); + } else { + ReadWriteIOUtils.write(inferenceAttributes.size(), byteBuffer); + for (Map.Entry entry : inferenceAttributes.entrySet()) { + ReadWriteIOUtils.write(entry.getKey(), byteBuffer); + ReadWriteIOUtils.write(entry.getValue(), byteBuffer); + } + } + } + + public void serialize(DataOutputStream stream) throws IOException { + ReadWriteIOUtils.write(targetAINode.ip, stream); + ReadWriteIOUtils.write(targetAINode.port, stream); + modelInformation.serialize(stream); + if (outputColumnNames == null) { + ReadWriteIOUtils.write(0, stream); + } else { + ReadWriteIOUtils.write(outputColumnNames.size(), stream); + for (String outputColumnName : outputColumnNames) { + ReadWriteIOUtils.write(outputColumnName, stream); + } + } + if (inferenceWindowParameter == null) { + ReadWriteIOUtils.write(false, stream); + } else { + ReadWriteIOUtils.write(true, stream); + inferenceWindowParameter.serialize(stream); + } + if (inferenceAttributes == null) { + ReadWriteIOUtils.write(0, stream); + } else { + ReadWriteIOUtils.write(inferenceAttributes.size(), stream); + for (Map.Entry entry : inferenceAttributes.entrySet()) { + ReadWriteIOUtils.write(entry.getKey(), stream); + ReadWriteIOUtils.write(entry.getValue(), stream); + } + } + } + + public static ModelInferenceDescriptor deserialize(ByteBuffer buffer) { + return new ModelInferenceDescriptor(buffer); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ModelInferenceDescriptor that = (ModelInferenceDescriptor) o; + return targetAINode.equals(that.targetAINode) + && modelInformation.equals(that.modelInformation) + && outputColumnNames.equals(that.outputColumnNames) + && inferenceWindowParameter.equals(that.inferenceWindowParameter) + && inferenceAttributes.equals(that.inferenceAttributes); + } + + @Override + public int hashCode() { + return Objects.hash( + targetAINode, + modelInformation, + outputColumnNames, + inferenceWindowParameter, + inferenceAttributes); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/StatementVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/StatementVisitor.java index e7b7e3ffa663..d15677de60b7 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/StatementVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/StatementVisitor.java @@ -74,6 +74,10 @@ import org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowTriggersStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowVariablesStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.UnSetTTLStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.CreateModelStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.DropModelStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowAINodesStatement; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.model.ShowModelsStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.AlterPipeStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.CreatePipePluginStatement; import org.apache.iotdb.db.queryengine.plan.statement.metadata.pipe.CreatePipeStatement; @@ -283,6 +287,19 @@ public R visitAlterLogicalView(AlterLogicalViewStatement alterLogicalViewStateme return visitStatement(alterLogicalViewStatement, context); } + // AI Model + public R visitCreateModel(CreateModelStatement createModelStatement, C context) { + return visitStatement(createModelStatement, context); + } + + public R visitDropModel(DropModelStatement dropModelStatement, C context) { + return visitStatement(dropModelStatement, context); + } + + public R visitShowModels(ShowModelsStatement showModelsModelStatement, C context) { + return visitStatement(showModelsModelStatement, context); + } + /** Data Manipulation Language (DML) */ // Select Statement @@ -445,6 +462,10 @@ public R visitShowConfigNodes(ShowConfigNodesStatement showConfigNodesStatement, return visitStatement(showConfigNodesStatement, context); } + public R visitShowAINodes(ShowAINodesStatement showAINodesStatement, C context) { + return visitStatement(showAINodesStatement, context); + } + public R visitShowVersion(ShowVersionStatement showVersionStatement, C context) { return visitStatement(showVersionStatement, context); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/QueryStatement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/QueryStatement.java index e09bdc004150..2ff3e8149c89 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/QueryStatement.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/QueryStatement.java @@ -26,6 +26,7 @@ import org.apache.iotdb.db.auth.AuthorityChecker; import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.execution.operator.window.WindowType; +import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindow; import org.apache.iotdb.db.queryengine.plan.analyze.ExpressionAnalyzer; import org.apache.iotdb.db.queryengine.plan.expression.Expression; import org.apache.iotdb.db.queryengine.plan.expression.leaf.TimeSeriesOperand; @@ -53,8 +54,10 @@ import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import static org.apache.iotdb.db.utils.constant.SqlConstant.COUNT_TIME; @@ -133,6 +136,57 @@ public class QueryStatement extends AuthorityInformationStatement { // we can skip the query private boolean isResultSetEmpty = false; + // [IoTDB-AI] used for model inference, which will be removed in the future + private String modelName; + private boolean hasModelInference = false; + private InferenceWindow inferenceWindow = null; + private Map inferenceAttribute = null; + + public void setModelName(String modelName) { + this.modelName = modelName; + } + + public String getModelName() { + return modelName; + } + + public void setHasModelInference(boolean hasModelInference) { + this.hasModelInference = hasModelInference; + } + + public boolean hasModelInference() { + return hasModelInference; + } + + public void setInferenceWindow(InferenceWindow inferenceWindow) { + this.inferenceWindow = inferenceWindow; + } + + public boolean isSetInferenceWindow() { + return this.inferenceWindow != null; + } + + public InferenceWindow getInferenceWindow() { + return inferenceWindow; + } + + public void addInferenceAttribute(String key, String value) { + if (inferenceAttribute == null) { + inferenceAttribute = new HashMap<>(); + } + inferenceAttribute.put(key, value); + } + + public Map getInferenceAttributes() { + return inferenceAttribute; + } + + public boolean hasInferenceAttributes() { + return inferenceAttribute != null; + } + + // [IoTDB-AI] END + public QueryStatement() { this.statementType = StatementType.QUERY; } @@ -546,6 +600,16 @@ public void setLastLevelUseWildcard(boolean lastLevelUseWildcard) { @SuppressWarnings({"squid:S3776", "squid:S6541"}) // Suppress high Cognitive Complexity warning public void semanticCheck() { + + if (hasModelInference) { + if (isAlignByDevice()) { + throw new SemanticException("Model inference does not support align by device now."); + } + if (isSelectInto()) { + throw new SemanticException("Model inference does not support select into now."); + } + } + if (isAggregationQuery()) { if (groupByComponent != null && isGroupByLevel()) { throw new SemanticException("GROUP BY CLAUSES doesn't support GROUP BY LEVEL now."); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/CreateModelStatement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/CreateModelStatement.java new file mode 100644 index 000000000000..2f43a1cfe762 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/CreateModelStatement.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.statement.metadata.model; + +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.auth.entity.PrivilegeType; +import org.apache.iotdb.commons.path.PartialPath; +import org.apache.iotdb.db.auth.AuthorityChecker; +import org.apache.iotdb.db.queryengine.plan.analyze.QueryType; +import org.apache.iotdb.db.queryengine.plan.statement.IConfigStatement; +import org.apache.iotdb.db.queryengine.plan.statement.Statement; +import org.apache.iotdb.db.queryengine.plan.statement.StatementVisitor; +import org.apache.iotdb.rpc.TSStatusCode; + +import java.util.Collections; +import java.util.List; + +public class CreateModelStatement extends Statement implements IConfigStatement { + + private String modelName; + + private String uri; + + public CreateModelStatement() { + // do nothing + } + + public String getUri() { + return uri; + } + + public String getModelName() { + return modelName; + } + + public void setUri(String uri) { + this.uri = uri; + } + + public void setModelName(String modelName) { + this.modelName = modelName; + } + + @Override + public List getPaths() { + return Collections.emptyList(); + } + + @Override + public QueryType getQueryType() { + return QueryType.WRITE; + } + + @Override + public TSStatus checkPermissionBeforeProcess(String userName) { + if (AuthorityChecker.SUPER_USER.equals(userName)) { + return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); + } + return AuthorityChecker.getTSStatus( + AuthorityChecker.checkSystemPermission(userName, PrivilegeType.USE_MODEL.ordinal()), + PrivilegeType.USE_MODEL); + } + + @Override + public R accept(StatementVisitor visitor, C context) { + return visitor.visitCreateModel(this, context); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/DropModelStatement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/DropModelStatement.java new file mode 100644 index 000000000000..3e207608241d --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/DropModelStatement.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.statement.metadata.model; + +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.auth.entity.PrivilegeType; +import org.apache.iotdb.commons.path.PartialPath; +import org.apache.iotdb.db.auth.AuthorityChecker; +import org.apache.iotdb.db.queryengine.plan.analyze.QueryType; +import org.apache.iotdb.db.queryengine.plan.statement.IConfigStatement; +import org.apache.iotdb.db.queryengine.plan.statement.Statement; +import org.apache.iotdb.db.queryengine.plan.statement.StatementVisitor; +import org.apache.iotdb.rpc.TSStatusCode; + +import java.util.Collections; +import java.util.List; + +public class DropModelStatement extends Statement implements IConfigStatement { + + private final String modelName; + + public DropModelStatement(String modelName) { + this.modelName = modelName; + } + + public String getModelName() { + return modelName; + } + + @Override + public List getPaths() { + return Collections.emptyList(); + } + + @Override + public QueryType getQueryType() { + return QueryType.WRITE; + } + + @Override + public TSStatus checkPermissionBeforeProcess(String userName) { + if (AuthorityChecker.SUPER_USER.equals(userName)) { + return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); + } + return AuthorityChecker.getTSStatus( + AuthorityChecker.checkSystemPermission(userName, PrivilegeType.USE_MODEL.ordinal()), + PrivilegeType.USE_MODEL); + } + + @Override + public R accept(StatementVisitor visitor, C context) { + return visitor.visitDropModel(this, context); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/ShowAINodesStatement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/ShowAINodesStatement.java new file mode 100644 index 000000000000..602d0e01465b --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/ShowAINodesStatement.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.statement.metadata.model; + +import org.apache.iotdb.db.queryengine.plan.analyze.QueryType; +import org.apache.iotdb.db.queryengine.plan.statement.IConfigStatement; +import org.apache.iotdb.db.queryengine.plan.statement.StatementVisitor; +import org.apache.iotdb.db.queryengine.plan.statement.metadata.ShowStatement; + +public class ShowAINodesStatement extends ShowStatement implements IConfigStatement { + + public ShowAINodesStatement() {} + + @Override + public QueryType getQueryType() { + return QueryType.READ; + } + + @Override + public R accept(StatementVisitor visitor, C context) { + return visitor.visitShowAINodes(this, context); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/ShowModelsStatement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/ShowModelsStatement.java new file mode 100644 index 000000000000..0b810b49ad82 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/metadata/model/ShowModelsStatement.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.statement.metadata.model; + +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.auth.entity.PrivilegeType; +import org.apache.iotdb.commons.path.PartialPath; +import org.apache.iotdb.db.auth.AuthorityChecker; +import org.apache.iotdb.db.queryengine.plan.analyze.QueryType; +import org.apache.iotdb.db.queryengine.plan.statement.IConfigStatement; +import org.apache.iotdb.db.queryengine.plan.statement.Statement; +import org.apache.iotdb.db.queryengine.plan.statement.StatementVisitor; +import org.apache.iotdb.rpc.TSStatusCode; + +import java.util.Collections; +import java.util.List; + +public class ShowModelsStatement extends Statement implements IConfigStatement { + + private String modelName; + + public ShowModelsStatement() { + // do nothing + } + + public void setModelName(String modelName) { + this.modelName = modelName; + } + + public boolean isSetModelName() { + return modelName != null; + } + + public String getModelName() { + return modelName; + } + + @Override + public List getPaths() { + return Collections.emptyList(); + } + + @Override + public QueryType getQueryType() { + return QueryType.READ; + } + + @Override + public TSStatus checkPermissionBeforeProcess(String userName) { + if (AuthorityChecker.SUPER_USER.equals(userName)) { + return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); + } + return AuthorityChecker.getTSStatus( + AuthorityChecker.checkSystemPermission(userName, PrivilegeType.USE_MODEL.ordinal()), + PrivilegeType.USE_MODEL); + } + + @Override + public R accept(StatementVisitor visitor, C context) { + return visitor.visitShowModels(this, context); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/AINodeRPCService.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/AINodeRPCService.java new file mode 100644 index 000000000000..5ec49e756602 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/AINodeRPCService.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.service; + +import org.apache.iotdb.commons.concurrent.ThreadName; +import org.apache.iotdb.commons.exception.runtime.RPCServiceException; +import org.apache.iotdb.commons.service.ServiceType; +import org.apache.iotdb.commons.service.ThriftService; +import org.apache.iotdb.commons.service.ThriftServiceThread; +import org.apache.iotdb.db.conf.IoTDBConfig; +import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.protocol.thrift.handler.AINodeRPCServiceThriftHandler; +import org.apache.iotdb.db.protocol.thrift.impl.AINodeRPCServiceImpl; +import org.apache.iotdb.mpp.rpc.thrift.IAINodeInternalRPCService; +import org.apache.iotdb.rpc.DeepCopyRpcTransportFactory; + +public class AINodeRPCService extends ThriftService implements AINodeRPCServiceMBean { + + private AINodeRPCServiceImpl impl; + + private AINodeRPCService() {} + + @Override + public ServiceType getID() { + return ServiceType.AINode_RPC_SERVICE; + } + + @Override + public void initTProcessor() { + impl = new AINodeRPCServiceImpl(); + initSyncedServiceImpl(null); + processor = new IAINodeInternalRPCService.Processor<>(impl); + } + + @Override + public void initThriftServiceThread() + throws IllegalAccessException, InstantiationException, ClassNotFoundException { + try { + IoTDBConfig config = IoTDBDescriptor.getInstance().getConfig(); + thriftServiceThread = + new ThriftServiceThread( + processor, + getID().getName(), + ThreadName.AINODE_RPC_SERVICE.getName(), + getBindIP(), + getBindPort(), + config.getRpcMaxConcurrentClientNum(), + config.getThriftServerAwaitTimeForStopService(), + new AINodeRPCServiceThriftHandler(impl), + config.isRpcThriftCompressionEnable(), + DeepCopyRpcTransportFactory.INSTANCE); + } catch (RPCServiceException e) { + throw new IllegalAccessException(e.getMessage()); + } + thriftServiceThread.setName(ThreadName.AINODE_RPC_SERVICE.getName()); + } + + @Override + public String getBindIP() { + return IoTDBDescriptor.getInstance().getConfig().getRpcAddress(); + } + + @Override + public int getBindPort() { + return IoTDBDescriptor.getInstance().getConfig().getAINodePort(); + } + + private static class AINodeRPCServiceHolder { + private static final AINodeRPCService INSTANCE = new AINodeRPCService(); + + private AINodeRPCServiceHolder() {} + } + + public static AINodeRPCService getInstance() { + return AINodeRPCServiceHolder.INSTANCE; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/AINodeRPCServiceMBean.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/AINodeRPCServiceMBean.java new file mode 100644 index 000000000000..f4f51c0caa27 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/AINodeRPCServiceMBean.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.service; + +public interface AINodeRPCServiceMBean {} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java index d1ea85fe4c71..89bd0df7c286 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java @@ -684,6 +684,10 @@ private void setUp() throws StartupException { private void setUpRPCService() throws StartupException { // Start InternalRPCService to indicate that the current DataNode can accept cluster scheduling registerManager.register(DataNodeInternalRPCService.getInstance()); + // Start InternalRPCService to indicate that the current DataNode can accept request from AINode + if (config.isEnableAINodeService()) { + registerManager.register(AINodeRPCService.getInstance()); + } // Notice: During the period between starting the internal RPC service // and starting the client RPC service , some requests may fail because diff --git a/iotdb-core/node-commons/pom.xml b/iotdb-core/node-commons/pom.xml index 56ab07bd99d7..262abd5e2039 100644 --- a/iotdb-core/node-commons/pom.xml +++ b/iotdb-core/node-commons/pom.xml @@ -182,6 +182,12 @@ org.apache.ratis ratis-common + + org.apache.iotdb + iotdb-thrift-ainode + 1.3.3-SNAPSHOT + compile + diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java index ee550d09a74d..c59416d1eaa0 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java @@ -20,6 +20,8 @@ package org.apache.iotdb.commons.client; import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.ainode.AINodeClient; +import org.apache.iotdb.commons.client.ainode.AsyncAINodeServiceClient; import org.apache.iotdb.commons.client.async.AsyncConfigNodeInternalServiceClient; import org.apache.iotdb.commons.client.async.AsyncDataNodeExternalServiceClient; import org.apache.iotdb.commons.client.async.AsyncDataNodeInternalServiceClient; @@ -294,6 +296,56 @@ public KeyedObjectPool createClie } } + public static class AsyncAINodeHeartbeatServiceClientPoolFactory + implements IClientPoolFactory { + @Override + public KeyedObjectPool createClientPool( + ClientManager manager) { + GenericKeyedObjectPool clientPool = + new GenericKeyedObjectPool<>( + new AsyncAINodeServiceClient.Factory( + manager, + new ThriftClientProperty.Builder() + .setConnectionTimeoutMs(conf.getConnectionTimeoutInMS()) + .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnabled()) + .setSelectorNumOfAsyncClientManager(conf.getSelectorNumOfClientManager()) + .setPrintLogWhenEncounterException(false) + .build(), + ThreadName.ASYNC_DATANODE_HEARTBEAT_CLIENT_POOL.getName()), + new ClientPoolProperty.Builder() + .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) + .build() + .getConfig()); + ClientManagerMetrics.getInstance() + .registerClientManager(this.getClass().getSimpleName(), clientPool); + return clientPool; + } + } + + public static class AINodeClientPoolFactory + implements IClientPoolFactory { + + @Override + public KeyedObjectPool createClientPool( + ClientManager manager) { + GenericKeyedObjectPool clientPool = + new GenericKeyedObjectPool<>( + new AINodeClient.Factory( + manager, + new ThriftClientProperty.Builder() + .setConnectionTimeoutMs(conf.getConnectionTimeoutInMS()) + .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnabled()) + .build()), + new ClientPoolProperty.Builder() + .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) + .build() + .getConfig()); + ClientManagerMetrics.getInstance() + .registerClientManager(this.getClass().getSimpleName(), clientPool); + return clientPool; + } + } + public static class SyncPipeConsensusServiceClientPoolFactory implements IClientPoolFactory { diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java new file mode 100644 index 000000000000..cc93a8f32d53 --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.client.ainode; + +import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; +import org.apache.iotdb.ainode.rpc.thrift.TConfigs; +import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; +import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp; +import org.apache.iotdb.ainode.rpc.thrift.TWindowParams; +import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.client.ClientManager; +import org.apache.iotdb.commons.client.ThriftClient; +import org.apache.iotdb.commons.client.factory.ThriftClientFactory; +import org.apache.iotdb.commons.client.property.ThriftClientProperty; +import org.apache.iotdb.commons.exception.ainode.LoadModelException; +import org.apache.iotdb.commons.model.ModelInformation; +import org.apache.iotdb.rpc.TConfigurationConst; +import org.apache.iotdb.rpc.TSStatusCode; + +import org.apache.commons.pool2.PooledObject; +import org.apache.commons.pool2.impl.DefaultPooledObject; +import org.apache.thrift.TException; +import org.apache.thrift.transport.TSocket; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; +import org.apache.thrift.transport.layered.TFramedTransport; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.TsBlock; +import org.apache.tsfile.read.common.block.column.TsBlockSerde; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class AINodeClient implements AutoCloseable, ThriftClient { + + private static final Logger logger = LoggerFactory.getLogger(AINodeClient.class); + + private final TEndPoint endPoint; + + private TTransport transport; + + private final ThriftClientProperty property; + private IAINodeRPCService.Client client; + + public static final String MSG_CONNECTION_FAIL = + "Fail to connect to AINode. Please check status of AINode"; + + private final TsBlockSerde tsBlockSerde = new TsBlockSerde(); + + ClientManager clientManager; + + public AINodeClient( + ThriftClientProperty property, + TEndPoint endPoint, + ClientManager clientManager) + throws TException { + this.property = property; + this.clientManager = clientManager; + this.endPoint = endPoint; + init(); + } + + private void init() throws TException { + try { + transport = + new TFramedTransport.Factory() + .getTransport( + new TSocket( + TConfigurationConst.defaultTConfiguration, + endPoint.getIp(), + endPoint.getPort(), + property.getConnectionTimeoutMs())); + if (!transport.isOpen()) { + transport.open(); + } + } catch (TTransportException e) { + throw new TException(MSG_CONNECTION_FAIL); + } + client = new IAINodeRPCService.Client(property.getProtocolFactory().getProtocol(transport)); + } + + public TTransport getTransport() { + return transport; + } + + public ModelInformation registerModel(String modelName, String uri) throws LoadModelException { + try { + TRegisterModelReq req = new TRegisterModelReq(uri, modelName); + TRegisterModelResp resp = client.registerModel(req); + if (resp.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + throw new LoadModelException(resp.status.message, resp.status.getCode()); + } + return parseModelInformation(modelName, resp.getAttributes(), resp.getConfigs()); + } catch (TException e) { + throw new LoadModelException( + e.getMessage(), TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()); + } + } + + private ModelInformation parseModelInformation( + String modelName, String attributes, TConfigs configs) { + int[] inputShape = configs.getInput_shape().stream().mapToInt(Integer::intValue).toArray(); + int[] outputShape = configs.getOutput_shape().stream().mapToInt(Integer::intValue).toArray(); + + TSDataType[] inputType = new TSDataType[inputShape[1]]; + TSDataType[] outputType = new TSDataType[outputShape[1]]; + for (int i = 0; i < inputShape[1]; i++) { + inputType[i] = TSDataType.values()[configs.getInput_type().get(i)]; + } + for (int i = 0; i < outputShape[1]; i++) { + outputType[i] = TSDataType.values()[configs.getOutput_type().get(i)]; + } + + return new ModelInformation( + modelName, inputShape, outputShape, inputType, outputType, attributes); + } + + public TSStatus deleteModel(String modelId) throws TException { + try { + return client.deleteModel(new TDeleteModelReq(modelId)); + } catch (TException e) { + logger.warn( + "Failed to connect to AINode from ConfigNode when executing {}: {}", + Thread.currentThread().getStackTrace()[1].getMethodName(), + e.getMessage()); + throw new TException(MSG_CONNECTION_FAIL); + } + } + + public TInferenceResp inference( + String modelId, + List inputColumnNames, + List inputTypeList, + Map columnIndexMap, + TsBlock inputTsBlock, + Map inferenceAttributes, + TWindowParams windowParams) + throws TException { + try { + TInferenceReq inferenceReq = + new TInferenceReq( + modelId, + tsBlockSerde.serialize(inputTsBlock), + inputTypeList, + inputColumnNames, + columnIndexMap); + if (windowParams != null) { + inferenceReq.setWindowParams(windowParams); + } + if (inferenceAttributes != null) { + inferenceReq.setInferenceAttributes(inferenceAttributes); + } + return client.inference(inferenceReq); + } catch (IOException e) { + throw new TException("An exception occurred while serializing input tsblock", e); + } catch (TException e) { + logger.warn( + "Failed to connect to AINode from DataNode when executing {}: {}", + Thread.currentThread().getStackTrace()[1].getMethodName(), + e.getMessage()); + throw new TException(MSG_CONNECTION_FAIL); + } + } + + @Override + public void close() throws Exception { + Optional.ofNullable(transport).ifPresent(TTransport::close); + } + + @Override + public void invalidate() { + Optional.ofNullable(transport).ifPresent(TTransport::close); + } + + @Override + public void invalidateAll() { + clientManager.clear(endPoint); + } + + @Override + public boolean printLogWhenEncounterException() { + return property.isPrintLogWhenEncounterException(); + } + + public static class Factory extends ThriftClientFactory { + + public Factory( + ClientManager clientClientManager, + ThriftClientProperty thriftClientProperty) { + super(clientClientManager, thriftClientProperty); + } + + @Override + public void destroyObject(TEndPoint tEndPoint, PooledObject pooledObject) + throws Exception { + pooledObject.getObject().close(); + } + + @Override + public PooledObject makeObject(TEndPoint endPoint) throws Exception { + return new DefaultPooledObject<>( + new AINodeClient(thriftClientProperty, endPoint, clientManager)); + } + + @Override + public boolean validateObject(TEndPoint tEndPoint, PooledObject pooledObject) { + return Optional.ofNullable(pooledObject.getObject().getTransport()) + .map(TTransport::isOpen) + .orElse(false); + } + } +} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClientManager.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClientManager.java new file mode 100644 index 000000000000..3a06e478e7b5 --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClientManager.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.client.ainode; + +import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.ClientPoolFactory; +import org.apache.iotdb.commons.client.IClientManager; + +public class AINodeClientManager { + private AINodeClientManager() { + // Empty constructor + } + + private static final class AINodeClientManagerHolder { + private static final IClientManager INSTANCE = + new IClientManager.Factory() + .createClientManager(new ClientPoolFactory.AINodeClientPoolFactory()); + } + + public static IClientManager getInstance() { + return AINodeClientManagerHolder.INSTANCE; + } +} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeInfo.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeInfo.java new file mode 100644 index 000000000000..d6f3a6552795 --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeInfo.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.client.ainode; + +import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.conf.CommonDescriptor; + +public class AINodeInfo { + // currently, we only support one AINode + public static final TEndPoint endPoint = + CommonDescriptor.getInstance().getConfig().getTargetAINodeEndPoint(); +} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AsyncAINodeServiceClient.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AsyncAINodeServiceClient.java new file mode 100644 index 000000000000..3276923deb78 --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AsyncAINodeServiceClient.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.client.ainode; + +import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; +import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.ClientManager; +import org.apache.iotdb.commons.client.ThriftClient; +import org.apache.iotdb.commons.client.factory.AsyncThriftClientFactory; +import org.apache.iotdb.commons.client.property.ThriftClientProperty; +import org.apache.iotdb.rpc.TNonblockingSocketWrapper; + +import org.apache.commons.pool2.PooledObject; +import org.apache.commons.pool2.impl.DefaultPooledObject; +import org.apache.thrift.async.TAsyncClientManager; + +import java.io.IOException; + +public class AsyncAINodeServiceClient extends IAINodeRPCService.AsyncClient + implements ThriftClient { + + private final boolean printLogWhenEncounterException; + private final TEndPoint endPoint; + private final ClientManager clientManager; + + public AsyncAINodeServiceClient( + ThriftClientProperty property, + TEndPoint endPoint, + TAsyncClientManager tClientManager, + ClientManager clientManager) + throws IOException { + super( + property.getProtocolFactory(), + tClientManager, + TNonblockingSocketWrapper.wrap( + endPoint.getIp(), endPoint.getPort(), property.getConnectionTimeoutMs())); + setTimeout(property.getConnectionTimeoutMs()); + this.printLogWhenEncounterException = property.isPrintLogWhenEncounterException(); + this.endPoint = endPoint; + this.clientManager = clientManager; + } + + @Override + public void onComplete() { + super.onComplete(); + returnSelf(); + } + + @Override + public void onError(Exception e) { + super.onError(e); + ThriftClient.resolveException(e, this); + returnSelf(); + } + + @Override + public void invalidate() { + if (!hasError()) { + super.onError(new Exception("This client has been invalidated")); + } + } + + @Override + public void invalidateAll() { + clientManager.clear(endPoint); + } + + @Override + public boolean printLogWhenEncounterException() { + return printLogWhenEncounterException; + } + + private void returnSelf() { + clientManager.returnClient(endPoint, this); + } + + private void close() { + ___transport.close(); + ___currentMethod = null; + } + + public boolean isReady() { + try { + checkReady(); + return true; + } catch (Exception e) { + if (printLogWhenEncounterException) { + logger.error("Unexpected exception occurs in {} : {}", this, e.getMessage()); + } + return false; + } + } + + public static class Factory + extends AsyncThriftClientFactory { + + public Factory( + ClientManager clientManager, + ThriftClientProperty thriftClientProperty, + String threadName) { + super(clientManager, thriftClientProperty, threadName); + } + + @Override + public void destroyObject( + TEndPoint endPoint, PooledObject pooledObject) { + pooledObject.getObject().close(); + } + + @Override + public PooledObject makeObject(TEndPoint endPoint) throws Exception { + return new DefaultPooledObject<>( + new AsyncAINodeServiceClient( + thriftClientProperty, + endPoint, + tManagers[clientCnt.incrementAndGet() % tManagers.length], + clientManager)); + } + + @Override + public boolean validateObject( + TEndPoint endPoint, PooledObject pooledObject) { + return pooledObject.getObject().isReady(); + } + } +} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/cluster/NodeType.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/cluster/NodeType.java index ff3fbfdada5b..65db12fb47cf 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/cluster/NodeType.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/cluster/NodeType.java @@ -21,7 +21,8 @@ public enum NodeType { ConfigNode("ConfigNode"), - DataNode("DataNode"); + DataNode("DataNode"), + AINode("AINode"); private final String nodeType; diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java index e43ae6e667ef..e731f50628c5 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java @@ -177,6 +177,7 @@ public enum ThreadName { INFLUXDB_RPC_SERVICE("InfluxdbRPC-Service"), INFLUXDB_RPC_PROCESSOR("InfluxdbRPC-Processor"), STORAGE_ENGINE_CACHED_POOL("StorageEngine"), + AINODE_RPC_SERVICE("AINodeRpc-Service"), IOTDB_SHUTDOWN_HOOK("IoTDB-Shutdown-Hook"), UPGRADE_TASK("UpgradeThread"), REGION_MIGRATE("Region-Migrate-Pool"), @@ -363,6 +364,7 @@ public enum ThreadName { INFLUXDB_RPC_SERVICE, INFLUXDB_RPC_PROCESSOR, STORAGE_ENGINE_CACHED_POOL, + AINODE_RPC_SERVICE, IOTDB_SHUTDOWN_HOOK, UPGRADE_TASK, REGION_MIGRATE, diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java index e4c9aca13a07..d3dba643e26a 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java @@ -19,6 +19,7 @@ package org.apache.iotdb.commons.conf; +import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.commons.client.property.ClientPoolProperty.DefaultProperty; import org.apache.iotdb.commons.cluster.NodeStatus; import org.apache.iotdb.commons.enums.HandleSystemErrorStrategy; @@ -151,6 +152,9 @@ public class CommonConfig { /** Status of current system. */ private volatile NodeStatus status = NodeStatus.Running; + private NodeStatus lastStatus = NodeStatus.Unknown; + private String lastStatusReason = ""; + private volatile boolean isStopping = false; private volatile String statusReason = null; @@ -160,6 +164,9 @@ public class CommonConfig { /** Disk Monitor. */ private double diskSpaceWarningThreshold = 0.05; + /** Ip and port of target AI node. */ + private TEndPoint targetAINodeEndPoint = new TEndPoint("127.0.0.1", 10810); + /** Time partition origin in milliseconds. */ private long timePartitionOrigin = 0; @@ -171,6 +178,9 @@ public class CommonConfig { private boolean timestampPrecisionCheckEnabled = true; + /** The number of threads in the thread pool that execute model inference tasks. */ + private int modelInferenceExecutionThreadCount = 5; + /** * The name of the directory that stores the tsfiles temporarily hold or generated by the pipe * module. The directory is located in the data directory of IoTDB. @@ -553,6 +563,14 @@ public void setStatus(NodeStatus status) { this.status = status; } + public TEndPoint getTargetAINodeEndPoint() { + return targetAINodeEndPoint; + } + + public void setTargetAINodeEndPoint(TEndPoint targetAINodeEndPoint) { + this.targetAINodeEndPoint = targetAINodeEndPoint; + } + public int getTTimePartitionSlotTransmitLimit() { return TTimePartitionSlotTransmitLimit; } @@ -1297,6 +1315,14 @@ public void setDatabaseLimitThreshold(int databaseLimitThreshold) { this.databaseLimitThreshold = databaseLimitThreshold; } + public int getModelInferenceExecutionThreadCount() { + return modelInferenceExecutionThreadCount; + } + + public void setModelInferenceExecutionThreadCount(int modelInferenceExecutionThreadCount) { + this.modelInferenceExecutionThreadCount = modelInferenceExecutionThreadCount; + } + public long getDatanodeTokenTimeoutMS() { return datanodeTokenTimeoutMS; } diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/exception/ainode/LoadModelException.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/exception/ainode/LoadModelException.java new file mode 100644 index 000000000000..dd80901630d1 --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/exception/ainode/LoadModelException.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.exception.ainode; + +import org.apache.iotdb.commons.exception.IoTDBException; + +public class LoadModelException extends IoTDBException { + public LoadModelException(String message, int errorCode) { + super(message, errorCode); + } +} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelHyperparameter.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelHyperparameter.java new file mode 100644 index 000000000000..5f06f2dcabd3 --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelHyperparameter.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.model; + +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class ModelHyperparameter { + + private final Map keyValueMap; + + public ModelHyperparameter(Map keyValueMap) { + this.keyValueMap = keyValueMap; + } + + public void update(Map modelInfo) { + this.keyValueMap.putAll(modelInfo); + } + + @Override + public String toString() { + StringBuilder stringBuilder = new StringBuilder(); + for (Map.Entry keyValuePair : keyValueMap.entrySet()) { + stringBuilder + .append(keyValuePair.getKey()) + .append('=') + .append(keyValuePair.getValue()) + .append('\n'); + } + return stringBuilder.toString(); + } + + public List toStringList() { + List resultList = new ArrayList<>(); + for (Map.Entry keyValuePair : keyValueMap.entrySet()) { + resultList.add(keyValuePair.getKey() + "=" + keyValuePair.getValue()); + } + return resultList; + } + + public void serialize(DataOutputStream stream) throws IOException { + ReadWriteIOUtils.write(keyValueMap, stream); + } + + public void serialize(FileOutputStream stream) throws IOException { + ReadWriteIOUtils.write(keyValueMap, stream); + } + + public static ModelHyperparameter deserialize(ByteBuffer buffer) { + return new ModelHyperparameter(ReadWriteIOUtils.readMap(buffer)); + } + + public static ModelHyperparameter deserialize(InputStream stream) throws IOException { + return new ModelHyperparameter(ReadWriteIOUtils.readMap(stream)); + } +} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java new file mode 100644 index 000000000000..0582989e00c2 --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.model; + +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.utils.PublicBAOS; +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Arrays; + +public class ModelInformation { + + ModelType modelType; + + private final String modelName; + + private final int[] inputShape; + + private final int[] outputShape; + + private TSDataType[] inputDataType; + + private TSDataType[] outputDataType; + + private ModelStatus status = ModelStatus.INACTIVE; + + String attribute = ""; + + public ModelInformation( + ModelType modelType, + String modelName, + int[] inputShape, + int[] outputShape, + TSDataType[] inputDataType, + TSDataType[] outputDataType, + String attribute, + ModelStatus status) { + this.modelType = modelType; + this.modelName = modelName; + this.inputShape = inputShape; + this.outputShape = outputShape; + this.inputDataType = inputDataType; + this.outputDataType = outputDataType; + this.attribute = attribute; + this.status = status; + } + + public ModelInformation( + String modelName, + int[] inputShape, + int[] outputShape, + TSDataType[] inputDataType, + TSDataType[] outputDataType, + String attribute) { + this.modelType = ModelType.USER_DEFINED; + this.modelName = modelName; + this.inputShape = inputShape; + this.outputShape = outputShape; + this.inputDataType = inputDataType; + this.outputDataType = outputDataType; + this.attribute = attribute; + } + + public ModelInformation(String modelName, ModelStatus status) { + this.modelType = ModelType.USER_DEFINED; + this.modelName = modelName; + this.inputShape = new int[0]; + this.outputShape = new int[0]; + this.outputDataType = new TSDataType[0]; + this.inputDataType = new TSDataType[0]; + this.status = status; + } + + // init built-in modelInformation + public ModelInformation(ModelType modelType, String modelName) { + this.modelType = modelType; + this.modelName = modelName; + this.inputShape = new int[2]; + this.outputShape = new int[2]; + this.inputDataType = new TSDataType[0]; + this.outputDataType = new TSDataType[0]; + this.status = ModelStatus.ACTIVE; + } + + public boolean isBuiltIn() { + return modelType != ModelType.USER_DEFINED; + } + + public boolean available() { + return status == ModelStatus.ACTIVE; + } + + public void updateStatus(ModelStatus status) { + this.status = status; + } + + public String getModelName() { + return modelName; + } + + // calculation modelType and outputColumn metadata for different built-in models + public void setInputColumnSize(int size) { + inputShape[1] = size; + if (modelType == ModelType.BUILT_IN_FORECAST) { + outputShape[1] = size; + } else if (modelType == ModelType.BUILT_IN_ANOMALY_DETECTION) { + outputShape[1] = 1; + } + if (modelType == ModelType.BUILT_IN_FORECAST) { + buildOutputDataTypeForBuiltInModel(TSDataType.DOUBLE, outputShape[1]); + } else if (modelType == ModelType.BUILT_IN_ANOMALY_DETECTION) { + buildOutputDataTypeForBuiltInModel(TSDataType.INT32, outputShape[1]); + } + } + + public void setInputDataType(TSDataType[] inputDataType) { + this.inputDataType = inputDataType; + } + + private void buildOutputDataTypeForBuiltInModel(TSDataType tsDataType, int num) { + outputDataType = new TSDataType[num]; + for (int i = 0; i < num; i++) { + outputDataType[i] = tsDataType; + } + } + + public int[] getInputShape() { + return inputShape; + } + + public int[] getOutputShape() { + return outputShape; + } + + public TSDataType[] getInputDataType() { + return inputDataType; + } + + public TSDataType[] getOutputDataType() { + return outputDataType; + } + + public ModelStatus getStatus() { + return status; + } + + public String getAttribute() { + return attribute; + } + + public void setAttribute(String attribute) { + this.attribute = attribute; + } + + public void serialize(DataOutputStream stream) throws IOException { + ReadWriteIOUtils.write(modelType.ordinal(), stream); + ReadWriteIOUtils.write(status.ordinal(), stream); + ReadWriteIOUtils.write(modelName, stream); + if (status == ModelStatus.UNAVAILABLE) { + return; + } + + for (int shape : inputShape) { + ReadWriteIOUtils.write(shape, stream); + } + for (int shape : outputShape) { + ReadWriteIOUtils.write(shape, stream); + } + + for (TSDataType type : inputDataType) { + ReadWriteIOUtils.write(type.serialize(), stream); + } + for (TSDataType type : outputDataType) { + ReadWriteIOUtils.write(type.serialize(), stream); + } + + ReadWriteIOUtils.write(attribute, stream); + } + + public void serialize(FileOutputStream stream) throws IOException { + ReadWriteIOUtils.write(modelType.ordinal(), stream); + ReadWriteIOUtils.write(status.ordinal(), stream); + ReadWriteIOUtils.write(modelName, stream); + if (status == ModelStatus.UNAVAILABLE) { + return; + } + + for (int shape : inputShape) { + ReadWriteIOUtils.write(shape, stream); + } + for (int shape : outputShape) { + ReadWriteIOUtils.write(shape, stream); + } + + for (TSDataType type : inputDataType) { + ReadWriteIOUtils.write(type.serialize(), stream); + } + for (TSDataType type : outputDataType) { + ReadWriteIOUtils.write(type.serialize(), stream); + } + + ReadWriteIOUtils.write(attribute, stream); + } + + public void serialize(ByteBuffer byteBuffer) { + ReadWriteIOUtils.write(modelType.ordinal(), byteBuffer); + ReadWriteIOUtils.write(status.ordinal(), byteBuffer); + ReadWriteIOUtils.write(modelName, byteBuffer); + if (status == ModelStatus.UNAVAILABLE) { + return; + } + + for (int shape : inputShape) { + ReadWriteIOUtils.write(shape, byteBuffer); + } + for (int shape : outputShape) { + ReadWriteIOUtils.write(shape, byteBuffer); + } + + for (TSDataType type : inputDataType) { + ReadWriteIOUtils.write(type.serialize(), byteBuffer); + } + for (TSDataType type : outputDataType) { + ReadWriteIOUtils.write(type.serialize(), byteBuffer); + } + + ReadWriteIOUtils.write(attribute, byteBuffer); + } + + public static ModelInformation deserialize(ByteBuffer buffer) { + ModelType modelType = ModelType.values()[ReadWriteIOUtils.readInt(buffer)]; + ModelStatus status = ModelStatus.values()[ReadWriteIOUtils.readInt(buffer)]; + String modelName = ReadWriteIOUtils.readString(buffer); + if (status == ModelStatus.UNAVAILABLE) { + return new ModelInformation(modelName, status); + } + + int[] inputShape = new int[2]; + for (int i = 0; i < inputShape.length; i++) { + inputShape[i] = ReadWriteIOUtils.readInt(buffer); + } + + int[] outputShape = new int[2]; + for (int i = 0; i < outputShape.length; i++) { + outputShape[i] = ReadWriteIOUtils.readInt(buffer); + } + + TSDataType[] inputDataType = new TSDataType[inputShape[1]]; + for (int i = 0; i < inputDataType.length; i++) { + inputDataType[i] = TSDataType.deserializeFrom(buffer); + } + + TSDataType[] outputDataType = new TSDataType[outputShape[1]]; + for (int i = 0; i < outputDataType.length; i++) { + outputDataType[i] = TSDataType.deserializeFrom(buffer); + } + + String attribute = ReadWriteIOUtils.readString(buffer); + + return new ModelInformation( + modelType, + modelName, + inputShape, + outputShape, + inputDataType, + outputDataType, + attribute, + status); + } + + public static ModelInformation deserialize(InputStream stream) throws IOException { + ModelType modelType = ModelType.values()[ReadWriteIOUtils.readInt(stream)]; + ModelStatus status = ModelStatus.values()[ReadWriteIOUtils.readInt(stream)]; + String modelName = ReadWriteIOUtils.readString(stream); + if (status == ModelStatus.UNAVAILABLE) { + return new ModelInformation(modelName, status); + } + + int[] inputShape = new int[2]; + for (int i = 0; i < inputShape.length; i++) { + inputShape[i] = ReadWriteIOUtils.readInt(stream); + } + + int[] outputShape = new int[2]; + for (int i = 0; i < outputShape.length; i++) { + outputShape[i] = ReadWriteIOUtils.readInt(stream); + } + + TSDataType[] inputDataType = new TSDataType[inputShape[1]]; + for (int i = 0; i < inputDataType.length; i++) { + inputDataType[i] = TSDataType.deserializeFrom(stream); + } + + TSDataType[] outputDataType = new TSDataType[outputShape[1]]; + for (int i = 0; i < outputDataType.length; i++) { + outputDataType[i] = TSDataType.deserializeFrom(stream); + } + + String attribute = ReadWriteIOUtils.readString(stream); + return new ModelInformation( + modelType, + modelName, + inputShape, + outputShape, + inputDataType, + outputDataType, + attribute, + status); + } + + public ByteBuffer serializeShowModelResult() throws IOException { + PublicBAOS buffer = new PublicBAOS(); + DataOutputStream stream = new DataOutputStream(buffer); + ReadWriteIOUtils.write(modelName, stream); + ReadWriteIOUtils.write(modelType.toString(), stream); + ReadWriteIOUtils.write(status.toString(), stream); + ReadWriteIOUtils.write(Arrays.toString(inputShape), stream); + ReadWriteIOUtils.write(Arrays.toString(outputShape), stream); + ReadWriteIOUtils.write(Arrays.toString(inputDataType), stream); + ReadWriteIOUtils.write(Arrays.toString(outputDataType), stream); + ReadWriteIOUtils.write(attribute, stream); + // add extra blank line to make the result more readable in cli + ReadWriteIOUtils.write(" ", stream); + return ByteBuffer.wrap(buffer.getBuf(), 0, buffer.size()); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof ModelInformation) { + ModelInformation other = (ModelInformation) obj; + return modelName.equals(other.modelName) + && modelType.equals(other.modelType) + && Arrays.equals(inputShape, other.inputShape) + && Arrays.equals(outputShape, other.outputShape) + && Arrays.equals(inputDataType, other.inputDataType) + && Arrays.equals(outputDataType, other.outputDataType) + && status.equals(other.status) + && attribute.equals(other.attribute); + } + return false; + } +} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelStatus.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelStatus.java new file mode 100644 index 000000000000..7aac33dac23a --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelStatus.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.model; + +public enum ModelStatus { + INACTIVE, + LOADING, + ACTIVE, + DROPPING, + UNAVAILABLE +} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java new file mode 100644 index 000000000000..64aff12f284e --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.model; + +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class ModelTable { + + private final Map modelInfoMap; + + public ModelTable() { + this.modelInfoMap = new ConcurrentHashMap<>(); + } + + public boolean containsModel(String modelId) { + return modelInfoMap.containsKey(modelId); + } + + public void addModel(ModelInformation modelInformation) { + modelInfoMap.put(modelInformation.getModelName(), modelInformation); + } + + public void removeModel(String modelId) { + modelInfoMap.remove(modelId); + } + + public List getAllModelInformation() { + return new ArrayList<>(modelInfoMap.values()); + } + + public ModelInformation getModelInformationById(String modelId) { + if (modelInfoMap.containsKey(modelId)) { + return modelInfoMap.get(modelId); + } + return null; + } + + public void clearFailedModel() { + for (ModelInformation modelInformation : modelInfoMap.values()) { + if (modelInformation.getStatus() == ModelStatus.UNAVAILABLE) { + modelInfoMap.remove(modelInformation.getModelName()); + } + } + } + + public void updateModel(String modelName, ModelInformation modelInfo) { + modelInfoMap.replace(modelName, modelInfo); + } + + public void clear() { + modelInfoMap.clear(); + } + + public void serialize(FileOutputStream stream) throws IOException { + ReadWriteIOUtils.write(modelInfoMap.size(), stream); + for (ModelInformation entry : modelInfoMap.values()) { + entry.serialize(stream); + } + } + + public static ModelTable deserialize(InputStream stream) throws IOException { + ModelTable modelTable = new ModelTable(); + int size = ReadWriteIOUtils.readInt(stream); + for (int i = 0; i < size; i++) { + modelTable.addModel(ModelInformation.deserialize(stream)); + } + return modelTable; + } +} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelType.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelType.java new file mode 100644 index 000000000000..cd154748f237 --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelType.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.model; + +public enum ModelType { + BUILT_IN_FORECAST, + BUILT_IN_ANOMALY_DETECTION, + USER_DEFINED +} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/exception/ModelManagementException.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/exception/ModelManagementException.java new file mode 100644 index 000000000000..6f7ded0ed422 --- /dev/null +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/exception/ModelManagementException.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.commons.model.exception; + +public class ModelManagementException extends RuntimeException { + public ModelManagementException(String message) { + super(message); + } +} diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/service/ServiceType.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/service/ServiceType.java index 24a1a915325e..2898e14d987c 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/service/ServiceType.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/service/ServiceType.java @@ -79,6 +79,7 @@ public enum ServiceType { PIPE_CONSENSUS_SERVICE("PipeConsensus Service", "PipeConsensusRPCService"), PIPE_PLUGIN_CLASSLOADER_MANAGER_SERVICE( "Pipe Plugin Classloader Manager Service", "PipePluginClassLoader"), + AINode_RPC_SERVICE("Rpc Service for AINode", "AINodeRPCService"), PIPE_RUNTIME_DATA_NODE_AGENT("Pipe Runtime Data Node Agent", "PipeRuntimeDataNodeAgent"), PIPE_RUNTIME_CONFIG_NODE_AGENT("Pipe Runtime Config Node Agent", "PipeRuntimeConfigNodeAgent"), SUBSCRIPTION_RUNTIME_AGENT("Subscription Runtime Agent", "SubscriptionRuntimeAgent"), diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/utils/ThriftCommonsSerDeUtils.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/utils/ThriftCommonsSerDeUtils.java index 56f98364d3aa..5d336febbde6 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/utils/ThriftCommonsSerDeUtils.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/utils/ThriftCommonsSerDeUtils.java @@ -19,6 +19,8 @@ package org.apache.iotdb.commons.utils; +import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; +import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId; import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; @@ -286,4 +288,61 @@ private static ConfigurableTByteBuffer generateTByteBuffer(ByteBuffer buffer) throws TTransportException { return new ConfigurableTByteBuffer(buffer, defaultTConfiguration); } + + public static void serializeTAINodeInfo( + TAINodeConfiguration aiNodeInfo, DataOutputStream stream) { + try { + aiNodeInfo.write(generateWriteProtocol(stream)); + } catch (TException e) { + throw new ThriftSerDeException("Write TAINodeInfo failed: ", e); + } + } + + public static TAINodeConfiguration deserializeTAINodeInfo(ByteBuffer buffer) { + TAINodeConfiguration aiNodeInfo = new TAINodeConfiguration(); + try { + aiNodeInfo.read(generateReadProtocol(buffer)); + } catch (TException e) { + throw new ThriftSerDeException("Read TAINodeInfo failed: ", e); + } + return aiNodeInfo; + } + + public static void serializeTAINodeConfiguration( + TAINodeConfiguration aiNodeConfiguration, DataOutputStream stream) { + try { + aiNodeConfiguration.write(generateWriteProtocol(stream)); + } catch (TException e) { + throw new ThriftSerDeException("Write TDataNodeConfiguration failed: ", e); + } + } + + public static TAINodeConfiguration deserializeTAINodeConfiguration(ByteBuffer buffer) { + TAINodeConfiguration aiNodeConfiguration = new TAINodeConfiguration(); + try { + aiNodeConfiguration.read(generateReadProtocol(buffer)); + } catch (TException e) { + throw new ThriftSerDeException("Read TAINodeConfiguration failed: ", e); + } + return aiNodeConfiguration; + } + + public static void serializeTAINodeLocation( + TAINodeLocation aiNodeLocation, DataOutputStream stream) { + try { + aiNodeLocation.write(generateWriteProtocol(stream)); + } catch (TException e) { + throw new ThriftSerDeException("Write TAINodeLocation failed: ", e); + } + } + + public static TAINodeLocation deserializeTAINodeLocation(ByteBuffer buffer) { + TAINodeLocation aiNodeLocation = new TAINodeLocation(); + try { + aiNodeLocation.read(generateReadProtocol(buffer)); + } catch (TException e) { + throw new ThriftSerDeException("Read TDataNodeLocation failed: ", e); + } + return aiNodeLocation; + } } diff --git a/iotdb-core/pom.xml b/iotdb-core/pom.xml index ec1880559258..518be40f82fe 100644 --- a/iotdb-core/pom.xml +++ b/iotdb-core/pom.xml @@ -37,4 +37,12 @@ metrics node-commons + + + with-ainode + + ainode + + + diff --git a/iotdb-protocol/pom.xml b/iotdb-protocol/pom.xml index 286a2a26e805..fec72bac2923 100644 --- a/iotdb-protocol/pom.xml +++ b/iotdb-protocol/pom.xml @@ -35,5 +35,6 @@ thrift-commons thrift-confignode thrift-consensus + thrift-ainode diff --git a/iotdb-protocol/thrift-ainode/pom.xml b/iotdb-protocol/thrift-ainode/pom.xml new file mode 100644 index 000000000000..4d01348674a8 --- /dev/null +++ b/iotdb-protocol/thrift-ainode/pom.xml @@ -0,0 +1,69 @@ + + + + 4.0.0 + + org.apache.iotdb + iotdb-protocol + 1.3.3-SNAPSHOT + + iotdb-thrift-ainode + IoTDB: Protocol: Thrift AI Node + RPC (Thrift) framework among AINodes. + + + org.slf4j + slf4j-api + + + org.apache.thrift + libthrift + + + org.apache.iotdb + iotdb-thrift-commons + 1.3.3-SNAPSHOT + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-source + + add-source + + generate-sources + + + ${project.build.directory}/generated-sources/thrift + + + + + + + + diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift new file mode 100644 index 000000000000..b3e8a67b8cc9 --- /dev/null +++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +include "common.thrift" +namespace java org.apache.iotdb.ainode.rpc.thrift +namespace py iotdb.thrift.ainode + +struct TDeleteModelReq { + 1: required string modelId +} + +struct TAIHeartbeatReq{ + 1: required i64 heartbeatTimestamp + 2: required bool needSamplingLoad +} + +struct TAIHeartbeatResp{ + 1: required i64 heartbeatTimestamp + 2: required string status + 3: optional string statusReason + 4: optional common.TLoadSample loadSample +} + +struct TRegisterModelReq { + 1: required string uri + 2: required string modelId +} + +struct TConfigs { + 1: required list input_shape + 2: required list output_shape + 3: required list input_type + 4: required list output_type +} + +struct TRegisterModelResp { + 1: required common.TSStatus status + 2: optional TConfigs configs + 3: optional string attributes +} + +struct TInferenceReq { + 1: required string modelId + 2: required binary dataset + 3: required list typeList + 4: required list columnNameList + 5: required map columnNameIndexMap + 6: optional TWindowParams windowParams + 7: optional map inferenceAttributes +} + +struct TWindowParams { + 1: required i32 windowInterval + 2: required i32 windowStep +} + +struct TInferenceResp { + 1: required common.TSStatus status + 2: required list inferenceResult +} + +service IAINodeRPCService { + + // -------------- For Config Node -------------- + + common.TSStatus deleteModel(TDeleteModelReq req) + + TRegisterModelResp registerModel(TRegisterModelReq req) + + TAIHeartbeatResp getAIHeartbeat(TAIHeartbeatReq req) + + // -------------- For Data Node -------------- + + TInferenceResp inference(TInferenceReq req) +} \ No newline at end of file diff --git a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift index 6df74bc73e9d..2b4b67cf5b91 100644 --- a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift +++ b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift @@ -85,11 +85,21 @@ struct TDataNodeLocation { 6: required TEndPoint schemaRegionConsensusEndPoint } +struct TAINodeLocation{ + 1: required i32 aiNodeId + 2: required TEndPoint internalEndPoint +} + struct TDataNodeConfiguration { 1: required TDataNodeLocation location 2: required TNodeResource resource } +struct TAINodeConfiguration{ + 1: required TAINodeLocation location + 2: required TNodeResource resource +} + enum TRegionMigrateFailedType { AddPeerFailed, RemovePeerFailed, @@ -196,6 +206,18 @@ struct TLicense { 9: required i16 mlNodeNumLimit } +struct TLoadSample { + // Percentage of occupied cpu in Node + 1: required double cpuUsageRate + // Percentage of occupied memory space in Node + 2: required double memoryUsageRate + // Percentage of occupied disk space in Node + 3: required double diskUsageRate + // The size of free disk space + // Unit: Byte + 4: required double freeDiskSpace +} + enum TServiceType { ConfigNodeInternalService, DataNodeInternalService, @@ -266,3 +288,11 @@ struct TShowConfigurationResp { 2: required string content } +// for AINode +enum TrainingState { + PENDING, + RUNNING, + FINISHED, + FAILED, + DROPPING +} diff --git a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift index 7044bd0b2213..65a859602468 100644 --- a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift +++ b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift @@ -513,8 +513,9 @@ struct TShowClusterResp { 1: required common.TSStatus status 2: required list configNodeList 3: required list dataNodeList - 4: required map nodeStatus - 5: required map nodeVersionInfo + 4: required list aiNodeList + 5: required map nodeStatus + 6: required map nodeVersionInfo } struct TGetClusterIdResp { @@ -547,11 +548,23 @@ struct TDataNodeInfo { 7: optional i32 cpuCoreNum } +struct TAINodeInfo{ + 1: required i32 aiNodeId + 2: required string status + 3: required string internalAddress + 4: required i32 internalPort +} + struct TShowDataNodesResp { 1: required common.TSStatus status 2: optional list dataNodesInfoList } +struct TShowAINodesResp { + 1: required common.TSStatus status + 2: optional list aiNodesInfoList +} + // Show confignodes struct TConfigNodeInfo { 1: required i32 configNodeId @@ -907,6 +920,34 @@ struct TUnsetSchemaTemplateReq { 4: optional bool isGeneratedByPipe } +struct TCreateModelReq { + 1: required string modelName + 2: required string uri +} + +struct TDropModelReq { + 1: required string modelId +} + +struct TShowModelReq { + 1: optional string modelId +} + +struct TShowModelResp { + 1: required common.TSStatus status + 2: required list modelInfoList +} + +struct TGetModelInfoReq { + 1: required string modelId +} + +struct TGetModelInfoResp { + 1: required common.TSStatus status + 2: optional binary modelInfo + 3: optional common.TEndPoint aiNodeAddress +} + // ==================================================== // Quota // ==================================================== @@ -937,6 +978,43 @@ enum TActivationControl { ALL_LICENSE_FILE_DELETED } +// ==================================================== +// AINode +// ==================================================== + +struct TAINodeConfigurationResp { + 1: required common.TSStatus status + 2: optional map aiNodeConfigurationMap +} + + +struct TAINodeRegisterReq{ + 1: required string clusterName + 2: required common.TAINodeConfiguration aiNodeConfiguration + 3: optional TNodeVersionInfo versionInfo +} + +struct TAINodeRegisterResp{ + 1: required common.TSStatus status + 2: required list configNodeList + 3: optional i32 aiNodeId +} + +struct TAINodeRestartReq{ + 1: required string clusterName + 2: required common.TAINodeConfiguration aiNodeConfiguration + 3: optional TNodeVersionInfo versionInfo +} + +struct TAINodeRestartResp{ + 1: required common.TSStatus status + 2: required list configNodeList +} + +struct TAINodeRemoveReq{ + 1: required common.TAINodeLocation aiNodeLocation +} + // ==================================================== // Test only // ==================================================== @@ -978,6 +1056,24 @@ service IConfigNodeRPCService { */ TDataNodeRestartResp restartDataNode(TDataNodeRestartReq req) + + // ====================================================== + // AINode + // ====================================================== + + /** + * node management for ainode, it's similar to datanode above + */ + TAINodeRegisterResp registerAINode(TAINodeRegisterReq req) + + TAINodeRestartResp restartAINode(TAINodeRestartReq req) + + common.TSStatus removeAINode(TAINodeRemoveReq req) + + TShowAINodesResp showAINodes() + + TAINodeConfigurationResp getAINodeConfiguration(i32 aiNodeId) + /** * Get system configurations. i.e. configurations that is not associated with the DataNodeId */ @@ -1594,6 +1690,34 @@ service IConfigNodeRPCService { */ TShowCQResp showCQ() + // ==================================================== + // AI Model + // ==================================================== + + /** + * Create a model + * + * @return SUCCESS_STATUS if the model was created successfully + */ + common.TSStatus createModel(TCreateModelReq req) + + /** + * Drop a model + * + * @return SUCCESS_STATUS if the model was removed successfully + */ + common.TSStatus dropModel(TDropModelReq req) + + /** + * Return the model table + */ + TShowModelResp showModel(TShowModelReq req) + + /** + * Return the model info by model_id + */ + TGetModelInfoResp getModelInfo(TGetModelInfoReq req) + // ====================================================== // Quota // ====================================================== diff --git a/iotdb-protocol/thrift-datanode/src/main/thrift/datanode.thrift b/iotdb-protocol/thrift-datanode/src/main/thrift/datanode.thrift index 90ee7818435a..76545f530a70 100644 --- a/iotdb-protocol/thrift-datanode/src/main/thrift/datanode.thrift +++ b/iotdb-protocol/thrift-datanode/src/main/thrift/datanode.thrift @@ -284,7 +284,7 @@ struct TDataNodeHeartbeatResp { 2: required string status 3: optional string statusReason 4: optional map judgedLeaders - 5: optional TLoadSample loadSample + 5: optional common.TLoadSample loadSample 6: optional map regionSeriesUsageMap 7: optional map regionDeviceUsageMap 8: optional map regionDisk @@ -315,18 +315,6 @@ enum TSchemaLimitLevel{ TIMESERIES } -struct TLoadSample { - // Percentage of occupied cpu in DataNode - 1: required double cpuUsageRate - // Percentage of occupied memory space in DataNode - 2: required double memoryUsageRate - // Percentage of occupied disk space in DataNode - 3: required double diskUsageRate - // The size of free disk space - // Unit: Byte - 4: required double freeDiskSpace -} - struct TRegionRouteReq { 1: required i64 timestamp 2: required map regionRouteMap @@ -543,6 +531,37 @@ struct TExecuteCQ { 7: required string username } +// ==================================================== +// AI Node +// ==================================================== + +struct TFetchMoreDataReq{ + 1: required i64 queryId + 2: optional i64 timeout + 3: optional i32 fetchSize +} + +struct TFetchMoreDataResp{ + 1: required common.TSStatus status + 2: optional list tsDataset + 3: optional bool hasMoreData +} + +struct TFetchTimeseriesReq { + 1: required string queryBody + 2: optional i32 fetchSize + 3: optional i64 timeout +} + +struct TFetchTimeseriesResp { + 1: required common.TSStatus status + 2: optional i64 queryId + 3: optional list columnNameList + 4: optional list columnTypeList + 5: optional map columnNameIndexMap + 6: optional list tsDataset + 7: optional bool hasMoreData +} /** * BEGIN: Used for EXPLAIN ANALYZE **/ @@ -1029,3 +1048,16 @@ service MPPDataExchangeService { /** Empty rpc, only for connection test */ common.TSStatus testConnectionEmptyRPC() } + +service IAINodeInternalRPCService{ + /** + * Fecth the data of the specified time series + */ + TFetchTimeseriesResp fetchTimeseries(TFetchTimeseriesReq req) + + /** + * Fetch rest data for a specified fetchTimeseries + */ + TFetchMoreDataResp fetchMoreData(TFetchMoreDataReq req) + +} diff --git a/pom.xml b/pom.xml index 52c706b98d7b..8e0df32992cc 100644 --- a/pom.xml +++ b/pom.xml @@ -1629,7 +1629,7 @@ generate-sources py - **/common.thrift,**/client.thrift,**/datanode.thrift,**/confignode.thrift + **/common.thrift,**/client.thrift,**/datanode.thrift,**/confignode.thrift,**/ainode.thrift ${project.build.directory}/generated-sources-python/