Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 246 additions & 14 deletions datafusion/functions/src/math/trunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,32 @@
// specific language governing permissions and limitations
// under the License.

use std::ops::{Div, Mul};
use std::sync::Arc;

use crate::utils::make_scalar_function;
use crate::utils::{calculate_binary_decimal_math_cast, make_scalar_function};

use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
use arrow::datatypes::DataType::{Float32, Float64};
use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type};
use arrow::datatypes::DataType::{
Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64,
};
use arrow::datatypes::{
DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType,
Float32Type, Float64Type, Int64Type,
};
use datafusion_common::ScalarValue::Int64;
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_expr::TypeSignature::Exact;
use datafusion_common::types::{
NativeType, logical_float32, logical_float64, logical_int64,
};
use datafusion_common::{Result, ScalarValue, exec_err, plan_err};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
};
use datafusion_expr_common::signature::{Coercion, TypeSignature, TypeSignatureClass};
use datafusion_macros::user_doc;
use num_traits::{One, Zero, pow};

#[user_doc(
doc_section(label = "Math Functions"),
Expand Down Expand Up @@ -68,19 +78,38 @@ impl Default for TruncFunc {

impl TruncFunc {
pub fn new() -> Self {
use DataType::*;
let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
let decimal_places = Coercion::new_implicit(
TypeSignatureClass::Native(logical_int64()),
vec![TypeSignatureClass::Integer],
NativeType::Int64,
);
let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32()));
let float64 = Coercion::new_implicit(
TypeSignatureClass::Native(logical_float64()),
vec![TypeSignatureClass::Numeric],
NativeType::Float64,
);
Self {
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
// return the best approximation for it (in f64).
// We accept f32 because in this case it is clear that the best approximation
// will be as good as the number of digits in the number
// Decimal arguments are accepted to handle large values properly
signature: Signature::one_of(
vec![
Exact(vec![Float32, Int64]),
Exact(vec![Float64, Int64]),
Exact(vec![Float64]),
Exact(vec![Float32]),
TypeSignature::Coercible(vec![
decimal.clone(),
decimal_places.clone(),
]),
TypeSignature::Coercible(vec![decimal]),
TypeSignature::Coercible(vec![
float32.clone(),
decimal_places.clone(),
]),
TypeSignature::Coercible(vec![float32]),
TypeSignature::Coercible(vec![float64.clone(), decimal_places]),
TypeSignature::Coercible(vec![float64]),
],
Volatility::Immutable,
),
Expand All @@ -98,9 +127,16 @@ impl ScalarUDFImpl for TruncFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
match arg_types[0] {
match &arg_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
Float64 => Ok(Float64),
dt if dt.is_decimal() => Ok(dt.clone()),
DataType::Null => Ok(Float64),
_ => plan_err!(
"Unsupported data type {:?} for function {}",
arg_types[0],
self.name()
),
}
}

Expand Down Expand Up @@ -146,6 +182,55 @@ impl ScalarUDFImpl for TruncFunc {
compute_truncate32(*v, p)
}))),
),
(
ColumnarValue::Scalar(ScalarValue::Decimal32(
Some(v),
lprecision,
lscale,
)),
Some(p),
) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal32(
Some(compute_truncate_decimal::<Decimal32Type>(*v, *lscale, p)),
*lprecision,
*lscale,
))),
(
ColumnarValue::Scalar(ScalarValue::Decimal64(
Some(v),
lprecision,
lscale,
)),
Some(p),
) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal64(
Some(compute_truncate_decimal::<Decimal64Type>(*v, *lscale, p)),
*lprecision,
*lscale,
))),
(
ColumnarValue::Scalar(ScalarValue::Decimal128(
Some(v),
lprecision,
lscale,
)),
Some(p),
) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
Some(compute_truncate_decimal::<Decimal128Type>(*v, *lscale, p)),
*lprecision,
*lscale,
))),
(
ColumnarValue::Scalar(ScalarValue::Decimal256(
Some(v),
lprecision,
lscale,
)),
Some(p),
) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(
Some(compute_truncate_decimal::<Decimal256Type>(*v, *lscale, p)),
*lprecision,
*lscale,
))),

// Array path for everything else
_ => make_scalar_function(trunc, vec![])(&args.args),
}
Expand Down Expand Up @@ -234,6 +319,58 @@ fn trunc(args: &[ArrayRef]) -> Result<ArrayRef> {
}
_ => exec_err!("trunc function requires a scalar or array for precision"),
},
Decimal32(lprecision, lscale) => Ok(calculate_binary_decimal_math_cast::<
Decimal32Type,
Int64Type,
Decimal32Type,
_,
>(
num.as_ref(),
&precision,
|v, y| Ok(compute_truncate_decimal::<Decimal32Type>(v, *lscale, y)),
*lprecision,
*lscale,
&DataType::Int64,
)? as ArrayRef),
Decimal64(lprecision, lscale) => Ok(calculate_binary_decimal_math_cast::<
Decimal64Type,
Int64Type,
Decimal64Type,
_,
>(
num.as_ref(),
&precision,
|v, y| Ok(compute_truncate_decimal::<Decimal64Type>(v, *lscale, y)),
*lprecision,
*lscale,
&DataType::Int64,
)? as ArrayRef),
Decimal128(lprecision, lscale) => Ok(calculate_binary_decimal_math_cast::<
Decimal128Type,
Int64Type,
Decimal128Type,
_,
>(
num.as_ref(),
&precision,
|v, y| Ok(compute_truncate_decimal::<Decimal128Type>(v, *lscale, y)),
*lprecision,
*lscale,
&DataType::Int64,
)? as ArrayRef),
Decimal256(lprecision, lscale) => Ok(calculate_binary_decimal_math_cast::<
Decimal256Type,
Int64Type,
Decimal256Type,
_,
>(
num.as_ref(),
&precision,
|v, y| Ok(compute_truncate_decimal::<Decimal256Type>(v, *lscale, y)),
*lprecision,
*lscale,
&DataType::Int64,
)? as ArrayRef),
other => exec_err!("Unsupported data type {other:?} for function trunc"),
}
}
Expand All @@ -248,13 +385,52 @@ fn compute_truncate64(x: f64, y: i64) -> f64 {
(x * factor).trunc() / factor
}

/// Truncates a decimal value to `truncate_precision` fractional digits.
/// If `truncate_precision` is positive, clear that amount of trailing low-order digits
/// If `truncate_precision` is negative, it also clears digits before a decimal point
///
/// Example:
/// Truncating number 12.3456 (123456 as i128 with scale=4) to 1 digit produces 12.3.
/// It makes exp = 4-1 = 3; factor = 10^3 = 1000; result = (123456 / 1000) * 1000 = 123000
/// It is a decimal 12.3 with scale=4
///
/// Truncating number 12.3456 to -1 digit produces 10.0.
/// It makes exp = 4-(-1) = 5; factor = 10^5 = 100000; result = (123456 / 100000) * 100000 = 100000
/// It is a decimal 10.0 with scale=4
fn compute_truncate_decimal<T>(
x: T::Native,
scale: i8,
truncate_precision: i64,
) -> T::Native
where
T: DecimalType,
T::Native: Copy + From<i32> + One + Zero + Div<Output = T::Native> + Mul,
{
// How many trailing digits of decimal to clear
let exp = (scale as i64).saturating_sub(truncate_precision);
if exp <= 0 {
// Keep more digits than actually stored, so nothing to truncate
x
} else if exp >= T::MAX_PRECISION as i64 {
// Drop more digits that can be stored, return 0 without overflowing `pow`
T::Native::zero()
} else {
let base = T::Native::from(10_i32);
let exp = exp as usize;
let factor = pow::<T::Native>(base, exp);
// Result is (x / factor) * factor, so (x/factor) drops extra digits
(x / factor) * factor
}
}

#[cfg(test)]
mod test {
use std::sync::Arc;

use crate::math::trunc::trunc;
use crate::math::trunc::{compute_truncate_decimal, trunc};

use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
use arrow::datatypes::Decimal128Type;
use datafusion_common::cast::{as_float32_array, as_float64_array};

#[test]
Expand Down Expand Up @@ -328,4 +504,60 @@ mod test {
assert_eq!(floats.value(3), 123.0);
assert_eq!(floats.value(4), -321.0);
}

#[test]
fn test_compute_truncate_decimal128() {
// number 12.3456 (scale 4) truncated to 3 places = 12.345
assert_eq!(
compute_truncate_decimal::<Decimal128Type>(123_456, 4, 3),
123_450
);
// number 12.3456 (scale 4) truncated to 1 place = 12.3
assert_eq!(
compute_truncate_decimal::<Decimal128Type>(123_456, 4, 1),
123_000
);

// requesting more places = no change
assert_eq!(
compute_truncate_decimal::<Decimal128Type>(123_456, 4, 10),
123_456
);

// truncating to 0 places = whole number 12
assert_eq!(
compute_truncate_decimal::<Decimal128Type>(123_456, 4, 0),
120_000
);

// number 12.3456 (scale 2) truncated to -1 places = 10
assert_eq!(
compute_truncate_decimal::<Decimal128Type>(123_456, 4, -1),
100_000
);

// number 12.3456 (scale 2) truncated to -3 places = 0
assert_eq!(
compute_truncate_decimal::<Decimal128Type>(123_456, 4, -3),
0
);

// number 1234.56 (scale 2) truncated to -3 places = 1000
assert_eq!(
compute_truncate_decimal::<Decimal128Type>(123_456, 2, -3),
100_000
);

// out of scale
assert_eq!(
compute_truncate_decimal::<Decimal128Type>(123_456, 4, -900),
0
);

// truncation rounds towards zero: -12.3456 = -12.345
assert_eq!(
compute_truncate_decimal::<Decimal128Type>(-123_456, 4, 3),
-123_450
);
}
}
51 changes: 51 additions & 0 deletions datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1333,6 +1333,57 @@ from small_floats;
0.836 0.8 0.836
1 1 1

# trunc with decimals
query RT
select trunc(arrow_cast(3.1415, 'Decimal128(10,4)')), arrow_typeof(trunc(arrow_cast(3.1415, 'Decimal128(10,4)')));
----
3 Decimal128(10, 4)

# trunc with precision - decimals
query RRRRR rowsort
select
trunc(arrow_cast(4.267, 'Decimal32(8,3)'), 3),
trunc(arrow_cast(1.1234, 'Decimal64(18,6)'), 2),
trunc(arrow_cast(-1.1231, 'Decimal128(15,4)'), 6),
trunc(arrow_cast(1.2837284, 'Decimal256(35,7)'), 2),
trunc(arrow_cast(1.1, 'Decimal128(10,1)'), 0);
----
4.267 1.12 -1.1231 1.28 1

# trunc with negative precision should truncate digits left of decimal - decimal types
query RT
select trunc(arrow_cast(12345.678, 'Decimal128(10,3)'), -3),
arrow_typeof(trunc(arrow_cast(12345.678, 'Decimal128(10,3)'), -3));
----
12000 Decimal128(10, 3)

# trunc: coercion with a decimal argument and a non-int64 precision argument
query RT
select trunc(arrow_cast(1.2345678, 'Decimal128(20,14)'), arrow_cast(2, 'Int32')),
arrow_typeof(trunc(arrow_cast(1.2345678, 'Decimal128(20,14)'), arrow_cast(2, 'Int32')));
----
1.23 Decimal128(20, 14)

# trunc with columns and precision - decimal128
query RRR rowsort
select
trunc(arrow_cast(a, 'Decimal128(10,4)'), 0) as a0,
trunc(arrow_cast(b, 'Decimal128(10,4)'), 0) as b0,
trunc(arrow_cast(c, 'Decimal128(10,4)'), 0) as c0
from small_floats;
----
-1 NULL NULL
0 0 -1
0 0 0
0 0 1

# trunc issue #22512
query R
select trunc(CAST(9007199254740993 AS DECIMAL(20,0)));
----
9007199254740993


## bitwise and

# bitwise and with column and scalar
Expand Down
Loading