From 3e83eb35c52ee3bcf4f9260dc923d4a5a66be209 Mon Sep 17 00:00:00 2001 From: yy <736262857@qq.com> Date: Thu, 2 Jul 2026 12:23:43 +0000 Subject: [PATCH 1/5] feat(tiflash): support push down NullEq function - Add NullEq function support in TiFlash side - Implement NullEq comparison logic for all data types - Add unit tests and integration tests for NullEq pushdown Issue: #5102 Signed-off-by: yy <736262857@qq.com> --- .../DAGExpressionAnalyzerHelper.cpp | 34 ++ .../Coprocessor/DAGExpressionAnalyzerHelper.h | 5 + dbms/src/Flash/Coprocessor/DAGUtils.cpp | 14 +- dbms/src/Functions/tests/gtest_nulleq.cpp | 504 ++++++++++++++++++ tests/fullstack-test/expr/nulleq.test | 28 + 5 files changed, 578 insertions(+), 7 deletions(-) create mode 100644 dbms/src/Functions/tests/gtest_nulleq.cpp create mode 100644 tests/fullstack-test/expr/nulleq.test diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.cpp index e24398e65d1..cbed8f8207c 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.cpp @@ -124,6 +124,39 @@ String DAGExpressionAnalyzerHelper::buildIfNullFunction( return analyzer->applyFunction(func_name, argument_names, actions, getCollatorFromExpr(expr)); } +String DAGExpressionAnalyzerHelper::buildNullEqFunction( + DAGExpressionAnalyzer * analyzer, + const tipb::Expr & expr, + const ExpressionActionsPtr & actions) +{ + if (expr.children_size() != 2) + { + throw TiFlashException("Invalid arguments of nullEq function", Errors::Coprocessor::BadRequest); + } + + String col1 = analyzer->getActions(expr.children(0), actions, false); + String col2 = analyzer->getActions(expr.children(1), actions, false); + + const Block & sample_block = actions->getSampleBlock(); + bool col1_nullable = sample_block.getByName(col1).type->isNullable(); + bool col2_nullable = sample_block.getByName(col2).type->isNullable(); + + if (!col1_nullable && !col2_nullable) + { + return analyzer->applyFunction("equals", {col1, col2}, actions, getCollatorFromExpr(expr)); + } + + String equals = analyzer->applyFunction("equals", {col1, col2}, actions, getCollatorFromExpr(expr)); + String name = analyzer->getActions(constructInt64LiteralTiExpr(0), actions); + + String is_null_col1 = analyzer->applyFunction("isNull", {col1}, actions, getCollatorFromExpr(expr)); + String is_null_col2 = analyzer->applyFunction("isNull", {col2}, actions, getCollatorFromExpr(expr)); + String and_is_null = analyzer->applyFunction("and", {is_null_col1, is_null_col2}, actions, nullptr); + String not_null_equals = analyzer->applyFunction("coalesce", {equals, name}, actions, nullptr); + + return analyzer->applyFunction("or", {and_is_null, not_null_equals}, actions, nullptr); +} + String DAGExpressionAnalyzerHelper::buildInFunction( DAGExpressionAnalyzer * analyzer, const tipb::Expr & expr, @@ -547,6 +580,7 @@ DAGExpressionAnalyzerHelper::FunctionBuilderMap DAGExpressionAnalyzerHelper::fun {"tidbIn", DAGExpressionAnalyzerHelper::buildInFunction}, {"tidbNotIn", DAGExpressionAnalyzerHelper::buildInFunction}, {"ifNull", DAGExpressionAnalyzerHelper::buildIfNullFunction}, + {"nullEq", DAGExpressionAnalyzerHelper::buildNullEqFunction}, {"multiIf", DAGExpressionAnalyzerHelper::buildMultiIfFunction}, {"tidb_cast", DAGExpressionAnalyzerHelper::buildCastFunction}, {"cast_int_as_json", DAGExpressionAnalyzerHelper::buildSingleParamJsonRelatedFunctions}, diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.h b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.h index 529dc204ace..6024b6d3145 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.h +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.h @@ -49,6 +49,11 @@ class DAGExpressionAnalyzerHelper const tipb::Expr & expr, const ExpressionActionsPtr & actions); + static String buildNullEqFunction( + DAGExpressionAnalyzer * analyzer, + const tipb::Expr & expr, + const ExpressionActionsPtr & actions); + static String buildLogicalFunction( DAGExpressionAnalyzer * analyzer, const tipb::Expr & expr, diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index 7f4ed4d9d7d..62fd065dfae 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -226,13 +226,13 @@ const std::unordered_map scalar_func_map({ //{tipb::ScalarFuncSig::NEJson, "notEquals"}, {tipb::ScalarFuncSig::NEVectorFloat32, "notEquals"}, - //{tipb::ScalarFuncSig::NullEQInt, "cast"}, - //{tipb::ScalarFuncSig::NullEQReal, "cast"}, - //{tipb::ScalarFuncSig::NullEQString, "cast"}, - //{tipb::ScalarFuncSig::NullEQDecimal, "cast"}, - //{tipb::ScalarFuncSig::NullEQTime, "cast"}, - //{tipb::ScalarFuncSig::NullEQDuration, "cast"}, - //{tipb::ScalarFuncSig::NullEQJson, "cast"}, + {tipb::ScalarFuncSig::NullEQInt, "nullEq"}, + {tipb::ScalarFuncSig::NullEQReal, "nullEq"}, + {tipb::ScalarFuncSig::NullEQString, "nullEq"}, + {tipb::ScalarFuncSig::NullEQDecimal, "nullEq"}, + {tipb::ScalarFuncSig::NullEQTime, "nullEq"}, + {tipb::ScalarFuncSig::NullEQDuration, "nullEq"}, + {tipb::ScalarFuncSig::NullEQJson, "nullEq"}, {tipb::ScalarFuncSig::PlusReal, "plus"}, {tipb::ScalarFuncSig::PlusDecimal, "plus"}, diff --git a/dbms/src/Functions/tests/gtest_nulleq.cpp b/dbms/src/Functions/tests/gtest_nulleq.cpp new file mode 100644 index 00000000000..33b0337197c --- /dev/null +++ b/dbms/src/Functions/tests/gtest_nulleq.cpp @@ -0,0 +1,504 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace DB +{ +namespace tests +{ +class TestNullEq : public DB::tests::FunctionTest +{ +protected: + ColumnWithTypeAndName executeNullEq(const ColumnWithTypeAndName & first_column, const ColumnWithTypeAndName & second_column) + { + return executeFunction("nullEq", first_column, second_column); + } + DataTypePtr getReturnTypeForNullEq(const DataTypePtr & type_1, const DataTypePtr & type_2) + { + ColumnsWithTypeAndName input_columns{ + {nullptr, type_1, ""}, + {nullptr, type_2, ""}, + }; + return getReturnTypeForFunction(*context, "nullEq", input_columns); + } + template + ColumnWithTypeAndName createIntegerColumnInternal(const std::vector & signed_input, const std::vector unsigned_input, const std::vector & null_map) + { + static_assert(std::is_integral_v); + InferredDataVector data_vector; + if constexpr (std::is_signed_v) + { + for (auto v : signed_input) + data_vector.push_back(static_cast(v)); + } + else + { + for (auto v : unsigned_input) + data_vector.push_back(static_cast(v)); + } + return null_map.empty() ? createColumn(data_vector) : createNullableColumn(data_vector, null_map); + } + template + ColumnWithTypeAndName createFloatColumnInternal(const std::vector & float_input, const std::vector & null_map) + { + static_assert(std::is_floating_point_v); + InferredDataVector data_vector; + for (auto v : float_input) + data_vector.push_back(static_cast(v)); + return null_map.empty() ? createColumn(data_vector) : createNullableColumn(data_vector, null_map); + } + void testNullEqFunction(const ColumnsWithTypeAndName & input_columns, const std::vector & null_map) + { + auto col_result = createColumn({0, 1}); + for (const auto & col_1 : input_columns) + { + for (const auto & col_2 : input_columns) + { + const auto * const result_type_name_nulleq = "UInt8"; + auto result_type = DataTypeFactory::instance().get(result_type_name_nulleq); + auto result_column = result_type->createColumn(); + for (size_t i = 0; i < col_1.column->size(); i++) + { + /// NullEq logic implementation. + Field result; + if (col_1.type->isNullable() && null_map[i] && col_2.type->isNullable() && null_map[i]) + { + /// Both Null. + col_result.column->get(1, result); + result_column->insert(result); + } + else if (col_1.type->isNullable() && null_map[i]) + { + /// Only one side is Null + col_result.column->get(0, result); + result_column->insert(result); + } + else if (col_2.type->isNullable() && null_map[i]) + { + /// Only one side is Null + col_result.column->get(0, result); + result_column->insert(result); + } + else + { + /// Both not Null. Make a comparison then. + Field col1_field; + col_1.column->get(i, col1_field); + Field col2_field; + col_2.column->get(i, col2_field); + bool equals = (col1_field == col2_field); + if (col1_field.toString().find("Decimal") != std::string::npos && col2_field.toString().find("Decimal") != std::string::npos) + { + /// Ugly Fix of Decimal. + auto decimal_string = [&](const Field & value) -> DB::String { + switch (value.getType()) + { + case Field::Types::Which::Decimal32: + { + auto v = safeGet>(value); + return v.toString(); + } + case Field::Types::Which::Decimal64: + { + auto v = safeGet>(value); + return v.toString(); + } + case Field::Types::Which::Decimal128: + { + auto v = safeGet>(value); + return v.toString(); + } + case Field::Types::Which::Decimal256: + { + auto v = safeGet>(value); + return v.toString(); + } + default: + throw Exception("Unsupported with data type."); + } + }; + + /// I know all the tested decimal have actually the scale of 2. + /// So they are just substring-relationship. + const auto decimal_string_col1 = decimal_string(col1_field); + const auto decimal_string_col2 = decimal_string(col2_field); + if (decimal_string_col1.size() > decimal_string_col2.size()) + { + equals = (decimal_string_col1.find(decimal_string_col2) != std::string::npos); + } + else + { + equals = (decimal_string_col2.find(decimal_string_col1) != std::string::npos); + } + } + col_result.column->get(equals, result); + result_column->insert(result); + } + } + ColumnWithTypeAndName expected{std::move(result_column), result_type, ""}; + ASSERT_COLUMN_EQ(expected, executeNullEq(col_1, col_2)); + } + } + } +}; + +TEST_F(TestNullEq, TestInputType) +try +{ + std::vector null_map{1, 0, 0, 0, 1}; + /// case 1 test NullEqInt + std::vector signed_ints{-2, -1, 0, 1, 2}; + std::vector unsigned_ints{0, 1, 2, 3, 4}; + ColumnsWithTypeAndName int_input{ + /// not null column + createIntegerColumnInternal(signed_ints, unsigned_ints, {}), + createIntegerColumnInternal(signed_ints, unsigned_ints, {}), + createIntegerColumnInternal(signed_ints, unsigned_ints, {}), + createIntegerColumnInternal(signed_ints, unsigned_ints, {}), + createIntegerColumnInternal(signed_ints, unsigned_ints, {}), + createIntegerColumnInternal(signed_ints, unsigned_ints, {}), + createIntegerColumnInternal(signed_ints, unsigned_ints, {}), + createIntegerColumnInternal(signed_ints, unsigned_ints, {}), + /// nullable column + createIntegerColumnInternal(signed_ints, unsigned_ints, null_map), + createIntegerColumnInternal(signed_ints, unsigned_ints, null_map), + createIntegerColumnInternal(signed_ints, unsigned_ints, null_map), + createIntegerColumnInternal(signed_ints, unsigned_ints, null_map), + createIntegerColumnInternal(signed_ints, unsigned_ints, null_map), + createIntegerColumnInternal(signed_ints, unsigned_ints, null_map), + createIntegerColumnInternal(signed_ints, unsigned_ints, null_map), + createIntegerColumnInternal(signed_ints, unsigned_ints, null_map), + }; + testNullEqFunction(int_input, null_map); + /// case 2 test NullEqReal + std::vector float_data{-2, -1, 0, 1, 2}; + ColumnsWithTypeAndName float_input{ + createFloatColumnInternal(float_data, {}), + createFloatColumnInternal(float_data, {}), + createFloatColumnInternal(float_data, null_map), + createFloatColumnInternal(float_data, null_map), + }; + testNullEqFunction(float_input, null_map); + /// case 3 test NullEqString + std::vector string_data{"abc", "bcd", "cde", "def", "efg"}; + ColumnsWithTypeAndName string_input{ + createColumn(string_data), + createNullableColumn(string_data, null_map), + }; + testNullEqFunction(string_input, null_map); + /// case 4 test NullEqDecimal + std::vector decimal_data{"-12.34", "-12.12", "0.00", "12.12", "12.34"}; + ColumnsWithTypeAndName decimal_input{ + createColumn(std::make_tuple(5, 3), decimal_data), + createNullableColumn(std::make_tuple(5, 3), decimal_data, null_map), + createColumn(std::make_tuple(12, 4), decimal_data), + createNullableColumn(std::make_tuple(12, 4), decimal_data, null_map), + createColumn(std::make_tuple(20, 2), decimal_data), + createNullableColumn(std::make_tuple(20, 2), decimal_data, null_map), + createColumn(std::make_tuple(40, 6), decimal_data), + createNullableColumn(std::make_tuple(40, 6), decimal_data, null_map), + }; + testNullEqFunction(decimal_input, null_map); + /// case 5 test NullEqTime + InferredDataVector date_data{ + MyDate(0, 0, 0).toPackedUInt(), + MyDate(1960, 1, 1).toPackedUInt(), + MyDate(1999, 1, 1).toPackedUInt(), + MyDate(2000, 1, 1).toPackedUInt(), + MyDate(2400, 1, 1).toPackedUInt(), + }; + InferredDataVector datetime_data_fsp_0{ + MyDateTime(0, 0, 0, 0, 0, 0, 0).toPackedUInt(), + MyDateTime(1960, 1, 1, 11, 11, 11, 0).toPackedUInt(), + MyDateTime(1999, 1, 1, 22, 22, 22, 0).toPackedUInt(), + MyDateTime(2000, 1, 1, 1, 1, 1, 0).toPackedUInt(), + MyDateTime(2400, 1, 1, 2, 2, 2, 0).toPackedUInt(), + }; + InferredDataVector datetime_data_fsp_3{ + MyDateTime(0, 0, 0, 0, 0, 0, 0).toPackedUInt(), + MyDateTime(1960, 1, 1, 11, 11, 11, 123000).toPackedUInt(), + MyDateTime(1999, 1, 1, 22, 22, 22, 234000).toPackedUInt(), + MyDateTime(2000, 1, 1, 1, 1, 1, 345000).toPackedUInt(), + MyDateTime(2400, 1, 1, 2, 2, 2, 456000).toPackedUInt(), + }; + InferredDataVector datetime_data_fsp_6{ + MyDateTime(0, 0, 0, 0, 0, 0, 0).toPackedUInt(), + MyDateTime(1960, 1, 1, 11, 11, 11, 123456).toPackedUInt(), + MyDateTime(1999, 1, 1, 22, 22, 22, 234561).toPackedUInt(), + MyDateTime(2000, 1, 1, 1, 1, 1, 345612).toPackedUInt(), + MyDateTime(2400, 1, 1, 2, 2, 2, 456123).toPackedUInt(), + }; + ColumnsWithTypeAndName time_input{ + createColumn(date_data), + createNullableColumn(date_data, null_map), + createColumn(std::make_tuple(0), datetime_data_fsp_0), + createNullableColumn(std::make_tuple(0), datetime_data_fsp_0, null_map), + createColumn(std::make_tuple(3), datetime_data_fsp_3), + createNullableColumn(std::make_tuple(3), datetime_data_fsp_3, null_map), + createColumn(std::make_tuple(6), datetime_data_fsp_6), + createNullableColumn(std::make_tuple(6), datetime_data_fsp_6, null_map), + }; + testNullEqFunction(time_input, null_map); + /// case 6 test NullEqDuration + InferredDataVector duration_data_fsp_0{ + MyDuration(-1, 10, 10, 10, 0, 0).nanoSecond(), + MyDuration(-1, 11, 11, 11, 0, 0).nanoSecond(), + MyDuration(1, 10, 10, 10, 0, 0).nanoSecond(), + MyDuration(1, 11, 11, 11, 0, 0).nanoSecond(), + MyDuration(1, 12, 12, 12, 0, 0).nanoSecond(), + }; + InferredDataVector duration_data_fsp_3{ + MyDuration(-1, 10, 10, 10, 123000, 3).nanoSecond(), + MyDuration(-1, 11, 11, 11, 234000, 3).nanoSecond(), + MyDuration(1, 10, 10, 10, 345000, 3).nanoSecond(), + MyDuration(1, 11, 11, 11, 456000, 3).nanoSecond(), + MyDuration(1, 12, 12, 12, 561000, 3).nanoSecond(), + }; + InferredDataVector duration_data_fsp_6{ + MyDuration(-1, 10, 10, 10, 123456, 6).nanoSecond(), + MyDuration(-1, 11, 11, 11, 234561, 6).nanoSecond(), + MyDuration(1, 10, 10, 10, 345612, 6).nanoSecond(), + MyDuration(1, 11, 11, 11, 456123, 6).nanoSecond(), + MyDuration(1, 12, 12, 12, 561234, 6).nanoSecond(), + }; + ColumnsWithTypeAndName duration_input{ + createColumn(std::make_tuple(0), duration_data_fsp_0), + createNullableColumn(std::make_tuple(0), duration_data_fsp_0, null_map), + createColumn(std::make_tuple(3), duration_data_fsp_3), + createNullableColumn(std::make_tuple(3), duration_data_fsp_3, null_map), + createColumn(std::make_tuple(6), duration_data_fsp_6), + createNullableColumn(std::make_tuple(6), duration_data_fsp_6, null_map), + }; + testNullEqFunction(duration_input, null_map); +} +CATCH +TEST_F(TestNullEq, TestTypeInfer) +try +{ + auto test_type = [&](const String & col_1_type_name, const String & col_2_type_name, const String & result_type_name) { + auto result_type = DataTypeFactory::instance().get(result_type_name); + auto col_1_type = DataTypeFactory::instance().get(col_1_type_name); + auto col_2_type = DataTypeFactory::instance().get(col_2_type_name); + ASSERT_TRUE(result_type->equals(*getReturnTypeForNullEq(col_1_type, col_2_type))); + }; + /// test integer type Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64 + const auto * const result_type_name_nulleq = "UInt8"; + test_type("Int8", "Int8", result_type_name_nulleq); + test_type("Int8", "Int16", result_type_name_nulleq); + test_type("Int8", "Int32", result_type_name_nulleq); + test_type("Int8", "Int64", result_type_name_nulleq); + test_type("Int8", "UInt8", result_type_name_nulleq); + test_type("Int8", "UInt16", result_type_name_nulleq); + test_type("Int8", "UInt32", result_type_name_nulleq); + test_type("Int8", "UInt64", result_type_name_nulleq); + + test_type("UInt8", "Int8", result_type_name_nulleq); + test_type("UInt8", "Int16", result_type_name_nulleq); + test_type("UInt8", "Int32", result_type_name_nulleq); + test_type("UInt8", "Int64", result_type_name_nulleq); + test_type("UInt8", "UInt8", result_type_name_nulleq); + test_type("UInt8", "UInt16", result_type_name_nulleq); + test_type("UInt8", "UInt32", result_type_name_nulleq); + test_type("UInt8", "UInt64", result_type_name_nulleq); + + test_type("Int16", "Int8", result_type_name_nulleq); + test_type("Int16", "Int16", result_type_name_nulleq); + test_type("Int16", "Int32", result_type_name_nulleq); + test_type("Int16", "Int64", result_type_name_nulleq); + test_type("Int16", "UInt8", result_type_name_nulleq); + test_type("Int16", "UInt16", result_type_name_nulleq); + test_type("Int16", "UInt32", result_type_name_nulleq); + test_type("Int16", "UInt64", result_type_name_nulleq); + + test_type("UInt16", "Int8", result_type_name_nulleq); + test_type("UInt16", "Int16", result_type_name_nulleq); + test_type("UInt16", "Int32", result_type_name_nulleq); + test_type("UInt16", "Int64", result_type_name_nulleq); + test_type("UInt16", "UInt8", result_type_name_nulleq); + test_type("UInt16", "UInt16", result_type_name_nulleq); + test_type("UInt16", "UInt32", result_type_name_nulleq); + test_type("UInt16", "UInt64", result_type_name_nulleq); + + test_type("Int32", "Int8", result_type_name_nulleq); + test_type("Int32", "Int16", result_type_name_nulleq); + test_type("Int32", "Int32", result_type_name_nulleq); + test_type("Int32", "Int64", result_type_name_nulleq); + test_type("Int32", "UInt8", result_type_name_nulleq); + test_type("Int32", "UInt16", result_type_name_nulleq); + test_type("Int32", "UInt32", result_type_name_nulleq); + test_type("Int32", "UInt64", result_type_name_nulleq); + + test_type("UInt32", "Int8", result_type_name_nulleq); + test_type("UInt32", "Int16", result_type_name_nulleq); + test_type("UInt32", "Int32", result_type_name_nulleq); + test_type("UInt32", "Int64", result_type_name_nulleq); + test_type("UInt32", "UInt8", result_type_name_nulleq); + test_type("UInt32", "UInt16", result_type_name_nulleq); + test_type("UInt32", "UInt32", result_type_name_nulleq); + test_type("UInt32", "UInt64", result_type_name_nulleq); + + test_type("Int64", "Int8", result_type_name_nulleq); + test_type("Int64", "Int16", result_type_name_nulleq); + test_type("Int64", "Int32", result_type_name_nulleq); + test_type("Int64", "Int64", result_type_name_nulleq); + test_type("Int64", "UInt8", result_type_name_nulleq); + test_type("Int64", "UInt16", result_type_name_nulleq); + test_type("Int64", "UInt32", result_type_name_nulleq); + test_type("Int64", "UInt64", result_type_name_nulleq); + + test_type("UInt64", "Int8", result_type_name_nulleq); + test_type("UInt64", "Int16", result_type_name_nulleq); + test_type("UInt64", "Int32", result_type_name_nulleq); + test_type("UInt64", "Int64", result_type_name_nulleq); + test_type("UInt64", "UInt8", result_type_name_nulleq); + test_type("UInt64", "UInt16", result_type_name_nulleq); + test_type("UInt64", "UInt32", result_type_name_nulleq); + test_type("UInt64", "UInt64", result_type_name_nulleq); + + /// test type infer for real + test_type("Float32", "Float32", result_type_name_nulleq); + test_type("Float32", "Float64", result_type_name_nulleq); + test_type("Float64", "Float32", result_type_name_nulleq); + test_type("Float64", "Float64", result_type_name_nulleq); + + /// test type infer for string + test_type("String", "String", result_type_name_nulleq); + + /// test type infer for decimal + test_type("Decimal(5,3)", "Decimal(5,3)", result_type_name_nulleq); + test_type("Decimal(5,3)", "Decimal(12,4)", result_type_name_nulleq); + test_type("Decimal(5,3)", "Decimal(20,2)", result_type_name_nulleq); + test_type("Decimal(5,3)", "Decimal(40,6)", result_type_name_nulleq); + + test_type("Decimal(12,4)", "Decimal(5,3)", result_type_name_nulleq); + test_type("Decimal(12,4)", "Decimal(12,4)", result_type_name_nulleq); + test_type("Decimal(12,4)", "Decimal(20,2)", result_type_name_nulleq); + test_type("Decimal(12,4)", "Decimal(40,6)", result_type_name_nulleq); + + test_type("Decimal(20,2)", "Decimal(5,3)", result_type_name_nulleq); + test_type("Decimal(20,2)", "Decimal(12,4)", result_type_name_nulleq); + test_type("Decimal(20,2)", "Decimal(20,2)", result_type_name_nulleq); + test_type("Decimal(20,2)", "Decimal(40,6)", result_type_name_nulleq); + + test_type("Decimal(40,6)", "Decimal(5,3)", result_type_name_nulleq); + test_type("Decimal(40,6)", "Decimal(12,4)", result_type_name_nulleq); + test_type("Decimal(40,6)", "Decimal(20,2)", result_type_name_nulleq); + test_type("Decimal(40,6)", "Decimal(40,6)", result_type_name_nulleq); + + /// test type infer for time + test_type("MyDate", "MyDate", result_type_name_nulleq); + test_type("MyDate", "MyDateTime(0)", result_type_name_nulleq); + test_type("MyDate", "MyDateTime(3)", result_type_name_nulleq); + test_type("MyDate", "MyDateTime(6)", result_type_name_nulleq); + + test_type("MyDateTime(0)", "MyDate", result_type_name_nulleq); + test_type("MyDateTime(0)", "MyDateTime(0)", result_type_name_nulleq); + test_type("MyDateTime(0)", "MyDateTime(3)", result_type_name_nulleq); + test_type("MyDateTime(0)", "MyDateTime(6)", result_type_name_nulleq); + + test_type("MyDateTime(3)", "MyDate", result_type_name_nulleq); + test_type("MyDateTime(3)", "MyDateTime(0)", result_type_name_nulleq); + test_type("MyDateTime(3)", "MyDateTime(3)", result_type_name_nulleq); + test_type("MyDateTime(3)", "MyDateTime(6)", result_type_name_nulleq); + + test_type("MyDateTime(6)", "MyDate", result_type_name_nulleq); + test_type("MyDateTime(6)", "MyDateTime(0)", result_type_name_nulleq); + test_type("MyDateTime(6)", "MyDateTime(3)", result_type_name_nulleq); + test_type("MyDateTime(6)", "MyDateTime(6)", result_type_name_nulleq); + + /// test type infer for Duration + test_type("MyDuration(0)", "MyDuration(0)", result_type_name_nulleq); + test_type("MyDuration(0)", "MyDuration(3)", result_type_name_nulleq); + test_type("MyDuration(0)", "MyDuration(6)", result_type_name_nulleq); + + test_type("MyDuration(3)", "MyDuration(0)", result_type_name_nulleq); + test_type("MyDuration(3)", "MyDuration(3)", result_type_name_nulleq); + test_type("MyDuration(3)", "MyDuration(6)", result_type_name_nulleq); + + test_type("MyDuration(6)", "MyDuration(0)", result_type_name_nulleq); + test_type("MyDuration(6)", "MyDuration(3)", result_type_name_nulleq); + test_type("MyDuration(6)", "MyDuration(6)", result_type_name_nulleq); + + /// test nullable related + test_type("Nullable(Int8)", "Nullable(Int8)", result_type_name_nulleq); + test_type("Nullable(Int8)", "Int8", result_type_name_nulleq); + test_type("Int8", "Nullable(Int8)", result_type_name_nulleq); + + test_type("Nullable(UInt8)", "Nullable(UInt8)", result_type_name_nulleq); + test_type("Nullable(UInt8)", "UInt8", result_type_name_nulleq); + test_type("UInt8", "Nullable(UInt8)", result_type_name_nulleq); + + test_type("Nullable(Int16)", "Nullable(Int16)", result_type_name_nulleq); + test_type("Nullable(Int16)", "Int16", result_type_name_nulleq); + test_type("Int16", "Nullable(Int16)", result_type_name_nulleq); + + test_type("Nullable(UInt16)", "Nullable(UInt16)", result_type_name_nulleq); + test_type("Nullable(UInt16)", "UInt16", result_type_name_nulleq); + test_type("UInt16", "Nullable(UInt16)", result_type_name_nulleq); + + test_type("Nullable(Int32)", "Nullable(Int32)", result_type_name_nulleq); + test_type("Nullable(Int32)", "Int32", result_type_name_nulleq); + test_type("Int32", "Nullable(Int32)", result_type_name_nulleq); + + test_type("Nullable(UInt32)", "Nullable(UInt32)", result_type_name_nulleq); + test_type("Nullable(UInt32)", "UInt32", result_type_name_nulleq); + test_type("UInt32", "Nullable(UInt32)", result_type_name_nulleq); + + test_type("Nullable(Int64)", "Nullable(Int64)", result_type_name_nulleq); + test_type("Nullable(Int64)", "Int64", result_type_name_nulleq); + test_type("Int64", "Nullable(Int64)", result_type_name_nulleq); + + test_type("Nullable(UInt64)", "Nullable(UInt64)", result_type_name_nulleq); + test_type("Nullable(UInt64)", "UInt64", result_type_name_nulleq); + test_type("UInt64", "Nullable(UInt64)", result_type_name_nulleq); + + test_type("Nullable(Float32)", "Nullable(Float32)", result_type_name_nulleq); + test_type("Nullable(Float32)", "Float32", result_type_name_nulleq); + test_type("Float32", "Nullable(Float32)", result_type_name_nulleq); + + test_type("Nullable(String)", "Nullable(String)", result_type_name_nulleq); + test_type("Nullable(String)", "String", result_type_name_nulleq); + test_type("String", "Nullable(String)", result_type_name_nulleq); + + test_type("Nullable(Decimal(5,3))", "Nullable(Decimal(5,3))", result_type_name_nulleq); + test_type("Nullable(Decimal(5,3))", "Decimal(5,3)", result_type_name_nulleq); + test_type("Decimal(5,3)", "Nullable(Decimal(5,3))", result_type_name_nulleq); + + test_type("Nullable(MyDate)", "Nullable(MyDate)", result_type_name_nulleq); + test_type("Nullable(MyDate)", "MyDate", result_type_name_nulleq); + test_type("MyDate", "Nullable(MyDate)", result_type_name_nulleq); + + test_type("Nullable(MyDateTime(0))", "Nullable(MyDateTime(0))", result_type_name_nulleq); + test_type("Nullable(MyDateTime(0))", "MyDateTime(0)", result_type_name_nulleq); + test_type("MyDateTime(0)", "Nullable(MyDateTime(0))", result_type_name_nulleq); + + test_type("Nullable(MyDuration(0))", "Nullable(MyDuration(0))", result_type_name_nulleq); + test_type("Nullable(MyDuration(0))", "MyDuration(0)", result_type_name_nulleq); + test_type("MyDuration(0)", "Nullable(MyDuration(0))", result_type_name_nulleq); +} +CATCH +} // namespace tests +} // namespace DB \ No newline at end of file diff --git a/tests/fullstack-test/expr/nulleq.test b/tests/fullstack-test/expr/nulleq.test new file mode 100644 index 00000000000..bbf8dc687b6 --- /dev/null +++ b/tests/fullstack-test/expr/nulleq.test @@ -0,0 +1,28 @@ +# Copyright 2022 PingCAP, Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +mysql> drop table if exists test.t; +mysql> create table test.t(a int, b int); +mysql> insert into test.t values(1, null),(null, 1); +mysql> alter table test.t set tiflash replica 1; +func> wait_table test t +mysql> set @@tidb_isolation_read_engines='tiflash'; +mysql> set @@tidb_isolation_read_engines='tiflash'; select a, b, nulleq(a, Null), nulleq(b, null), nulleq(a, 1), nulleq(b, 1) from test.t; ++------+------+-----------------+-----------------+--------------+--------------+ +| a | b | nulleq(a, Null) | nulleq(b, null) | nulleq(a, 1) | nulleq(b, 1) | ++------+------+-----------------+-----------------+--------------+--------------+ +| 1 | NULL | 0 | 1 | 1 | 0 | +| NULL | 1 | 1 | 0 | 0 | 1 | ++------+------+-----------------+-----------------+--------------+--------------+ +mysql> drop table if exists test.t; \ No newline at end of file From fe6ec21a9e107e22f4e8cfca2b867544446e47e1 Mon Sep 17 00:00:00 2001 From: yy <736262857@qq.com> Date: Fri, 3 Jul 2026 07:15:52 +0000 Subject: [PATCH 2/5] feat(tiflash): support NullEq pushdown in Runtime Filter - Add NullEq handling logic in RS (Runtime Filter) generation - Ensure NULL values are correctly considered in filter construction - Add related test cases for RS with NullEq Ref: #5102 Signed-off-by: yy <736262857@qq.com> --- .../Storages/DeltaMerge/Filter/NullEqual.h | 69 ++++ .../Storages/DeltaMerge/Filter/RSOperator.cpp | 2 + .../Storages/DeltaMerge/Filter/RSOperator.h | 1 + .../DeltaMerge/FilterParser/FilterParser.cpp | 18 +- .../DeltaMerge/FilterParser/FilterParser.h | 2 + .../Index/tests/gtest_dm_minmax_index.cpp | 308 ++++++++++++++++++ 6 files changed, 393 insertions(+), 7 deletions(-) create mode 100644 dbms/src/Storages/DeltaMerge/Filter/NullEqual.h diff --git a/dbms/src/Storages/DeltaMerge/Filter/NullEqual.h b/dbms/src/Storages/DeltaMerge/Filter/NullEqual.h new file mode 100644 index 00000000000..5c2f9090b9d --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Filter/NullEqual.h @@ -0,0 +1,69 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace DB::DM +{ + +class NullEqual : public ColCmpVal +{ +public: + NullEqual(const Attr & attr_, const Field & value_) + : ColCmpVal(attr_, value_) + {} + + String name() override { return "nullEqual"; } + + RSResults roughCheck(size_t start_pack, size_t pack_count, const RSCheckParam & param) override + { + auto rs_index = getRSIndex(param, attr); + if (!rs_index) + return RSResults(pack_count, RSResult::Some); + + if (value.isNull()) + { + return rs_index->minmax->checkIsNull(start_pack, pack_count); + } + else + { + return rs_index->minmax->checkCmp(start_pack, pack_count, value, rs_index->type); + } + } + + ColumnRangePtr buildSets(const google::protobuf::RepeatedPtrField & index_infos) override + { + if (value.isNull()) + return UnsupportedColumnRange::create(); + + if (auto set = IntegerSet::createValueSet(attr.type, {value}); set) + { + auto iter = std::find_if(index_infos.begin(), index_infos.end(), [&](const auto & info) { + return info.index_type() == tipb::ColumnarIndexType::TypeInverted + && info.inverted_query_info().column_id() == attr.col_id; + }); + if (iter != index_infos.end()) + return SingleColumnRange::create( + iter->inverted_query_info().column_id(), + iter->inverted_query_info().index_id(), + set); + } + return UnsupportedColumnRange::create(); + } +}; + +} // namespace DB::DM \ No newline at end of file diff --git a/dbms/src/Storages/DeltaMerge/Filter/RSOperator.cpp b/dbms/src/Storages/DeltaMerge/Filter/RSOperator.cpp index 1e57a9f8de8..0460f80583e 100644 --- a/dbms/src/Storages/DeltaMerge/Filter/RSOperator.cpp +++ b/dbms/src/Storages/DeltaMerge/Filter/RSOperator.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -45,6 +46,7 @@ RSOperatorPtr createLess(const Attr & attr, const Field & value) RSOperatorPtr createLessEqual(const Attr & attr, const Field & value) { return std::make_shared(attr, value); } RSOperatorPtr createLike(const Attr & attr, const Field & value) { return std::make_shared(attr, value); } RSOperatorPtr createNot(const RSOperatorPtr & op) { return std::make_shared(op); } +RSOperatorPtr createNullEqual(const Attr & attr, const Field & value) { return std::make_shared(attr, value); } RSOperatorPtr createNotEqual(const Attr & attr, const Field & value) { return std::make_shared(attr, value); } RSOperatorPtr createOr(const RSOperators & children) { return std::make_shared(children); } RSOperatorPtr createIsNull(const Attr & attr) { return std::make_shared(attr);} diff --git a/dbms/src/Storages/DeltaMerge/Filter/RSOperator.h b/dbms/src/Storages/DeltaMerge/Filter/RSOperator.h index b1a98837836..a4984afd172 100644 --- a/dbms/src/Storages/DeltaMerge/Filter/RSOperator.h +++ b/dbms/src/Storages/DeltaMerge/Filter/RSOperator.h @@ -181,6 +181,7 @@ RSOperatorPtr createAnd(const RSOperators & children); // compare RSOperatorPtr createEqual(const Attr & attr, const Field & value); RSOperatorPtr createNotEqual(const Attr & attr, const Field & value); +RSOperatorPtr createNullEqual(const Attr & attr, const Field & value); RSOperatorPtr createGreater(const Attr & attr, const Field & value); RSOperatorPtr createGreaterEqual(const Attr & attr, const Field & value); RSOperatorPtr createLess(const Attr & attr, const Field & value); diff --git a/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp b/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp index b434c9446f5..311a6ebe8ae 100644 --- a/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp +++ b/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.cpp @@ -237,6 +237,8 @@ inline RSOperatorPtr parseTiCompareExpr( // return createLessEqual(attr, values[0]); case FilterParser::RSFilterType::In: return createIn(attr, values); + case FilterParser::RSFilterType::NullEqual: + return createNullEqual(attr, values[0]); default: return createUnsupported(fmt::format("Unknown compare type: {}", tipb::ExprType_Name(expr.tp()))); } @@ -308,6 +310,7 @@ RSOperatorPtr parseTiExpr( case FilterParser::RSFilterType::Less: case FilterParser::RSFilterType::LessEqual: case FilterParser::RSFilterType::In: + case FilterParser::RSFilterType::NullEqual: return parseTiCompareExpr(expr, filter_type, scan_column_infos, id_to_attr, timezone_info); case FilterParser::RSFilterType::IsNull: @@ -583,13 +586,14 @@ std::unordered_map FilterParser {tipb::ScalarFuncSig::NEDuration, FilterParser::RSFilterType::NotEqual}, {tipb::ScalarFuncSig::NEJson, FilterParser::RSFilterType::NotEqual}, - //{tipb::ScalarFuncSig::NullEQInt, "cast"}, - //{tipb::ScalarFuncSig::NullEQReal, "cast"}, - //{tipb::ScalarFuncSig::NullEQString, "cast"}, - //{tipb::ScalarFuncSig::NullEQDecimal, "cast"}, - //{tipb::ScalarFuncSig::NullEQTime, "cast"}, - //{tipb::ScalarFuncSig::NullEQDuration, "cast"}, - //{tipb::ScalarFuncSig::NullEQJson, "cast"}, + {tipb::ScalarFuncSig::NullEQInt, FilterParser::RSFilterType::NullEqual}, + {tipb::ScalarFuncSig::NullEQReal, FilterParser::RSFilterType::NullEqual}, + {tipb::ScalarFuncSig::NullEQString, FilterParser::RSFilterType::NullEqual}, + {tipb::ScalarFuncSig::NullEQDecimal, FilterParser::RSFilterType::NullEqual}, + {tipb::ScalarFuncSig::NullEQTime, FilterParser::RSFilterType::NullEqual}, + {tipb::ScalarFuncSig::NullEQDuration, FilterParser::RSFilterType::NullEqual}, + {tipb::ScalarFuncSig::NullEQJson, FilterParser::RSFilterType::NullEqual}, + // {tipb::ScalarFuncSig::PlusReal, "plus"}, // {tipb::ScalarFuncSig::PlusDecimal, "plus"}, diff --git a/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.h b/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.h index 2752afd070b..336b4506754 100644 --- a/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.h +++ b/dbms/src/Storages/DeltaMerge/FilterParser/FilterParser.h @@ -83,6 +83,8 @@ class FilterParser // NotLike, TiDB will convert it to Not(Like) IsNull, + + NullEqual, }; static std::unordered_map scalar_func_rs_filter_map; diff --git a/dbms/src/Storages/DeltaMerge/Index/tests/gtest_dm_minmax_index.cpp b/dbms/src/Storages/DeltaMerge/Index/tests/gtest_dm_minmax_index.cpp index 4aec61395ce..200e6118900 100644 --- a/dbms/src/Storages/DeltaMerge/Index/tests/gtest_dm_minmax_index.cpp +++ b/dbms/src/Storages/DeltaMerge/Index/tests/gtest_dm_minmax_index.cpp @@ -1147,6 +1147,173 @@ RSOperatorPtr generateIsNullOperator(MinMaxTestDatatype data_type) } } +RSOperatorPtr generateNullEqualOperator(MinMaxTestDatatype data_type, bool is_match, bool compare_with_null) +{ + switch (data_type) + { + case Test_Int64: + { + if (compare_with_null) + { + return createNullEqual(attr("Int64"), Field{}); + } + if (is_match) + { + return createNullEqual(attr("Int64"), Field(Int64_Match_DATA)); + } + else + { + return createNullEqual(attr("Int64"), Field(Int64_Smaller_DATA)); + } + } + case Test_Nullable_Int64: + { + if (compare_with_null) + { + return createNullEqual(attr("Nullable(Int64)"), Field{}); + } + if (is_match) + { + return createNullEqual(attr("Nullable(Int64)"), Field(Int64_Match_DATA)); + } + else + { + return createNullEqual(attr("Nullable(Int64)"), Field(Int64_Smaller_DATA)); + } + } + case Test_Date: + { + if (compare_with_null) + { + return createNullEqual(attr("Date"), Field{}); + } + if (is_match) + { + return createNullEqual(attr("Date"), Field(Date_Match_DATA)); + } + else + { + return createNullEqual(attr("Date"), Field(Date_Smaller_DATA)); + } + } + case Test_Nullable_Date: + { + if (compare_with_null) + { + return createNullEqual(attr("Nullable(Date)"), Field{}); + } + if (is_match) + { + return createNullEqual(attr("Nullable(Date)"), Field(Date_Match_DATA)); + } + else + { + return createNullEqual(attr("Nullable(Date)"), Field(Date_Smaller_DATA)); + } + } + case Test_DateTime: + { + if (compare_with_null) + { + return createNullEqual(attr("DateTime"), Field{}); + } + if (is_match) + { + return createNullEqual(attr("DateTime"), Field(DateTime_Match_DATA)); + } + else + { + return createNullEqual(attr("DateTime"), Field(DateTime_Smaller_DATA)); + } + } + case Test_Nullable_DateTime: + { + if (compare_with_null) + { + return createNullEqual(attr("Nullable(DateTime)"), Field{}); + } + if (is_match) + { + return createNullEqual(attr("Nullable(DateTime)"), Field(DateTime_Match_DATA)); + } + else + { + return createNullEqual(attr("Nullable(DateTime)"), Field(DateTime_Smaller_DATA)); + } + } + case Test_MyDateTime: + { + if (compare_with_null) + { + return createNullEqual(attr("MyDateTime"), Field{}); + } + if (is_match) + { + return createNullEqual(attr("MyDateTime"), Field(parseMyDateTime(MyDateTime_Match_DATE))); + } + else + { + return createNullEqual(attr("MyDateTime"), Field(parseMyDateTime(MyDateTime_Smaller_DATE))); + } + } + case Test_Nullable_MyDateTime: + { + if (compare_with_null) + { + return createNullEqual(attr("Nullable(MyDateTime)"), Field{}); + } + if (is_match) + { + return createNullEqual(attr("Nullable(MyDateTime)"), Field(parseMyDateTime(MyDateTime_Match_DATE))); + } + else + { + return createNullEqual(attr("Nullable(MyDateTime)"), Field(parseMyDateTime(MyDateTime_Smaller_DATE))); + } + } + case Test_Decimal64: + { + if (compare_with_null) + { + return createNullEqual(attr("Decimal(20,5)"), Field{}); + } + if (is_match) + { + return createNullEqual( + attr("Decimal(20,5)"), + Field(DecimalField(getDecimal64(Decimal_Match_DATA), 5))); + } + else + { + return createNullEqual( + attr("Decimal(20,5)"), + Field(DecimalField(getDecimal64(Decimal_UnMatch_DATA), 5))); + } + } + case Test_Nullable_Decimal64: + { + if (compare_with_null) + { + return createNullEqual(attr("Nullable(Decimal(20,5))"), Field{}); + } + if (is_match) + { + return createNullEqual( + attr("Nullable(Decimal(20,5))"), + Field(DecimalField(getDecimal64(Decimal_Match_DATA), 5))); + } + else + { + return createNullEqual( + attr("Nullable(Decimal(20,5))"), + Field(DecimalField(getDecimal64(Decimal_UnMatch_DATA), 5))); + } + } + default: + throw Exception("Unknown data type"); + } +} + RSOperatorPtr generateRSOperator(MinMaxTestDatatype data_type, MinMaxTestOperator rs_operator, bool is_match) { switch (rs_operator) @@ -1875,6 +2042,147 @@ try CATCH +TEST_F(MinMaxIndexTest, NullEqual) +try +{ + const auto * case_name = ::testing::UnitTest::GetInstance()->current_test_info()->name(); + + for (size_t datatype = Test_Int64; datatype < Test_Decimal64; datatype++) + { + { + // not null data, compare with value (match) + auto type_value_pair = generateTypeValue(static_cast(datatype), false); + ASSERT_EQ( + true, + checkMatch( + case_name, + *context, + type_value_pair.first, + type_value_pair.second, + generateNullEqualOperator( + static_cast(datatype), + true, + false))); + } + { + // not null data, compare with value (not match) + auto type_value_pair = generateTypeValue(static_cast(datatype), false); + ASSERT_EQ( + false, + checkMatch( + case_name, + *context, + type_value_pair.first, + type_value_pair.second, + generateNullEqualOperator( + static_cast(datatype), + false, + false))); + } + { + // not null data, compare with null (not match) + auto type_value_pair = generateTypeValue(static_cast(datatype), false); + ASSERT_EQ( + false, + checkMatch( + case_name, + *context, + type_value_pair.first, + type_value_pair.second, + generateNullEqualOperator( + static_cast(datatype), + true, + true))); + } + { + if (!isNullableDateType(static_cast(datatype))) + { + continue; + } + // has null data, compare with null (match) + auto type_value_pair = generateTypeValue(static_cast(datatype), true); + ASSERT_EQ( + true, + checkMatch( + case_name, + *context, + type_value_pair.first, + type_value_pair.second, + generateNullEqualOperator( + static_cast(datatype), + true, + true))); + } + { + // has null data, compare with value (match) + auto type_value_pair = generateTypeValue(static_cast(datatype), true); + ASSERT_EQ( + true, + checkMatch( + case_name, + *context, + type_value_pair.first, + type_value_pair.second, + generateNullEqualOperator( + static_cast(datatype), + true, + false))); + } + { + // has null data, compare with value (not match) + auto type_value_pair = generateTypeValue(static_cast(datatype), true); + ASSERT_EQ( + false, + checkMatch( + case_name, + *context, + type_value_pair.first, + type_value_pair.second, + generateNullEqualOperator( + static_cast(datatype), + false, + false))); + } + } + + // datatypes which not support minmax index + for (size_t datatype = Test_Decimal64; datatype < Test_Max; datatype++) + { + { + // not null data, compare with value (match) + auto type_value_pair = generateTypeValue(static_cast(datatype), false); + ASSERT_EQ( + true, + checkMatch( + case_name, + *context, + type_value_pair.first, + type_value_pair.second, + generateNullEqualOperator( + static_cast(datatype), + true, + false))); + } + { + // not null data, compare with value (not match) - should still return true + // because no minmax index is available + auto type_value_pair = generateTypeValue(static_cast(datatype), false); + ASSERT_EQ( + true, + checkMatch( + case_name, + *context, + type_value_pair.first, + type_value_pair.second, + generateNullEqualOperator( + static_cast(datatype), + false, + false))); + } + } +} +CATCH + TEST_F(MinMaxIndexTest, checkPKMatch) try { From 75dd7087094f37a4399a3187e156dff0a17d5565 Mon Sep 17 00:00:00 2001 From: yy <736262857@qq.com> Date: Fri, 3 Jul 2026 12:38:09 +0000 Subject: [PATCH 3/5] fix: address CodeRabbit comments for NullEq pushdown - Add missing test scenarios for unsupported datatypes - Use DecimalField::operator== instead of string-based comparison - Remove NullEQJson from scalar function mapping Ref: pingcap#5102 Signed-off-by: yy <736262857@qq.com> --- dbms/src/Flash/Coprocessor/DAGUtils.cpp | 1 - dbms/src/Functions/tests/gtest_nulleq.cpp | 62 ++++++----------- .../Index/tests/gtest_dm_minmax_index.cpp | 68 +++++++++++++++++++ 3 files changed, 89 insertions(+), 42 deletions(-) diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index 62fd065dfae..068413f89af 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -232,7 +232,6 @@ const std::unordered_map scalar_func_map({ {tipb::ScalarFuncSig::NullEQDecimal, "nullEq"}, {tipb::ScalarFuncSig::NullEQTime, "nullEq"}, {tipb::ScalarFuncSig::NullEQDuration, "nullEq"}, - {tipb::ScalarFuncSig::NullEQJson, "nullEq"}, {tipb::ScalarFuncSig::PlusReal, "plus"}, {tipb::ScalarFuncSig::PlusDecimal, "plus"}, diff --git a/dbms/src/Functions/tests/gtest_nulleq.cpp b/dbms/src/Functions/tests/gtest_nulleq.cpp index 33b0337197c..837ba9a180c 100644 --- a/dbms/src/Functions/tests/gtest_nulleq.cpp +++ b/dbms/src/Functions/tests/gtest_nulleq.cpp @@ -112,47 +112,27 @@ class TestNullEq : public DB::tests::FunctionTest bool equals = (col1_field == col2_field); if (col1_field.toString().find("Decimal") != std::string::npos && col2_field.toString().find("Decimal") != std::string::npos) { - /// Ugly Fix of Decimal. - auto decimal_string = [&](const Field & value) -> DB::String { - switch (value.getType()) - { - case Field::Types::Which::Decimal32: - { - auto v = safeGet>(value); - return v.toString(); - } - case Field::Types::Which::Decimal64: - { - auto v = safeGet>(value); - return v.toString(); - } - case Field::Types::Which::Decimal128: - { - auto v = safeGet>(value); - return v.toString(); - } - case Field::Types::Which::Decimal256: - { - auto v = safeGet>(value); - return v.toString(); - } - default: - throw Exception("Unsupported with data type."); - } - }; - - /// I know all the tested decimal have actually the scale of 2. - /// So they are just substring-relationship. - const auto decimal_string_col1 = decimal_string(col1_field); - const auto decimal_string_col2 = decimal_string(col2_field); - if (decimal_string_col1.size() > decimal_string_col2.size()) - { - equals = (decimal_string_col1.find(decimal_string_col2) != std::string::npos); - } - else - { - equals = (decimal_string_col2.find(decimal_string_col1) != std::string::npos); - } + auto v1 = safeGet>(col1_field); + auto v2 = safeGet>(col2_field); + equals = decimalEqual(v1.getValue(), v2.getValue(), v1.getScale(), v2.getScale()); + } + else if (col1_field.getType() == Field::Types::Which::Decimal64 && col2_field.getType() == Field::Types::Which::Decimal64) + { + auto v1 = safeGet>(col1_field); + auto v2 = safeGet>(col2_field); + equals = decimalEqual(v1.getValue(), v2.getValue(), v1.getScale(), v2.getScale()); + } + else if (col1_field.getType() == Field::Types::Which::Decimal128 && col2_field.getType() == Field::Types::Which::Decimal128) + { + auto v1 = safeGet>(col1_field); + auto v2 = safeGet>(col2_field); + equals = decimalEqual(v1.getValue(), v2.getValue(), v1.getScale(), v2.getScale()); + } + else if (col1_field.getType() == Field::Types::Which::Decimal256 && col2_field.getType() == Field::Types::Which::Decimal256) + { + auto v1 = safeGet>(col1_field); + auto v2 = safeGet>(col2_field); + equals = decimalEqual(v1.getValue(), v2.getValue(), v1.getScale(), v2.getScale()); } col_result.column->get(equals, result); result_column->insert(result); diff --git a/dbms/src/Storages/DeltaMerge/Index/tests/gtest_dm_minmax_index.cpp b/dbms/src/Storages/DeltaMerge/Index/tests/gtest_dm_minmax_index.cpp index 200e6118900..eea9bef1af9 100644 --- a/dbms/src/Storages/DeltaMerge/Index/tests/gtest_dm_minmax_index.cpp +++ b/dbms/src/Storages/DeltaMerge/Index/tests/gtest_dm_minmax_index.cpp @@ -2179,6 +2179,74 @@ try false, false))); } + { + // not null data, compare with null - should still return true + // because no minmax index is available + auto type_value_pair = generateTypeValue(static_cast(datatype), false); + ASSERT_EQ( + true, + checkMatch( + case_name, + *context, + type_value_pair.first, + type_value_pair.second, + generateNullEqualOperator( + static_cast(datatype), + true, + true))); + } + { + // has null data, compare with value (match) - should still return true + // because no minmax index is available + if (!isNullableDateType(static_cast(datatype))) + { + continue; + } + auto type_value_pair = generateTypeValue(static_cast(datatype), true); + ASSERT_EQ( + true, + checkMatch( + case_name, + *context, + type_value_pair.first, + type_value_pair.second, + generateNullEqualOperator( + static_cast(datatype), + true, + false))); + } + { + // has null data, compare with value (not match) - should still return true + // because no minmax index is available + auto type_value_pair = generateTypeValue(static_cast(datatype), true); + ASSERT_EQ( + true, + checkMatch( + case_name, + *context, + type_value_pair.first, + type_value_pair.second, + generateNullEqualOperator( + static_cast(datatype), + false, + false))); + } + { + // has null data, compare with null (match) - should still return true + // because no minmax index is available + auto type_value_pair = generateTypeValue(static_cast(datatype), true); + ASSERT_EQ( + true, + checkMatch( + case_name, + *context, + type_value_pair.first, + type_value_pair.second, + generateNullEqualOperator( + static_cast(datatype), + true, + true))); + } } } CATCH From 721827d240e2e5ba4c641a6f729e0fd8f49627ff Mon Sep 17 00:00:00 2001 From: yy <736262857@qq.com> Date: Fri, 3 Jul 2026 12:49:55 +0000 Subject: [PATCH 4/5] fix: use precise type check for Decimal equality in NullEq test Replace fragile string-based Decimal detection with direct Field::Types::Which comparison. Signed-off-by: yy <736262857@qq.com> --- dbms/src/Functions/tests/gtest_nulleq.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbms/src/Functions/tests/gtest_nulleq.cpp b/dbms/src/Functions/tests/gtest_nulleq.cpp index 837ba9a180c..bb3a0ccddc6 100644 --- a/dbms/src/Functions/tests/gtest_nulleq.cpp +++ b/dbms/src/Functions/tests/gtest_nulleq.cpp @@ -110,7 +110,7 @@ class TestNullEq : public DB::tests::FunctionTest Field col2_field; col_2.column->get(i, col2_field); bool equals = (col1_field == col2_field); - if (col1_field.toString().find("Decimal") != std::string::npos && col2_field.toString().find("Decimal") != std::string::npos) + if (col1_field.getType() == Field::Types::Which::Decimal32 && col2_field.getType() == Field::Types::Which::Decimal32) { auto v1 = safeGet>(col1_field); auto v2 = safeGet>(col2_field); From 9f77226745a48102cdcf44f49f0c5bd72ab55bec Mon Sep 17 00:00:00 2001 From: yy <736262857@qq.com> Date: Fri, 3 Jul 2026 14:43:53 +0000 Subject: [PATCH 5/5] Fix else branch in testNullEqFunction to properly compare values Signed-off-by: yy <736262857@qq.com> --- dbms/src/Functions/tests/gtest_nulleq.cpp | 55 ++++++++++------------- 1 file changed, 23 insertions(+), 32 deletions(-) diff --git a/dbms/src/Functions/tests/gtest_nulleq.cpp b/dbms/src/Functions/tests/gtest_nulleq.cpp index bb3a0ccddc6..b536755f64c 100644 --- a/dbms/src/Functions/tests/gtest_nulleq.cpp +++ b/dbms/src/Functions/tests/gtest_nulleq.cpp @@ -15,7 +15,10 @@ #include #include #include +#include +#include #include +#include #include #include #include @@ -104,42 +107,30 @@ class TestNullEq : public DB::tests::FunctionTest } else { - /// Both not Null. Make a comparison then. - Field col1_field; - col_1.column->get(i, col1_field); - Field col2_field; - col_2.column->get(i, col2_field); - bool equals = (col1_field == col2_field); - if (col1_field.getType() == Field::Types::Which::Decimal32 && col2_field.getType() == Field::Types::Which::Decimal32) - { - auto v1 = safeGet>(col1_field); - auto v2 = safeGet>(col2_field); - equals = decimalEqual(v1.getValue(), v2.getValue(), v1.getScale(), v2.getScale()); - } - else if (col1_field.getType() == Field::Types::Which::Decimal64 && col2_field.getType() == Field::Types::Which::Decimal64) - { - auto v1 = safeGet>(col1_field); - auto v2 = safeGet>(col2_field); - equals = decimalEqual(v1.getValue(), v2.getValue(), v1.getScale(), v2.getScale()); - } - else if (col1_field.getType() == Field::Types::Which::Decimal128 && col2_field.getType() == Field::Types::Which::Decimal128) - { - auto v1 = safeGet>(col1_field); - auto v2 = safeGet>(col2_field); - equals = decimalEqual(v1.getValue(), v2.getValue(), v1.getScale(), v2.getScale()); - } - else if (col1_field.getType() == Field::Types::Which::Decimal256 && col2_field.getType() == Field::Types::Which::Decimal256) - { - auto v1 = safeGet>(col1_field); - auto v2 = safeGet>(col2_field); - equals = decimalEqual(v1.getValue(), v2.getValue(), v1.getScale(), v2.getScale()); - } - col_result.column->get(equals, result); + Field field1; + Field field2; + col_1.column->get(i, field1); + col_2.column->get(i, field2); + + DataTypePtr type1 = col_1.type; + DataTypePtr type2 = col_2.type; + if (type1->isNullable()) + type1 = removeNullable(type1); + if (type2->isNullable()) + type2 = removeNullable(type2); + + DataTypePtr super_type = getLeastSupertype({type1, type2}); + field1 = convertFieldToType(field1, *super_type, type1.get()); + field2 = convertFieldToType(field2, *super_type, type2.get()); + + bool equal = field1 == field2; + col_result.column->get(equal, result); result_column->insert(result); } } ColumnWithTypeAndName expected{std::move(result_column), result_type, ""}; - ASSERT_COLUMN_EQ(expected, executeNullEq(col_1, col_2)); + auto actual = executeNullEq(col_1, col_2); + ASSERT_COLUMN_EQ(expected, actual); } } }