diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index fc8671c5aa8..2023a6eea38 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -226,13 +226,14 @@ const std::unordered_map scalar_func_map({ //{tipb::ScalarFuncSig::NEJson, "notEquals"}, {tipb::ScalarFuncSig::NEVectorFloat32, "notEquals"}, - //{tipb::ScalarFuncSig::NullEQInt, "cast"}, - //{tipb::ScalarFuncSig::NullEQReal, "cast"}, - //{tipb::ScalarFuncSig::NullEQString, "cast"}, - //{tipb::ScalarFuncSig::NullEQDecimal, "cast"}, - //{tipb::ScalarFuncSig::NullEQTime, "cast"}, - //{tipb::ScalarFuncSig::NullEQDuration, "cast"}, - //{tipb::ScalarFuncSig::NullEQJson, "cast"}, + {tipb::ScalarFuncSig::NullEQInt, "tidbNullEQ"}, + {tipb::ScalarFuncSig::NullEQReal, "tidbNullEQ"}, + {tipb::ScalarFuncSig::NullEQString, "tidbNullEQ"}, + {tipb::ScalarFuncSig::NullEQDecimal, "tidbNullEQ"}, + {tipb::ScalarFuncSig::NullEQTime, "tidbNullEQ"}, + {tipb::ScalarFuncSig::NullEQDuration, "tidbNullEQ"}, + //{tipb::ScalarFuncSig::NullEQJson, "tidbNullEQ"}, + {tipb::ScalarFuncSig::NullEQVectorFloat32, "tidbNullEQ"}, {tipb::ScalarFuncSig::PlusReal, "plus"}, {tipb::ScalarFuncSig::PlusDecimal, "plus"}, diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_tidb_null_eq_func.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_tidb_null_eq_func.cpp new file mode 100644 index 00000000000..594d7a08115 --- /dev/null +++ b/dbms/src/Flash/Coprocessor/tests/gtest_tidb_null_eq_func.cpp @@ -0,0 +1,56 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +namespace DB::tests +{ +TEST(TiDBNullEQFuncTest, DagUtilsMappedToTidbNullEQ) +{ + { + tipb::Expr expr; + expr.set_tp(tipb::ExprType::ScalarFunc); + expr.set_sig(tipb::ScalarFuncSig::NullEQInt); + + ASSERT_TRUE(isScalarFunctionExpr(expr)); + ASSERT_EQ(getFunctionName(expr), "tidbNullEQ"); + } + { + tipb::Expr expr; + expr.set_tp(tipb::ExprType::ScalarFunc); + expr.set_sig(tipb::ScalarFuncSig::NullEQString); + + ASSERT_TRUE(isScalarFunctionExpr(expr)); + ASSERT_EQ(getFunctionName(expr), "tidbNullEQ"); + } + { + tipb::Expr expr; + expr.set_tp(tipb::ExprType::ScalarFunc); + expr.set_sig(tipb::ScalarFuncSig::NullEQDecimal); + + ASSERT_TRUE(isScalarFunctionExpr(expr)); + ASSERT_EQ(getFunctionName(expr), "tidbNullEQ"); + } + { + tipb::Expr expr; + expr.set_tp(tipb::ExprType::ScalarFunc); + expr.set_sig(tipb::ScalarFuncSig::NullEQVectorFloat32); + + ASSERT_TRUE(isScalarFunctionExpr(expr)); + ASSERT_EQ(getFunctionName(expr), "tidbNullEQ"); + } +} + +} // namespace DB::tests diff --git a/dbms/src/Functions/FunctionsComparison.cpp b/dbms/src/Functions/FunctionsComparison.cpp index f29a1b2c9c3..7772c440e7a 100644 --- a/dbms/src/Functions/FunctionsComparison.cpp +++ b/dbms/src/Functions/FunctionsComparison.cpp @@ -14,12 +14,214 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include #include #include namespace DB { +namespace ErrorCodes +{ +extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +extern const int ILLEGAL_COLUMN; +extern const int LOGICAL_ERROR; +} // namespace ErrorCodes + +class FunctionTiDBNullEQ : public IFunction +{ +public: + static constexpr auto name = "tidbNullEQ"; + + static FunctionPtr create(const Context &) { return std::make_shared(); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool useDefaultImplementationForNulls() const override { return false; } + bool useDefaultImplementationForConstants() const override { return true; } + + void setCollator(const TiDB::TiDBCollatorPtr & collator_) override + { + collator = collator_; + equals_function->setCollator(collator_); + } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if (arguments.size() != 2) + throw Exception( + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Number of arguments for function {} doesn't match: passed {}, should be 2.", + getName(), + arguments.size()); + + /// `NULL <=> x` is always true/false (never NULL), even if `NULL` is represented as `Nothing`. + if (arguments[0]->onlyNull() || arguments[1]->onlyNull()) + return std::make_shared(); + + /// Use equals to validate that the input types are comparable. + /// Always return non-nullable UInt8 because `NULL <=> x` is always true/false (not NULL). + FunctionEquals().getReturnTypeImpl({removeNullable(arguments[0]), removeNullable(arguments[1])}); + return std::make_shared(); + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + const auto & left = block.getByPosition(arguments[0]); + const auto & right = block.getByPosition(arguments[1]); + + ColumnPtr left_col = left.column; + ColumnPtr right_col = right.column; + + const size_t rows = left_col->size(); + if (unlikely(right_col->size() != rows)) + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, + "Columns sizes are different in function {}: left {}, right {}.", + getName(), + rows, + right_col->size()); + + /// Fast path for always-NULL columns (Nullable(Nothing)). + /// `NULL <=> x` equals to `isNull(x)`; `NULL <=> NULL` is always 1. + if (left_col->onlyNull() || right_col->onlyNull()) + { + if (left_col->onlyNull() && right_col->onlyNull()) + { + block.getByPosition(result).column = ColumnUInt8::create(rows, 1); + return; + } + + const ColumnPtr & other_col = left_col->onlyNull() ? right_col : left_col; + if (other_col->isColumnNullable()) + { + const auto & other_nullmap = assert_cast(*other_col).getNullMapData(); + auto res_col = ColumnUInt8::create(); + auto & res_data = res_col->getData(); + res_data.assign(other_nullmap.begin(), other_nullmap.end()); + block.getByPosition(result).column = std::move(res_col); + } + else + { + block.getByPosition(result).column = ColumnUInt8::create(rows, 0); + } + return; + } + + auto unwrap_nullable_column = [rows](const ColumnPtr & col, ColumnPtr & nested_col, const NullMap *& nullmap) { + nested_col = col; + nullmap = nullptr; + + if (const auto * const_col = typeid_cast(col.get())) + { + const auto & data_col = const_col->getDataColumn(); + if (data_col.isColumnNullable()) + { + /// `ColumnConst(ColumnNullable(NULL))` is handled by the `onlyNull()` fast path above. + /// If we reach here, the nullable constant must be non-NULL, so there is no nullmap to apply. + const auto & nullable_col = assert_cast(data_col); + nested_col = ColumnConst::create(nullable_col.getNestedColumnPtr(), rows); + } + return; + } + + if (col->isColumnNullable()) + { + const auto & nullable_col = assert_cast(*col); + nested_col = nullable_col.getNestedColumnPtr(); + nullmap = &nullable_col.getNullMapData(); + } + }; + + ColumnPtr left_nested_col = left_col; + const NullMap * left_nullmap = nullptr; + unwrap_nullable_column(left_col, left_nested_col, left_nullmap); + + ColumnPtr right_nested_col = right_col; + const NullMap * right_nullmap = nullptr; + unwrap_nullable_column(right_col, right_nested_col, right_nullmap); + + /// Execute `equals` on nested columns. + Block temp_block; + temp_block.insert({left_nested_col, removeNullable(left.type), "a"}); + temp_block.insert({right_nested_col, removeNullable(right.type), "b"}); + temp_block.insert({nullptr, std::make_shared(), "res"}); + DefaultExecutable(equals_function).execute(temp_block, {0, 1}, 2); + + ColumnPtr eq_col = temp_block.getByPosition(2).column; + if (left_nullmap == nullptr && right_nullmap == nullptr) + { + block.getByPosition(result).column = std::move(eq_col); + return; + } + + if (ColumnPtr converted = eq_col->convertToFullColumnIfConst()) + eq_col = converted; + + /// Adjust for NULL values: + /// - both NULL => 1 + /// - one NULL => 0 + /// - no NULL => equals result + auto eq_mutable = (*std::move(eq_col)).mutate(); + auto * eq_vec_col = typeid_cast(eq_mutable.get()); + if (unlikely(eq_vec_col == nullptr)) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Unexpected result column type {} for equals inside {}.", + eq_mutable->getName(), + getName()); + + auto & res_data = eq_vec_col->getData(); + if (left_nullmap != nullptr && right_nullmap != nullptr) + { + const auto & left_data = *left_nullmap; + const auto & right_data = *right_nullmap; + for (size_t i = 0; i < rows; ++i) + { + const UInt8 left_is_null = left_data[i] != 0; + const UInt8 right_is_null = right_data[i] != 0; + + const UInt8 any_null = left_is_null | right_is_null; + const UInt8 both_null = left_is_null & right_is_null; + + /// Keep equals result when `any_null == 0`, otherwise override it to 0. + /// Finally, override to 1 when `both_null == 1`. + const auto eq = static_cast(res_data[i] != 0); + res_data[i] = (eq & static_cast(!any_null)) | both_null; + } + } + else if (left_nullmap != nullptr) + { + const auto & left_data = *left_nullmap; + for (size_t i = 0; i < rows; ++i) + { + const UInt8 left_is_null = left_data[i] != 0; + const auto eq = static_cast(res_data[i] != 0); + res_data[i] = eq & static_cast(!left_is_null); + } + } + else if (right_nullmap != nullptr) + { + const auto & right_data = *right_nullmap; + for (size_t i = 0; i < rows; ++i) + { + const UInt8 right_is_null = right_data[i] != 0; + const auto eq = static_cast(res_data[i] != 0); + res_data[i] = eq & static_cast(!right_is_null); + } + } + + block.getByPosition(result).column = std::move(eq_mutable); + } + +private: + TiDB::TiDBCollatorPtr collator = nullptr; + std::shared_ptr equals_function = std::make_shared(); +}; + void registerFunctionsComparison(FunctionFactory & factory) { factory.registerFunction(); @@ -33,6 +235,7 @@ void registerFunctionsComparison(FunctionFactory & factory) factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); + factory.registerFunction(); } template <> diff --git a/dbms/src/Functions/tests/gtest_tidb_null_eq.cpp b/dbms/src/Functions/tests/gtest_tidb_null_eq.cpp new file mode 100644 index 00000000000..2520f67791f --- /dev/null +++ b/dbms/src/Functions/tests/gtest_tidb_null_eq.cpp @@ -0,0 +1,124 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +namespace DB::tests +{ +class TestTiDBNullEQ : public DB::tests::FunctionTest +{ +}; + +TEST_F(TestTiDBNullEQ, Basic) +try +{ + auto a = createColumn({1, 2, 2}); + auto b = createColumn({1, 3, 2}); + auto res = executeFunction("tidbNullEQ", a, b); + ASSERT_EQ(res.type->getName(), "UInt8"); + ASSERT_COLUMN_EQ(createColumn({1, 0, 1}), res); +} +CATCH + +TEST_F(TestTiDBNullEQ, NullableInputs) +try +{ + auto a = createColumn>({1, std::nullopt, std::nullopt, 2}); + auto b = createColumn>({1, std::nullopt, 3, std::nullopt}); + auto res = executeFunction("tidbNullEQ", a, b); + ASSERT_EQ(res.type->getName(), "UInt8"); + ASSERT_COLUMN_EQ(createColumn({1, 1, 0, 0}), res); +} +CATCH + +TEST_F(TestTiDBNullEQ, OnlyNullColumns) +try +{ + auto a = createOnlyNullColumn(5); + auto b = createOnlyNullColumn(5); + auto res = executeFunction("tidbNullEQ", a, b); + ASSERT_EQ(res.type->getName(), "UInt8"); + ASSERT_COLUMN_EQ(createColumn({1, 1, 1, 1, 1}), res); +} +CATCH + +TEST_F(TestTiDBNullEQ, OneSideOnlyNull) +try +{ + auto a = createOnlyNullColumn(3); + auto b = createColumn>({1, std::nullopt, 3}); + auto res = executeFunction("tidbNullEQ", a, b); + ASSERT_EQ(res.type->getName(), "UInt8"); + ASSERT_COLUMN_EQ(createColumn({0, 1, 0}), res); +} +CATCH + +TEST_F(TestTiDBNullEQ, ConstOnlyNull) +try +{ + auto a = createOnlyNullColumnConst(4); + auto b = createConstColumn>(4, 1); + auto res = executeFunction("tidbNullEQ", a, b); + ASSERT_EQ(res.type->getName(), "UInt8"); + ASSERT_COLUMN_EQ(createConstColumn(4, 0), res); +} +CATCH + +TEST_F(TestTiDBNullEQ, ConstNullableNonNull) +try +{ + auto a = createConstColumn>(4, 1); + auto b = createColumn>({1, std::nullopt, 2, 1}); + auto res = executeFunction("tidbNullEQ", a, b); + ASSERT_EQ(res.type->getName(), "UInt8"); + ASSERT_COLUMN_EQ(createColumn({1, 0, 0, 1}), res); + + auto res2 = executeFunction("tidbNullEQ", b, a); + ASSERT_EQ(res2.type->getName(), "UInt8"); + ASSERT_COLUMN_EQ(createColumn({1, 0, 0, 1}), res2); +} +CATCH + +TEST_F(TestTiDBNullEQ, ConstNullableNull) +try +{ + auto a = createConstColumn>(4, std::nullopt); + auto b = createColumn>({1, std::nullopt, 2, std::nullopt}); + + auto res = executeFunction("tidbNullEQ", a, b); + ASSERT_EQ(res.type->getName(), "UInt8"); + ASSERT_COLUMN_EQ(createColumn({0, 1, 0, 1}), res); + + auto res2 = executeFunction("tidbNullEQ", b, a); + ASSERT_EQ(res2.type->getName(), "UInt8"); + ASSERT_COLUMN_EQ(createColumn({0, 1, 0, 1}), res2); +} +CATCH + +TEST_F(TestTiDBNullEQ, CollatorIsForwardedToEquals) +try +{ + auto a = createColumn>({"a", "A", std::nullopt}); + auto b = createColumn>({"A", "a", std::nullopt}); + + auto ci_collator = TiDB::ITiDBCollator::getCollator(TiDB::ITiDBCollator::UTF8MB4_GENERAL_CI); + ASSERT_COLUMN_EQ(createColumn({1, 1, 1}), executeFunction("tidbNullEQ", {a, b}, ci_collator)); + + auto bin_collator = TiDB::ITiDBCollator::getCollator(TiDB::ITiDBCollator::BINARY); + ASSERT_COLUMN_EQ(createColumn({0, 0, 1}), executeFunction("tidbNullEQ", {a, b}, bin_collator)); +} +CATCH + +} // namespace DB::tests diff --git a/dbms/src/Storages/DeltaMerge/Filter/NullEqual.h b/dbms/src/Storages/DeltaMerge/Filter/NullEqual.h new file mode 100644 index 00000000000..d7af6898a12 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Filter/NullEqual.h @@ -0,0 +1,58 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB::DM +{ + +class NullEqual : public ColCmpVal +{ +public: + NullEqual(const Attr & attr_, const Field & value_) + : ColCmpVal(attr_, value_) + {} + + String name() override { return "null_equal"; } + + RSResults roughCheck(size_t start_pack, size_t pack_count, const RSCheckParam & param) override + { + auto rs_index = getRSIndex(param, attr); + return rs_index ? rs_index->minmax->checkNullEqual(start_pack, pack_count, value, rs_index->type) + : RSResults(pack_count, RSResult::Some); + } + + ColumnRangePtr buildSets(const google::protobuf::RepeatedPtrField & index_infos) override + { + if (value.isNull()) + return UnsupportedColumnRange::create(); + if (auto set = IntegerSet::createValueSet(attr.type, {value}); set) + { + auto iter = std::find_if(index_infos.begin(), index_infos.end(), [&](const auto & info) { + return info.index_type() == tipb::ColumnarIndexType::TypeInverted + && info.inverted_query_info().column_id() == attr.col_id; + }); + if (iter != index_infos.end()) + return SingleColumnRange::create( + iter->inverted_query_info().column_id(), + iter->inverted_query_info().index_id(), + set); + } + return UnsupportedColumnRange::create(); + } +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Filter/RSOperator.cpp b/dbms/src/Storages/DeltaMerge/Filter/RSOperator.cpp index 1e57a9f8de8..adb9eccbdd9 100644 --- a/dbms/src/Storages/DeltaMerge/Filter/RSOperator.cpp +++ b/dbms/src/Storages/DeltaMerge/Filter/RSOperator.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -38,6 +39,7 @@ namespace DB::DM // clang-format off RSOperatorPtr createAnd(const RSOperators & children) { return std::make_shared(children); } RSOperatorPtr createEqual(const Attr & attr, const Field & value) { return std::make_shared(attr, value); } +RSOperatorPtr createNullEqual(const Attr & attr, const Field & value) { return std::make_shared(attr, value); } RSOperatorPtr createGreater(const Attr & attr, const Field & value) { return std::make_shared(attr, value); } RSOperatorPtr createGreaterEqual(const Attr & attr, const Field & value) { return std::make_shared(attr, value); } RSOperatorPtr createIn(const Attr & attr, const Fields & values) { return std::make_shared(attr, values); } diff --git a/dbms/src/Storages/DeltaMerge/Filter/RSOperator.h b/dbms/src/Storages/DeltaMerge/Filter/RSOperator.h index b1a98837836..0fd5c2fcd6a 100644 --- a/dbms/src/Storages/DeltaMerge/Filter/RSOperator.h +++ b/dbms/src/Storages/DeltaMerge/Filter/RSOperator.h @@ -180,6 +180,7 @@ RSOperatorPtr createOr(const RSOperators & children); RSOperatorPtr createAnd(const RSOperators & children); // compare RSOperatorPtr createEqual(const Attr & attr, const Field & value); +RSOperatorPtr createNullEqual(const Attr & attr, const Field & value); RSOperatorPtr createNotEqual(const Attr & attr, const Field & value); RSOperatorPtr createGreater(const Attr & attr, const Field & value); RSOperatorPtr createGreaterEqual(const Attr & attr, const Field & value); diff --git a/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp b/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp index b434c9446f5..f2c797f72f3 100644 --- a/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp +++ b/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp @@ -213,6 +213,10 @@ inline RSOperatorPtr parseTiCompareExpr( // { case FilterParser::RSFilterType::Equal: return createEqual(attr, values[0]); + case FilterParser::RSFilterType::NullEqual: + if (values[0].isNull()) + return createIsNull(attr); + return createNullEqual(attr, values[0]); case FilterParser::RSFilterType::NotEqual: return createNotEqual(attr, values[0]); case FilterParser::RSFilterType::Greater: @@ -302,6 +306,7 @@ RSOperatorPtr parseTiExpr( } case FilterParser::RSFilterType::Equal: + case FilterParser::RSFilterType::NullEqual: case FilterParser::RSFilterType::NotEqual: case FilterParser::RSFilterType::Greater: case FilterParser::RSFilterType::GreaterEqual: @@ -583,13 +588,14 @@ std::unordered_map FilterParser {tipb::ScalarFuncSig::NEDuration, FilterParser::RSFilterType::NotEqual}, {tipb::ScalarFuncSig::NEJson, FilterParser::RSFilterType::NotEqual}, - //{tipb::ScalarFuncSig::NullEQInt, "cast"}, - //{tipb::ScalarFuncSig::NullEQReal, "cast"}, - //{tipb::ScalarFuncSig::NullEQString, "cast"}, - //{tipb::ScalarFuncSig::NullEQDecimal, "cast"}, - //{tipb::ScalarFuncSig::NullEQTime, "cast"}, - //{tipb::ScalarFuncSig::NullEQDuration, "cast"}, - //{tipb::ScalarFuncSig::NullEQJson, "cast"}, + {tipb::ScalarFuncSig::NullEQInt, FilterParser::RSFilterType::NullEqual}, + {tipb::ScalarFuncSig::NullEQReal, FilterParser::RSFilterType::NullEqual}, + {tipb::ScalarFuncSig::NullEQString, FilterParser::RSFilterType::NullEqual}, + {tipb::ScalarFuncSig::NullEQDecimal, FilterParser::RSFilterType::NullEqual}, + {tipb::ScalarFuncSig::NullEQTime, FilterParser::RSFilterType::NullEqual}, + {tipb::ScalarFuncSig::NullEQDuration, FilterParser::RSFilterType::NullEqual}, + {tipb::ScalarFuncSig::NullEQJson, FilterParser::RSFilterType::NullEqual}, + {tipb::ScalarFuncSig::NullEQVectorFloat32, FilterParser::RSFilterType::NullEqual}, // {tipb::ScalarFuncSig::PlusReal, "plus"}, // {tipb::ScalarFuncSig::PlusDecimal, "plus"}, diff --git a/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.h b/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.h index 2752afd070b..a160baf2c6d 100644 --- a/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.h +++ b/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.h @@ -70,6 +70,7 @@ class FilterParser And, // compare Equal, + NullEqual, NotEqual, Greater, GreaterEqual, diff --git a/dbms/src/Storages/DeltaMerge/Index/MinMaxIndex.cpp b/dbms/src/Storages/DeltaMerge/Index/MinMaxIndex.cpp index 473decaaec3..e61a8fff9f7 100644 --- a/dbms/src/Storages/DeltaMerge/Index/MinMaxIndex.cpp +++ b/dbms/src/Storages/DeltaMerge/Index/MinMaxIndex.cpp @@ -292,10 +292,10 @@ RSResults MinMaxIndex::checkNullableIn( size_t pos = i * 2; size_t prev_offset = pos == 0 ? 0 : offsets[pos - 1]; // todo use StringRef instead of String - auto min = String(chars[prev_offset], offsets[pos] - prev_offset - 1); + auto min = String(reinterpret_cast(&chars[prev_offset]), offsets[pos] - prev_offset - 1); pos = i * 2 + 1; prev_offset = offsets[pos - 1]; - auto max = String(chars[prev_offset], offsets[pos] - prev_offset - 1); + auto max = String(reinterpret_cast(&chars[prev_offset]), offsets[pos] - prev_offset - 1); auto value_result = RoughCheck::CheckIn::check(values, type, min, max); results[i - start_pack] = addNullIfHasNull(value_result, i); } @@ -490,6 +490,134 @@ template RSResults MinMaxIndex::checkCmp( const Field & value, const DataTypePtr & type); +template +RSResults MinMaxIndex::checkNullableNullEqualImpl( + const DB::ColumnNullable & column_nullable, + const DB::ColumnUInt8 & null_map, + size_t start_pack, + size_t pack_count, + const Field & value, + const DataTypePtr & type) +{ + RSResults results(pack_count, RSResult::Some); + const auto & minmaxes_data = toColumnVectorData(column_nullable.getNestedColumnPtr()); + for (size_t i = start_pack; i < start_pack + pack_count; ++i) + { + if (details::minIsNull(null_map, i)) + { + if (has_null_marks[i] && !has_value_marks[i]) + results[i - start_pack] = RSResult::None; + continue; + } + + auto min = minmaxes_data[i * 2]; + auto max = minmaxes_data[i * 2 + 1]; + auto value_result = RoughCheck::CheckEqual::check(value, type, min, max); + if (has_null_marks[i] && value_result == RSResult::All) + results[i - start_pack] = RSResult::Some; + else + results[i - start_pack] = value_result; + } + return results; +} + +RSResults MinMaxIndex::checkNullableNullEqual( + size_t start_pack, + size_t pack_count, + const Field & value, + const DataTypePtr & type) +{ + const auto & column_nullable = static_cast(*minmaxes); + const auto & null_map = column_nullable.getNullMapColumn(); + const auto * raw_type = type.get(); + +#define DISPATCH(TYPE) \ + if (typeid_cast(raw_type)) \ + return checkNullableNullEqualImpl(column_nullable, null_map, start_pack, pack_count, value, type); + FOR_NUMERIC_TYPES(DISPATCH) +#undef DISPATCH + if (typeid_cast(raw_type) || typeid_cast(raw_type)) + { + // For DataTypeMyDateTime / DataTypeMyDate, simply compare them as comparing UInt64 is OK. + // Check `struct MyTimeBase` for more details. + return checkNullableNullEqualImpl( + column_nullable, + null_map, + start_pack, + pack_count, + value, + type); + } + if (typeid_cast(raw_type)) + { + const auto * string_column = checkAndGetColumn(column_nullable.getNestedColumnPtr().get()); + const auto & chars = string_column->getChars(); + const auto & offsets = string_column->getOffsets(); + RSResults results(pack_count, RSResult::Some); + for (size_t i = start_pack; i < start_pack + pack_count; ++i) + { + if (details::minIsNull(null_map, i)) + { + if (has_null_marks[i] && !has_value_marks[i]) + results[i - start_pack] = RSResult::None; + continue; + } + + size_t pos = i * 2; + size_t prev_offset = pos == 0 ? 0 : offsets[pos - 1]; + // todo use StringRef instead of String + auto min = String(reinterpret_cast(&chars[prev_offset]), offsets[pos] - prev_offset - 1); + pos = i * 2 + 1; + prev_offset = offsets[pos - 1]; + auto max = String(reinterpret_cast(&chars[prev_offset]), offsets[pos] - prev_offset - 1); + auto value_result = RoughCheck::CheckEqual::check(value, type, min, max); + if (has_null_marks[i] && value_result == RSResult::All) + results[i - start_pack] = RSResult::Some; + else + results[i - start_pack] = value_result; + } + return results; + } + // Should not happen, because TiDB use DataTypeMyDateTime and DataTypeMyDate + if (typeid_cast(raw_type)) + { + return checkNullableNullEqualImpl( + column_nullable, + null_map, + start_pack, + pack_count, + value, + type); + } + if (typeid_cast(raw_type)) + { + return checkNullableNullEqualImpl( + column_nullable, + null_map, + start_pack, + pack_count, + value, + type); + } + return RSResults(pack_count, RSResult::Some); +} + +RSResults MinMaxIndex::checkNullEqual( + size_t start_pack, + size_t pack_count, + const Field & value, + const DataTypePtr & type) +{ + if (value.isNull()) + return checkIsNull(start_pack, pack_count); + + const auto * raw_type = type.get(); + if (typeid_cast(raw_type)) + return checkNullableNullEqual(start_pack, pack_count, value, removeNullable(type)); + + return checkCmp(start_pack, pack_count, value, type); +} + template RSResults MinMaxIndex::checkNullableCmpImpl( const DB::ColumnNullable & column_nullable, @@ -554,10 +682,10 @@ RSResults MinMaxIndex::checkNullableCmp( size_t pos = i * 2; size_t prev_offset = pos == 0 ? 0 : offsets[pos - 1]; // todo use StringRef instead of String - auto min = String(chars[prev_offset], offsets[pos] - prev_offset - 1); + auto min = String(reinterpret_cast(&chars[prev_offset]), offsets[pos] - prev_offset - 1); pos = i * 2 + 1; prev_offset = offsets[pos - 1]; - auto max = String(chars[prev_offset], offsets[pos] - prev_offset - 1); + auto max = String(reinterpret_cast(&chars[prev_offset]), offsets[pos] - prev_offset - 1); auto value_result = Op::template check(value, type, min, max); results[i - start_pack] = addNullIfHasNull(value_result, i); } diff --git a/dbms/src/Storages/DeltaMerge/Index/MinMaxIndex.h b/dbms/src/Storages/DeltaMerge/Index/MinMaxIndex.h index 21f5dcdccbb..6b32501cca3 100644 --- a/dbms/src/Storages/DeltaMerge/Index/MinMaxIndex.h +++ b/dbms/src/Storages/DeltaMerge/Index/MinMaxIndex.h @@ -69,6 +69,7 @@ class MinMaxIndex template RSResults checkCmp(size_t start_pack, size_t pack_count, const Field & value, const DataTypePtr & type); + RSResults checkNullEqual(size_t start_pack, size_t pack_count, const Field & value, const DataTypePtr & type); // TODO: merge with checkCmp RSResults checkIn( @@ -105,6 +106,19 @@ class MinMaxIndex size_t pack_count, const Field & value, const DataTypePtr & type); + template + RSResults checkNullableNullEqualImpl( + const DB::ColumnNullable & column_nullable, + const DB::ColumnUInt8 & null_map, + size_t start_pack, + size_t pack_count, + const Field & value, + const DataTypePtr & type); + RSResults checkNullableNullEqual( + size_t start_pack, + size_t pack_count, + const Field & value, + const DataTypePtr & type); template RSResults checkInImpl( diff --git a/dbms/src/Storages/DeltaMerge/Index/tests/gtest_dm_minmax_index.cpp b/dbms/src/Storages/DeltaMerge/Index/tests/gtest_dm_minmax_index.cpp index 4aec61395ce..7abc645a755 100644 --- a/dbms/src/Storages/DeltaMerge/Index/tests/gtest_dm_minmax_index.cpp +++ b/dbms/src/Storages/DeltaMerge/Index/tests/gtest_dm_minmax_index.cpp @@ -2057,6 +2057,115 @@ try auto filter = createEqual(attr("Nullable(Int64)"), Field(static_cast(1))); ASSERT_EQ(filter->roughCheck(0, 1, param)[0], RSResult::SomeNull); + + // make a null-equal filter, keep the compatibility path conservative. + auto null_eq_filter = createNullEqual(attr("Nullable(Int64)"), Field(static_cast(1))); + + ASSERT_EQ(null_eq_filter->roughCheck(0, 1, param)[0], RSResult::Some); +} +CATCH + +TEST_F(MinMaxIndexTest, NullEQOrNotNullEQWithNullPack) +try +{ + RSCheckParam param; + + auto type = std::make_shared(); + auto data_type = makeNullable(type); + + PaddedPODArray has_null_marks(3); + PaddedPODArray has_value_marks(3); + MutableColumnPtr minmaxes = data_type->createColumn(); + + // pack 0: {1, NULL} + has_null_marks[0] = 1; + has_value_marks[0] = 1; + minmaxes->insert(Field(static_cast(1))); + minmaxes->insert(Field(static_cast(1))); + + // pack 1: {2, NULL} + has_null_marks[1] = 1; + has_value_marks[1] = 1; + minmaxes->insert(Field(static_cast(2))); + minmaxes->insert(Field(static_cast(2))); + + // pack 2: {NULL} + has_null_marks[2] = 1; + has_value_marks[2] = 0; + minmaxes->insertDefault(); + minmaxes->insertDefault(); + + auto minmax + = std::make_shared(std::move(has_null_marks), std::move(has_value_marks), std::move(minmaxes)); + + auto index = RSIndex(data_type, minmax); + param.indexes.emplace(DEFAULT_COL_ID, index); + + auto null_eq = createNullEqual(attr("Nullable(Int64)"), Field(static_cast(1))); + auto not_null_eq = createNot(null_eq); + auto results = null_eq->roughCheck(0, 3, param); + auto not_results = not_null_eq->roughCheck(0, 3, param); + + ASSERT_EQ(results[0], RSResult::Some); + ASSERT_EQ(results[1], RSResult::None); + ASSERT_EQ(results[2], RSResult::None); + + ASSERT_EQ(not_results[0], RSResult::Some); + ASSERT_EQ(not_results[1], RSResult::All); + ASSERT_EQ(not_results[2], RSResult::All); +} +CATCH + +TEST_F(MinMaxIndexTest, NullableStringCmpInAndNullEQ) +try +{ + struct NullableStringTestCase + { + std::vector> column_data; + std::vector del_mark; + }; + + std::vector cases = { + {{String("aa"), std::nullopt}, {0, 0}}, + {{String("bb"), std::nullopt}, {0, 0}}, + {{std::nullopt}, {0}}, + }; + + auto col_type = makeNullable(std::make_shared()); + auto minmax_index = std::make_shared(*col_type); + for (const auto & c : cases) + { + RUNTIME_CHECK(c.column_data.size(), c.del_mark.size()); + auto col_data = createColumn>(c.column_data).column; + auto del_mark_col = createColumn(c.del_mark).column; + minmax_index->addPack(*col_data, static_cast *>(del_mark_col.get())); + } + + auto eq_results = minmax_index->checkCmp(0, cases.size(), Field(String("aa")), col_type); + ASSERT_EQ(eq_results[0], RSResult::AllNull); + ASSERT_EQ(eq_results[1], RSResult::NoneNull); + ASSERT_EQ(eq_results[2], RSResult::SomeNull); + + auto in_results = minmax_index->checkIn(0, cases.size(), {Field(String("aa"))}, col_type); + ASSERT_EQ(in_results[0], RSResult::AllNull); + ASSERT_EQ(in_results[1], RSResult::NoneNull); + ASSERT_EQ(in_results[2], RSResult::SomeNull); + + RSCheckParam param; + param.indexes.emplace(DEFAULT_COL_ID, RSIndex(col_type, minmax_index)); + + auto null_eq = createNullEqual(attr("Nullable(String)"), Field(String("aa"))); + auto not_null_eq = createNot(null_eq); + auto null_eq_results = null_eq->roughCheck(0, cases.size(), param); + auto not_null_eq_results = not_null_eq->roughCheck(0, cases.size(), param); + + ASSERT_EQ(null_eq_results[0], RSResult::Some); + ASSERT_EQ(null_eq_results[1], RSResult::None); + ASSERT_EQ(null_eq_results[2], RSResult::None); + + ASSERT_EQ(not_null_eq_results[0], RSResult::Some); + ASSERT_EQ(not_null_eq_results[1], RSResult::All); + ASSERT_EQ(not_null_eq_results[2], RSResult::All); } CATCH @@ -2270,6 +2379,97 @@ try } CATCH +TEST_F(MinMaxIndexTest, ParseNullEQ) +try +{ + const google::protobuf::RepeatedPtrField pushed_down_filters{}; + google::protobuf::RepeatedPtrField filters; + + auto build_column_ref = [](Int64 column_index) { + tipb::Expr expr; + expr.set_tp(tipb::ExprType::ColumnRef); + WriteBufferFromOwnString ss; + encodeDAGInt64(column_index, ss); + expr.set_val(ss.releaseStr()); + auto * field_type = expr.mutable_field_type(); + field_type->set_tp(tipb::ExprType::Int64); + field_type->set_flag(0); + return expr; + }; + auto build_int_literal = [](Int64 value) { + tipb::Expr expr; + expr.set_tp(tipb::ExprType::Int64); + WriteBufferFromOwnString ss; + encodeDAGInt64(value, ss); + expr.set_val(ss.releaseStr()); + return expr; + }; + auto build_null_literal = [] { + tipb::Expr expr; + expr.set_tp(tipb::ExprType::Null); + return expr; + }; + + { + tipb::Expr expr; + expr.set_sig(tipb::ScalarFuncSig::NullEQInt); + expr.set_tp(tipb::ExprType::ScalarFunc); + expr.add_children()->CopyFrom(build_column_ref(0)); + expr.add_children()->CopyFrom(build_int_literal(1)); + filters.Add()->CopyFrom(expr); + } + { + tipb::Expr expr; + expr.set_sig(tipb::ScalarFuncSig::NullEQInt); + expr.set_tp(tipb::ExprType::ScalarFunc); + expr.add_children()->CopyFrom(build_column_ref(0)); + expr.add_children()->CopyFrom(build_null_literal()); + filters.Add()->CopyFrom(expr); + } + { + tipb::Expr child; + child.set_sig(tipb::ScalarFuncSig::NullEQInt); + child.set_tp(tipb::ExprType::ScalarFunc); + child.add_children()->CopyFrom(build_column_ref(0)); + child.add_children()->CopyFrom(build_int_literal(1)); + + tipb::Expr expr; + expr.set_sig(tipb::ScalarFuncSig::UnaryNotInt); + expr.set_tp(tipb::ExprType::ScalarFunc); + expr.add_children()->CopyFrom(child); + filters.Add()->CopyFrom(expr); + } + + const ColumnDefines columns_to_read = {ColumnDefine{1, "a", std::make_shared()}}; + TiDB::ColumnInfo a; + a.id = 1; + TiDB::ColumnInfos column_infos = {a}; + const auto ann_query_info = tipb::ANNQueryInfo{}; + const auto fts_query_info = tipb::FTSQueryInfo{}; + static const google::protobuf::RepeatedPtrField empty_used_indexes{}; + auto dag_query = std::make_unique( + filters, + ann_query_info, + fts_query_info, + pushed_down_filters, + empty_used_indexes, + column_infos, + std::vector{}, + 0, + context->getTimezoneInfo()); + FilterParser::ColumnIDToAttrMap column_id_to_attr; + for (const auto & cd : columns_to_read) + { + column_id_to_attr[cd.id] = Attr{.col_name = cd.name, .col_id = cd.id, .type = cd.type}; + } + + const auto op = DB::DM::FilterParser::parseDAGQuery(*dag_query, column_infos, column_id_to_attr, Logger::get()); + EXPECT_EQ( + op->toDebugString(), + R"raw({"op":"and","children":[{"op":"null_equal","col":"a","value":"1"},{"op":"isnull","col":"a"},{"op":"not","children":[{"op":"null_equal","col":"a","value":"1"}]}]})raw"); +} +CATCH + namespace { // Only support Int64 for testing. diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_filter_parser_nulleq.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_filter_parser_nulleq.cpp new file mode 100644 index 00000000000..20e1547194e --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_filter_parser_nulleq.cpp @@ -0,0 +1,183 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB::DM::tests +{ + +namespace +{ +tipb::Expr buildColumnRefExpr(Int64 column_index, Int32 field_type) +{ + tipb::Expr col; + col.set_tp(tipb::ExprType::ColumnRef); + { + WriteBufferFromOwnString ss; + encodeDAGInt64(column_index, ss); + col.set_val(ss.releaseStr()); + } + auto * field_type_pb = col.mutable_field_type(); + field_type_pb->set_tp(field_type); + field_type_pb->set_flag(0); + return col; +} + +tipb::Expr buildInt64LiteralExpr(Int64 value) +{ + tipb::Expr lit; + lit.set_tp(tipb::ExprType::Int64); + { + WriteBufferFromOwnString ss; + encodeDAGInt64(value, ss); + lit.set_val(ss.releaseStr()); + } + return lit; +} + +tipb::Expr buildNullLiteralExpr() +{ + tipb::Expr lit; + lit.set_tp(tipb::ExprType::Null); + return lit; +} + +tipb::Expr buildLogicalNotExpr(const tipb::Expr & child) +{ + tipb::Expr expr; + expr.set_sig(tipb::ScalarFuncSig::UnaryNotInt); + expr.set_tp(tipb::ExprType::ScalarFunc); + expr.add_children()->CopyFrom(child); + return expr; +} + +String parseToDebugString(Context & context, const tipb::Expr & filter_expr) +{ + google::protobuf::RepeatedPtrField filters; + filters.Add()->CopyFrom(filter_expr); + + const google::protobuf::RepeatedPtrField pushed_down_filters{}; + + TiDB::ColumnInfo col; + col.id = 1; + TiDB::ColumnInfos column_infos = {col}; + + const ColumnDefines columns_to_read = {ColumnDefine{1, "a", std::make_shared()}}; + FilterParser::ColumnIDToAttrMap column_id_to_attr; + for (const auto & cd : columns_to_read) + { + column_id_to_attr[cd.id] = Attr{.col_name = cd.name, .col_id = cd.id, .type = cd.type}; + } + + const auto ann_query_info = tipb::ANNQueryInfo{}; + const auto fts_query_info = tipb::FTSQueryInfo{}; + static const google::protobuf::RepeatedPtrField empty_used_indexes{}; + auto dag_query = std::make_unique( + filters, + ann_query_info, + fts_query_info, + pushed_down_filters, + empty_used_indexes, + column_infos, + std::vector{}, + 0, + context.getTimezoneInfo()); + + const auto op = DB::DM::FilterParser::parseDAGQuery(*dag_query, column_infos, column_id_to_attr, Logger::get()); + return op->toDebugString(); +} +} // namespace + +TEST(DMFilterParserTest, ParseNullEQ) +try +{ + auto context = DMTestEnv::getContext(); + + { + // a <=> 1 -> null_equal(a, 1) + tipb::Expr expr; + expr.set_sig(tipb::ScalarFuncSig::NullEQInt); + expr.set_tp(tipb::ExprType::ScalarFunc); + expr.add_children()->CopyFrom(buildColumnRefExpr(/*column_index*/ 0, TiDB::TypeLongLong)); + expr.add_children()->CopyFrom(buildInt64LiteralExpr(1)); + EXPECT_EQ(parseToDebugString(*context, expr), R"raw({"op":"null_equal","col":"a","value":"1"})raw"); + } + + { + // a <=> NULL -> isnull(a) + tipb::Expr expr; + expr.set_sig(tipb::ScalarFuncSig::NullEQInt); + expr.set_tp(tipb::ExprType::ScalarFunc); + expr.add_children()->CopyFrom(buildColumnRefExpr(/*column_index*/ 0, TiDB::TypeLongLong)); + expr.add_children()->CopyFrom(buildNullLiteralExpr()); + EXPECT_EQ(parseToDebugString(*context, expr), R"raw({"op":"isnull","col":"a"})raw"); + } + + { + // NULL <=> a -> isnull(a) + tipb::Expr expr; + expr.set_sig(tipb::ScalarFuncSig::NullEQInt); + expr.set_tp(tipb::ExprType::ScalarFunc); + expr.add_children()->CopyFrom(buildNullLiteralExpr()); + expr.add_children()->CopyFrom(buildColumnRefExpr(/*column_index*/ 0, TiDB::TypeLongLong)); + EXPECT_EQ(parseToDebugString(*context, expr), R"raw({"op":"isnull","col":"a"})raw"); + } + + { + // 1 <=> a -> null_equal(a, 1) + tipb::Expr expr; + expr.set_sig(tipb::ScalarFuncSig::NullEQInt); + expr.set_tp(tipb::ExprType::ScalarFunc); + expr.add_children()->CopyFrom(buildInt64LiteralExpr(1)); + expr.add_children()->CopyFrom(buildColumnRefExpr(/*column_index*/ 0, TiDB::TypeLongLong)); + EXPECT_EQ(parseToDebugString(*context, expr), R"raw({"op":"null_equal","col":"a","value":"1"})raw"); + } + + { + // not(a <=> 1) keeps the dedicated null_equal node under logical not. + tipb::Expr expr; + expr.set_sig(tipb::ScalarFuncSig::NullEQInt); + expr.set_tp(tipb::ExprType::ScalarFunc); + expr.add_children()->CopyFrom(buildColumnRefExpr(/*column_index*/ 0, TiDB::TypeLongLong)); + expr.add_children()->CopyFrom(buildInt64LiteralExpr(1)); + EXPECT_EQ( + parseToDebugString(*context, buildLogicalNotExpr(expr)), + R"raw({"op":"not","children":[{"op":"null_equal","col":"a","value":"1"}]})raw"); + } + + { + // White-box regression for the NullEQJson signature. + // DM rough set filter does not currently support JSON ColumnRef directly, + // so use a supported ColumnRef type here and verify the NullEQJson sig still + // lowers `<=> NULL` to `isnull(col)` once it reaches parseTiCompareExpr. + tipb::Expr expr; + expr.set_sig(tipb::ScalarFuncSig::NullEQJson); + expr.set_tp(tipb::ExprType::ScalarFunc); + expr.add_children()->CopyFrom(buildColumnRefExpr(/*column_index*/ 0, TiDB::TypeLongLong)); + expr.add_children()->CopyFrom(buildNullLiteralExpr()); + EXPECT_EQ(parseToDebugString(*context, expr), R"raw({"op":"isnull","col":"a"})raw"); + } +} +CATCH + +} // namespace DB::DM::tests