Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 82 additions & 8 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ use arrow::datatypes::{
use arrow::error::ArrowError;
use arrow_buffer::i256;
use datafusion_common::types::NativeType;
use datafusion_common::{
Result, ScalarValue, exec_err, internal_err, plan_datafusion_err, plan_err,
};
use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
Expand Down Expand Up @@ -100,6 +98,17 @@ impl LogFunc {
}
}

/// Ranks float types by width so the wider one drives the result type.
/// Non-float types (e.g. decimals, which are computed in f64) rank as Float64.
#[inline]
fn float_rank(data_type: &DataType) -> u8 {
match data_type {
DataType::Float16 => 1,
DataType::Float32 => 2,
_ => 3,
}
}

/// Checks if the base is valid for the efficient integer logarithm algorithm.
#[inline]
fn is_valid_integer_base(base: f64) -> bool {
Expand Down Expand Up @@ -195,12 +204,17 @@ impl ScalarUDFImpl for LogFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
// Check last argument (value)
match &arg_types.last().ok_or(plan_datafusion_err!("No args"))? {
DataType::Float16 => Ok(DataType::Float16),
DataType::Float32 => Ok(DataType::Float32),
_ => Ok(DataType::Float64),
if arg_types.is_empty() {
return plan_err!("No args");
}
// The result type is the widest float among the arguments so that, e.g.,
// log(Float64, Float32) is computed and returned in Float64 instead of
// narrowing the base to Float32 and losing precision.
Ok(match arg_types.iter().map(float_rank).max().unwrap_or(0) {
1 => DataType::Float16,
2 => DataType::Float32,
_ => DataType::Float64,
})
}

fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
Expand Down Expand Up @@ -246,6 +260,19 @@ impl ScalarUDFImpl for LogFunc {
};
let value = value.to_array(args.number_rows)?;

// If the base is a wider float than the value (e.g. log(Float64, Float32)),
// widen the value to the result type so the base is not narrowed to the
// value's type, which would lose precision. Decimal values are always
// computed in f64, so they are left untouched here.
let return_type = args.return_type();
let value = if matches!(value.data_type(), DataType::Float16 | DataType::Float32)
&& float_rank(return_type) > float_rank(value.data_type())
{
arrow::compute::cast(&value, return_type)?
} else {
value
};

let output: ArrayRef = match value.data_type() {
DataType::Float16 => {
calculate_binary_math::<Float16Type, Float16Type, Float16Type, _>(
Expand Down Expand Up @@ -745,6 +772,53 @@ mod tests {
}
}
}
#[test]
fn test_log_f64_base_f32_value() {
// log(Float64 base, Float32 value) must widen the value to f64 and return
// Float64 instead of narrowing the base to f32 (see issue #22581).
assert_eq!(
LogFunc::new()
.return_type(&[DataType::Float64, DataType::Float32])
.unwrap(),
DataType::Float64
);

let arg_fields = vec![
Field::new("b", DataType::Float64, false).into(),
Field::new("x", DataType::Float32, false).into(),
];
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Scalar(ScalarValue::Float64(Some(16777217.0))), // base
ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num
],
arg_fields,
number_rows: 1,
return_field: Field::new("f", DataType::Float64, true).into(),
config_options: Arc::new(ConfigOptions::default()),
};
let result = LogFunc::new()
.invoke_with_args(args)
.expect("failed to initialize function log");

match result {
ColumnarValue::Array(arr) => {
let floats = as_float64_array(&arr)
.expect("failed to convert result to a Float64Array");
assert_eq!(floats.len(), 1);
let expected = 2.0_f64.log(16777217.0);
assert!(
(floats.value(0) - expected).abs() < 1e-15,
"got {}, expected {expected}",
floats.value(0)
);
}
ColumnarValue::Scalar(_) => {
panic!("Expected an array value")
}
}
}

#[test]
// Test log() simplification errors
fn test_log_simplify_errors() {
Expand Down
21 changes: 19 additions & 2 deletions datafusion/sqllogictest/test_files/math.slt
Original file line number Diff line number Diff line change
Expand Up @@ -924,18 +924,35 @@ physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/
query RT
SELECT log(2.5, arrow_cast(10.9, 'Float16')), arrow_typeof(log(2.5, arrow_cast(10.9, 'Float16')));
----
2.6074219 Float16
2.606835742462 Float64

query RT
SELECT log(2.5, 10.9::float), arrow_typeof(log(2.5, 10.9::float));
----
2.606992 Float32
2.606992159958 Float64

query RT
SELECT log(2.5, 10.9::double), arrow_typeof(log(2.5, 10.9::double));
----
2.606992198152 Float64

# A Float64 (or integer) base must not narrow a Float32 value to Float32; the
# value is widened to Float64 and a Float64 result is returned (issue #22581).
query RT
SELECT log(arrow_cast(16777217, 'Float64'), arrow_cast(2.0, 'Float32')), arrow_typeof(log(arrow_cast(16777217, 'Float64'), arrow_cast(2.0, 'Float32')));
----
0.041666666517 Float64

query RT
SELECT log(arrow_cast(16777217, 'Float64'), arrow_cast(2.0, 'Float16')), arrow_typeof(log(arrow_cast(16777217, 'Float64'), arrow_cast(2.0, 'Float16')));
----
0.041666666517 Float64

query RT
SELECT log(2, arrow_cast(2.0, 'Float32')), arrow_typeof(log(2, arrow_cast(2.0, 'Float32')));
----
1 Float64

# lcm with array and scalar

query I
Expand Down