Skip to content

Commit

Permalink
branch-3.0: [fix](DECIMAL) error DECIMAL cat to BOOLEAN #44326 (#46276)
Browse files Browse the repository at this point in the history
Cherry-picked from #44326

Co-authored-by: Mryange <[email protected]>
  • Loading branch information
github-actions[bot] and Mryange authored Jan 14, 2025
1 parent 90c1e58 commit 0075a83
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 28 deletions.
19 changes: 12 additions & 7 deletions be/src/vec/data_types/data_type_decimal.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,15 +461,20 @@ void convert_from_decimals(RealTo* dst, const RealFrom* src, UInt32 precicion_fr
MaxFieldType multiplier = DataTypeDecimal<MaxFieldType>::get_scale_multiplier(scale_from);
FromDataType from_data_type(precicion_from, scale_from);
for (size_t i = 0; i < size; i++) {
auto tmp = static_cast<MaxFieldType>(src[i]).value / multiplier.value;
if constexpr (narrow_integral) {
if (tmp < min_result.value || tmp > max_result.value) {
THROW_DECIMAL_CONVERT_OVERFLOW_EXCEPTION(from_data_type.to_string(src[i]),
from_data_type.get_name(),
OrigToDataType {}.get_name());
// uint8_t now use as boolean in doris
if constexpr (std::is_same_v<RealTo, UInt8>) {
dst[i] = static_cast<MaxFieldType>(src[i]).value != 0;
} else {
auto tmp = static_cast<MaxFieldType>(src[i]).value / multiplier.value;
if constexpr (narrow_integral) {
if (tmp < min_result.value || tmp > max_result.value) {
THROW_DECIMAL_CONVERT_OVERFLOW_EXCEPTION(from_data_type.to_string(src[i]),
from_data_type.get_name(),
OrigToDataType {}.get_name());
}
}
dst[i] = tmp;
}
dst[i] = tmp;
}
}

Expand Down
43 changes: 28 additions & 15 deletions be/src/vec/functions/function_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,21 @@ struct ConvertImpl {
using FromFieldType = typename FromDataType::FieldType;
using ToFieldType = typename ToDataType::FieldType;

// `static_cast_set` is introduced to wrap `static_cast` and handle special cases.
// Doris uses `uint8` to represent boolean values internally.
// Directly `static_cast` to `uint8` can result in non-0/1 values,
// To address this, `static_cast_set` performs an additional check:
// For `uint8` types, it explicitly uses `static_cast<bool>` to ensure
// the result is either 0 or 1.
static void static_cast_set(ToFieldType& to, const FromFieldType& from) {
// uint8_t now use as boolean in doris
if constexpr (std::is_same_v<uint8_t, ToFieldType>) {
to = static_cast<bool>(from);
} else {
to = static_cast<ToFieldType>(from);
}
}

template <typename Additions = void*>
static Status execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count,
Expand Down Expand Up @@ -375,8 +390,9 @@ struct ConvertImpl {
} else if constexpr (IsDateTimeV2Type<ToDataType>) {
DataTypeDateTimeV2::cast_from_date(vec_from[i], vec_to[i]);
} else {
vec_to[i] =
reinterpret_cast<const VecDateTimeValue&>(vec_from[i]).to_int64();
static_cast_set(
vec_to[i],
reinterpret_cast<const VecDateTimeValue&>(vec_from[i]).to_int64());
}
}
} else if constexpr (IsTimeV2Type<FromDataType>) {
Expand Down Expand Up @@ -407,13 +423,16 @@ struct ConvertImpl {
}
} else {
if constexpr (IsDateTimeV2Type<FromDataType>) {
vec_to[i] = reinterpret_cast<const DateV2Value<DateTimeV2ValueType>&>(
vec_from[i])
.to_int64();
static_cast_set(
vec_to[i],
reinterpret_cast<const DateV2Value<DateTimeV2ValueType>&>(
vec_from[i])
.to_int64());
} else {
vec_to[i] = reinterpret_cast<const DateV2Value<DateV2ValueType>&>(
vec_from[i])
.to_int64();
static_cast_set(vec_to[i],
reinterpret_cast<const DateV2Value<DateV2ValueType>&>(
vec_from[i])
.to_int64());
}
}
}
Expand All @@ -440,16 +459,10 @@ struct ConvertImpl {
}
} else {
for (size_t i = 0; i < size; ++i) {
vec_to[i] = static_cast<ToFieldType>(vec_from[i]);
static_cast_set(vec_to[i], vec_from[i]);
}
}
}
// TODO: support boolean cast more reasonable
if constexpr (std::is_same_v<uint8_t, ToFieldType>) {
for (int i = 0; i < size; ++i) {
vec_to[i] = static_cast<bool>(vec_to[i]);
}
}

block.replace_by_position(result, std::move(col_to));
} else {
Expand Down
17 changes: 17 additions & 0 deletions regression-test/data/correctness/test_cast_decimalv3_as_bool.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select1 --
0.000 13131.213132100 0E-16
0.000 2131231.231000000 2.3323000E-9
3.141 0E-9 123123.2131231231322130

-- !select2 --
false true false
false true true
true false true

-- !select3 --
true 1 true false

-- !select3 --
true 1 true false

Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("test_cast_decimalv3_as_bool") {
sql """ DROP TABLE IF EXISTS cast_decimalv3_as_bool """
sql """
CREATE TABLE IF NOT EXISTS cast_decimalv3_as_bool (
`id` int(11) ,
`k1` decimalv3(9,3) ,
`k2` decimalv3(18,9) ,
`k3` decimalv3(38,16) ,
)
UNIQUE KEY(`id`)
DISTRIBUTED BY HASH(`id`) BUCKETS 10
PROPERTIES (
"enable_unique_key_merge_on_write" = "true",
"replication_num" = "1"
);
"""
sql """
set enable_nereids_planner=true,enable_fold_constant_by_be = false
"""
sql """
INSERT INTO cast_decimalv3_as_bool VALUES
(1,0.00001,13131.2131321,0.000000000000000000),
(2,0.00000,2131231.231,0.0000000023323),
(3,3.141414,0.0000000000,123123.213123123132213);
"""
qt_select1 """
select k1,k2,k3 from cast_decimalv3_as_bool order by id
"""
qt_select2 """
select cast(k1 as boolean), cast(k2 as boolean) , cast(k3 as boolean) from cast_decimalv3_as_bool order by id
"""
qt_select3"""
select cast(3.00001 as boolean), cast(cast(3.00001 as boolean) as int),cast(0.001 as boolean),cast(0.000 as boolean);
"""
qt_select3"""
select cast(cast(3.00001 as double)as boolean), cast(cast(cast(3.00001 as double) as boolean) as int),cast(cast(0.001 as double) as boolean),cast(cast(0.000 as double) as boolean);
"""
}
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,11 @@ suite("test_case_function_null", "query,p0,arrow_flight_sql") {
c2,
c1;
"""

// There is a behavior change. The 0.4cast boolean used to be 0 in the past, but now it has changed to 1.
// Therefore, we need to update the case accordingly.
qt_sql_case1 """
SELECT SUM(
CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.4716 AS BOOLEAN)))
CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.0 AS BOOLEAN)))
WHEN ((CAST('-1530390546' AS VARCHAR)) LIKE ('-1678299490'))
THEN (- (+ case_null2.c0))
WHEN CASE (NULL IN (NULL))
Expand All @@ -197,9 +198,10 @@ suite("test_case_function_null", "query,p0,arrow_flight_sql") {
END)
FROM case_null2;
"""

// There is a behavior change. The 0.4cast boolean used to be 0 in the past, but now it has changed to 1.
// Therefore, we need to update the case accordingly.
qt_sql_case2 """
SELECT SUM(CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.4716 AS BOOLEAN)))
SELECT SUM(CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.0 AS BOOLEAN)))
WHEN ((CAST('-1530390546' AS VARCHAR)) LIKE ('-1678299490'))
THEN (- (+ case_null2.c0))
END)
Expand All @@ -209,9 +211,11 @@ suite("test_case_function_null", "query,p0,arrow_flight_sql") {
sql "SET experimental_enable_nereids_planner=true"
sql "SET enable_fallback_to_original_planner=false"

// There is a behavior change. The 0.4cast boolean used to be 0 in the past, but now it has changed to 1.
// Therefore, we need to update the case accordingly.
qt_sql_case1 """
SELECT SUM(
CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.4716 AS BOOLEAN)))
CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.0 AS BOOLEAN)))
WHEN ((CAST('-1530390546' AS VARCHAR)) LIKE ('-1678299490'))
THEN (- (+ case_null2.c0))
WHEN CASE (NULL IN (NULL))
Expand All @@ -221,8 +225,10 @@ suite("test_case_function_null", "query,p0,arrow_flight_sql") {
FROM case_null2;
"""

// There is a behavior change. The 0.4cast boolean used to be 0 in the past, but now it has changed to 1.
// Therefore, we need to update the case accordingly.
qt_sql_case2 """
SELECT SUM(CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.4716 AS BOOLEAN)))
SELECT SUM(CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.0 AS BOOLEAN)))
WHEN ((CAST('-1530390546' AS VARCHAR)) LIKE ('-1678299490'))
THEN (- (+ case_null2.c0))
END)
Expand Down

0 comments on commit 0075a83

Please sign in to comment.