You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/02/21 10:37:30 UTC
[arrow] branch master updated: ARROW-11651: [Rust][DataFusion]
Implement Postgres String Functions: Length Functions
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new abf967d ARROW-11651: [Rust][DataFusion] Implement Postgres String Functions: Length Functions
abf967d is described below
commit abf967dcaff34f0a7663dec2cad67a25b6bf04ee
Author: Mike Seddon <se...@gmail.com>
AuthorDate: Sun Feb 21 05:35:40 2021 -0500
ARROW-11651: [Rust][DataFusion] Implement Postgres String Functions: Length Functions
Splitting up https://github.com/apache/arrow/pull/9243
This implements the following functions:
- String functions
- [x] bit_Length
- [x] char_length
- [x] character_length
- [x] length
- [x] octet_length
Closes #9509 from seddonm1/length-functions
Lead-authored-by: Mike Seddon <se...@gmail.com>
Co-authored-by: Jorge C. Leitao <jo...@gmail.com>
Signed-off-by: Andrew Lamb <an...@nerdnetworks.org>
---
rust/arrow/Cargo.toml | 4 +
.../benches/bit_length_kernel.rs} | 51 ++--
rust/arrow/src/compute/kernels/length.rs | 268 +++++++++++++++++----
rust/datafusion/Cargo.toml | 1 +
rust/datafusion/README.md | 6 +-
rust/datafusion/src/logical_plan/expr.rs | 22 +-
rust/datafusion/src/logical_plan/mod.rs | 11 +-
rust/datafusion/src/physical_plan/functions.rs | 263 +++++++++++++++-----
.../src/physical_plan/string_expressions.rs | 29 ++-
rust/datafusion/src/prelude.rs | 6 +-
10 files changed, 513 insertions(+), 148 deletions(-)
diff --git a/rust/arrow/Cargo.toml b/rust/arrow/Cargo.toml
index 0b14b5b..5ab1f8c 100644
--- a/rust/arrow/Cargo.toml
+++ b/rust/arrow/Cargo.toml
@@ -115,6 +115,10 @@ name = "length_kernel"
harness = false
[[bench]]
+name = "bit_length_kernel"
+harness = false
+
+[[bench]]
name = "sort_kernel"
harness = false
diff --git a/rust/datafusion/src/prelude.rs b/rust/arrow/benches/bit_length_kernel.rs
similarity index 50%
copy from rust/datafusion/src/prelude.rs
copy to rust/arrow/benches/bit_length_kernel.rs
index 4575de1..51d3134 100644
--- a/rust/datafusion/src/prelude.rs
+++ b/rust/arrow/benches/bit_length_kernel.rs
@@ -13,23 +13,34 @@
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
-// under the License.pub},
-
-//! A "prelude" for users of the datafusion crate.
-//!
-//! Like the standard library's prelude, this module simplifies importing of
-//! common items. Unlike the standard prelude, the contents of this module must
-//! be imported manually:
-//!
-//! ```
-//! use datafusion::prelude::*;
-//! ```
-
-pub use crate::dataframe::DataFrame;
-pub use crate::execution::context::{ExecutionConfig, ExecutionContext};
-pub use crate::logical_plan::{
- array, avg, col, concat, count, create_udf, in_list, length, lit, lower, ltrim, max,
- md5, min, rtrim, sha224, sha256, sha384, sha512, sum, trim, upper, JoinType,
- Partitioning,
-};
-pub use crate::physical_plan::csv::CsvReadOptions;
+// under the License.
+
+#[macro_use]
+extern crate criterion;
+use criterion::Criterion;
+
+extern crate arrow;
+
+use arrow::{array::*, compute::kernels::length::bit_length};
+
+fn bench_bit_length(array: &StringArray) {
+ criterion::black_box(bit_length(array).unwrap());
+}
+
+fn add_benchmark(c: &mut Criterion) {
+ fn double_vec<T: Clone>(v: Vec<T>) -> Vec<T> {
+ [&v[..], &v[..]].concat()
+ }
+
+ // double ["hello", " ", "world", "!"] 10 times
+ let mut values = vec!["one", "on", "o", ""];
+ for _ in 0..10 {
+ values = double_vec(values);
+ }
+ let array = StringArray::from(values);
+
+ c.bench_function("bit_length", |b| b.iter(|| bench_bit_length(&array)));
+}
+
+criterion_group!(benches, add_benchmark);
+criterion_main!(benches);
diff --git a/rust/arrow/src/compute/kernels/length.rs b/rust/arrow/src/compute/kernels/length.rs
index 740bb2b..ed1fda4 100644
--- a/rust/arrow/src/compute/kernels/length.rs
+++ b/rust/arrow/src/compute/kernels/length.rs
@@ -17,26 +17,33 @@
//! Defines kernel for length of a string array
-use crate::{array::*, buffer::Buffer};
use crate::{
- datatypes::DataType,
+ array::*,
+ buffer::Buffer,
+ datatypes::{ArrowNativeType, ArrowPrimitiveType},
+};
+use crate::{
+ datatypes::{DataType, Int32Type, Int64Type},
error::{ArrowError, Result},
};
use std::sync::Arc;
-#[allow(clippy::unnecessary_wraps)]
-fn length_string<OffsetSize>(array: &Array, data_type: DataType) -> Result<ArrayRef>
+fn unary_offsets_string<O, F>(
+ array: &GenericStringArray<O>,
+ data_type: DataType,
+ op: F,
+) -> ArrayRef
where
- OffsetSize: OffsetSizeTrait,
+ O: StringOffsetSizeTrait + ArrowNativeType,
+ F: Fn(O) -> O,
{
// note: offsets are stored as u8, but they can be interpreted as OffsetSize
let offsets = &array.data_ref().buffers()[0];
// this is a 30% improvement over iterating over u8s and building OffsetSize, which
// justifies the usage of `unsafe`.
- let slice: &[OffsetSize] =
- &unsafe { offsets.typed_data::<OffsetSize>() }[array.offset()..];
+ let slice: &[O] = &unsafe { offsets.typed_data::<O>() }[array.offset()..];
- let lengths = slice.windows(2).map(|offset| offset[1] - offset[0]);
+ let lengths = slice.windows(2).map(|offset| op(offset[1] - offset[0]));
// JUSTIFICATION
// Benefit
@@ -60,18 +67,45 @@ where
vec![buffer],
vec![],
);
- Ok(make_array(Arc::new(data)))
+ make_array(Arc::new(data))
}
-/// Returns an array of Int32/Int64 denoting the number of characters in each string in the array.
+fn octet_length<O: StringOffsetSizeTrait, T: ArrowPrimitiveType>(
+ array: &dyn Array,
+) -> ArrayRef
+where
+ T::Native: StringOffsetSizeTrait,
+{
+ let array = array
+ .as_any()
+ .downcast_ref::<GenericStringArray<O>>()
+ .unwrap();
+ unary_offsets_string::<O, _>(array, T::DATA_TYPE, |x| x)
+}
+
+fn bit_length_impl<O: StringOffsetSizeTrait, T: ArrowPrimitiveType>(
+ array: &dyn Array,
+) -> ArrayRef
+where
+ T::Native: StringOffsetSizeTrait,
+{
+ let array = array
+ .as_any()
+ .downcast_ref::<GenericStringArray<O>>()
+ .unwrap();
+ let bits_in_bytes = O::from_usize(8).unwrap();
+ unary_offsets_string::<O, _>(array, T::DATA_TYPE, |x| x * bits_in_bytes)
+}
+
+/// Returns an array of Int32/Int64 denoting the number of bytes in each string in the array.
///
/// * this only accepts StringArray/Utf8 and LargeString/LargeUtf8
/// * length of null is null.
/// * length is in number of bytes
pub fn length(array: &Array) -> Result<ArrayRef> {
match array.data_type() {
- DataType::Utf8 => length_string::<i32>(array, DataType::Int32),
- DataType::LargeUtf8 => length_string::<i64>(array, DataType::Int64),
+ DataType::Utf8 => Ok(octet_length::<i32, Int32Type>(array)),
+ DataType::LargeUtf8 => Ok(octet_length::<i64, Int64Type>(array)),
_ => Err(ArrowError::ComputeError(format!(
"length not supported for {:?}",
array.data_type()
@@ -79,11 +113,27 @@ pub fn length(array: &Array) -> Result<ArrayRef> {
}
}
+/// Returns an array of Int32/Int64 denoting the number of bits in each string in the array.
+///
+/// * this only accepts StringArray/Utf8 and LargeString/LargeUtf8
+/// * bit_length of null is null.
+/// * bit_length is in number of bits
+pub fn bit_length(array: &Array) -> Result<ArrayRef> {
+ match array.data_type() {
+ DataType::Utf8 => Ok(bit_length_impl::<i32, Int32Type>(array)),
+ DataType::LargeUtf8 => Ok(bit_length_impl::<i64, Int64Type>(array)),
+ _ => Err(ArrowError::ComputeError(format!(
+ "bit_length not supported for {:?}",
+ array.data_type()
+ ))),
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
- fn cases() -> Vec<(Vec<&'static str>, usize, Vec<i32>)> {
+ fn length_cases() -> Vec<(Vec<&'static str>, usize, Vec<i32>)> {
fn double_vec<T: Clone>(v: Vec<T>) -> Vec<T> {
[&v[..], &v[..]].concat()
}
@@ -105,34 +155,38 @@ mod tests {
}
#[test]
- fn test_string() -> Result<()> {
- cases().into_iter().try_for_each(|(input, len, expected)| {
- let array = StringArray::from(input);
- let result = length(&array)?;
- assert_eq!(len, result.len());
- let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
- expected.iter().enumerate().for_each(|(i, value)| {
- assert_eq!(*value, result.value(i));
- });
- Ok(())
- })
+ fn length_test_string() -> Result<()> {
+ length_cases()
+ .into_iter()
+ .try_for_each(|(input, len, expected)| {
+ let array = StringArray::from(input);
+ let result = length(&array)?;
+ assert_eq!(len, result.len());
+ let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
+ expected.iter().enumerate().for_each(|(i, value)| {
+ assert_eq!(*value, result.value(i));
+ });
+ Ok(())
+ })
}
#[test]
- fn test_large_string() -> Result<()> {
- cases().into_iter().try_for_each(|(input, len, expected)| {
- let array = LargeStringArray::from(input);
- let result = length(&array)?;
- assert_eq!(len, result.len());
- let result = result.as_any().downcast_ref::<Int64Array>().unwrap();
- expected.iter().enumerate().for_each(|(i, value)| {
- assert_eq!(*value as i64, result.value(i));
- });
- Ok(())
- })
- }
-
- fn null_cases() -> Vec<(Vec<Option<&'static str>>, usize, Vec<Option<i32>>)> {
+ fn length_test_large_string() -> Result<()> {
+ length_cases()
+ .into_iter()
+ .try_for_each(|(input, len, expected)| {
+ let array = LargeStringArray::from(input);
+ let result = length(&array)?;
+ assert_eq!(len, result.len());
+ let result = result.as_any().downcast_ref::<Int64Array>().unwrap();
+ expected.iter().enumerate().for_each(|(i, value)| {
+ assert_eq!(*value as i64, result.value(i));
+ });
+ Ok(())
+ })
+ }
+
+ fn length_null_cases() -> Vec<(Vec<Option<&'static str>>, usize, Vec<Option<i32>>)> {
vec![(
vec![Some("one"), None, Some("three"), Some("four")],
4,
@@ -141,8 +195,8 @@ mod tests {
}
#[test]
- fn null_string() -> Result<()> {
- null_cases()
+ fn length_null_string() -> Result<()> {
+ length_null_cases()
.into_iter()
.try_for_each(|(input, len, expected)| {
let array = StringArray::from(input);
@@ -157,8 +211,8 @@ mod tests {
}
#[test]
- fn null_large_string() -> Result<()> {
- null_cases()
+ fn length_null_large_string() -> Result<()> {
+ length_null_cases()
.into_iter()
.try_for_each(|(input, len, expected)| {
let array = LargeStringArray::from(input);
@@ -179,7 +233,7 @@ mod tests {
/// Tests that length is not valid for u64.
#[test]
- fn wrong_type() {
+ fn length_wrong_type() {
let array: UInt64Array = vec![1u64].into();
assert!(length(&array).is_err());
@@ -187,7 +241,7 @@ mod tests {
/// Tests with an offset
#[test]
- fn offsets() -> Result<()> {
+ fn length_offsets() -> Result<()> {
let a = StringArray::from(vec!["hello", " ", "world"]);
let b = make_array(
ArrayData::builder(DataType::Utf8)
@@ -203,4 +257,130 @@ mod tests {
Ok(())
}
+
+ fn bit_length_cases() -> Vec<(Vec<&'static str>, usize, Vec<i32>)> {
+ fn double_vec<T: Clone>(v: Vec<T>) -> Vec<T> {
+ [&v[..], &v[..]].concat()
+ }
+
+ // a large array
+ let mut values = vec!["one", "on", "o", ""];
+ let mut expected = vec![24, 16, 8, 0];
+ for _ in 0..10 {
+ values = double_vec(values);
+ expected = double_vec(expected);
+ }
+
+ vec![
+ (vec!["hello", " ", "world", "!"], 4, vec![40, 8, 40, 8]),
+ (vec!["💖"], 1, vec![32]),
+ (vec!["josé"], 1, vec![40]),
+ (values, 4096, expected),
+ ]
+ }
+
+ #[test]
+ fn bit_length_test_string() -> Result<()> {
+ bit_length_cases()
+ .into_iter()
+ .try_for_each(|(input, len, expected)| {
+ let array = StringArray::from(input);
+ let result = bit_length(&array)?;
+ assert_eq!(len, result.len());
+ let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
+ expected.iter().enumerate().for_each(|(i, value)| {
+ assert_eq!(*value, result.value(i));
+ });
+ Ok(())
+ })
+ }
+
+ #[test]
+ fn bit_length_test_large_string() -> Result<()> {
+ bit_length_cases()
+ .into_iter()
+ .try_for_each(|(input, len, expected)| {
+ let array = LargeStringArray::from(input);
+ let result = bit_length(&array)?;
+ assert_eq!(len, result.len());
+ let result = result.as_any().downcast_ref::<Int64Array>().unwrap();
+ expected.iter().enumerate().for_each(|(i, value)| {
+ assert_eq!(*value as i64, result.value(i));
+ });
+ Ok(())
+ })
+ }
+
+ fn bit_length_null_cases() -> Vec<(Vec<Option<&'static str>>, usize, Vec<Option<i32>>)>
+ {
+ vec![(
+ vec![Some("one"), None, Some("three"), Some("four")],
+ 4,
+ vec![Some(24), None, Some(40), Some(32)],
+ )]
+ }
+
+ #[test]
+ fn bit_length_null_string() -> Result<()> {
+ bit_length_null_cases()
+ .into_iter()
+ .try_for_each(|(input, len, expected)| {
+ let array = StringArray::from(input);
+ let result = bit_length(&array)?;
+ assert_eq!(len, result.len());
+ let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
+
+ let expected: Int32Array = expected.into();
+ assert_eq!(expected.data(), result.data());
+ Ok(())
+ })
+ }
+
+ #[test]
+ fn bit_length_null_large_string() -> Result<()> {
+ bit_length_null_cases()
+ .into_iter()
+ .try_for_each(|(input, len, expected)| {
+ let array = LargeStringArray::from(input);
+ let result = bit_length(&array)?;
+ assert_eq!(len, result.len());
+ let result = result.as_any().downcast_ref::<Int64Array>().unwrap();
+
+ // convert to i64
+ let expected: Int64Array = expected
+ .iter()
+ .map(|e| e.map(|e| e as i64))
+ .collect::<Vec<_>>()
+ .into();
+ assert_eq!(expected.data(), result.data());
+ Ok(())
+ })
+ }
+
+ /// Tests that bit_length is not valid for u64.
+ #[test]
+ fn bit_length_wrong_type() {
+ let array: UInt64Array = vec![1u64].into();
+
+ assert!(bit_length(&array).is_err());
+ }
+
+ /// Tests with an offset
+ #[test]
+ fn bit_length_offsets() -> Result<()> {
+ let a = StringArray::from(vec!["hello", " ", "world"]);
+ let b = make_array(
+ ArrayData::builder(DataType::Utf8)
+ .len(2)
+ .offset(1)
+ .buffers(a.data_ref().buffers().to_vec())
+ .build(),
+ );
+ let result = bit_length(b.as_ref())?;
+
+ let expected = Int32Array::from(vec![8, 40]);
+ assert_eq!(expected.data(), result.data());
+
+ Ok(())
+ }
}
diff --git a/rust/datafusion/Cargo.toml b/rust/datafusion/Cargo.toml
index ea55666..11cc63b 100644
--- a/rust/datafusion/Cargo.toml
+++ b/rust/datafusion/Cargo.toml
@@ -64,6 +64,7 @@ log = "^0.4"
md-5 = "^0.9.1"
sha2 = "^0.9.1"
ordered-float = "2.0"
+unicode-segmentation = "^1.7.1"
[dev-dependencies]
rand = "0.8"
diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md
index 7a12250..b4cb043 100644
--- a/rust/datafusion/README.md
+++ b/rust/datafusion/README.md
@@ -57,7 +57,11 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI
- [x] UDAFs (user-defined aggregate functions)
- [x] Common math functions
- String functions
- - [x] Length
+ - [x] bit_Length
+ - [x] char_length
+ - [x] character_length
+ - [x] length
+ - [x] octet_length
- [x] Concatenate
- Miscellaneous/Boolean functions
- [x] nullif
diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs
index ffed843..16c01ed 100644
--- a/rust/datafusion/src/logical_plan/expr.rs
+++ b/rust/datafusion/src/logical_plan/expr.rs
@@ -850,6 +850,8 @@ macro_rules! unary_scalar_expr {
}
// generate methods for creating the supported unary expressions
+
+// math functions
unary_scalar_expr!(Sqrt, sqrt);
unary_scalar_expr!(Sin, sin);
unary_scalar_expr!(Cos, cos);
@@ -867,24 +869,22 @@ unary_scalar_expr!(Exp, exp);
unary_scalar_expr!(Log, ln);
unary_scalar_expr!(Log2, log2);
unary_scalar_expr!(Log10, log10);
+
+// string functions
+unary_scalar_expr!(BitLength, bit_length);
+unary_scalar_expr!(CharacterLength, character_length);
+unary_scalar_expr!(CharacterLength, length);
unary_scalar_expr!(Lower, lower);
-unary_scalar_expr!(Trim, trim);
unary_scalar_expr!(Ltrim, ltrim);
-unary_scalar_expr!(Rtrim, rtrim);
-unary_scalar_expr!(Upper, upper);
unary_scalar_expr!(MD5, md5);
+unary_scalar_expr!(OctetLength, octet_length);
+unary_scalar_expr!(Rtrim, rtrim);
unary_scalar_expr!(SHA224, sha224);
unary_scalar_expr!(SHA256, sha256);
unary_scalar_expr!(SHA384, sha384);
unary_scalar_expr!(SHA512, sha512);
-
-/// returns the length of a string in bytes
-pub fn length(e: Expr) -> Expr {
- Expr::ScalarFunction {
- fun: functions::BuiltinScalarFunction::Length,
- args: vec![e],
- }
-}
+unary_scalar_expr!(Trim, trim);
+unary_scalar_expr!(Upper, upper);
/// returns the concatenation of string expressions
pub fn concat(args: Vec<Expr>) -> Expr {
diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs
index fbad5e2..6244387 100644
--- a/rust/datafusion/src/logical_plan/mod.rs
+++ b/rust/datafusion/src/logical_plan/mod.rs
@@ -34,11 +34,12 @@ pub use builder::LogicalPlanBuilder;
pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema};
pub use display::display_schema;
pub use expr::{
- abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col,
- combine_filters, concat, cos, count, count_distinct, create_udaf, create_udf, exp,
- exprlist_to_fields, floor, in_list, length, lit, ln, log10, log2, lower, ltrim, max,
- md5, min, or, round, rtrim, sha224, sha256, sha384, sha512, signum, sin, sqrt, sum,
- tan, trim, trunc, upper, when, Expr, ExpressionVisitor, Literal, Recursion,
+ abs, acos, and, array, asin, atan, avg, binary_expr, bit_length, case, ceil,
+ character_length, col, combine_filters, concat, cos, count, count_distinct,
+ create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list, length, lit, ln,
+ log10, log2, lower, ltrim, max, md5, min, octet_length, or, round, rtrim, sha224,
+ sha256, sha384, sha512, signum, sin, sqrt, sum, tan, trim, trunc, upper, when, Expr,
+ ExpressionVisitor, Literal, Recursion,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs
index c5cd01f..baacf94 100644
--- a/rust/datafusion/src/physical_plan/functions.rs
+++ b/rust/datafusion/src/physical_plan/functions.rs
@@ -45,9 +45,9 @@ use crate::{
};
use arrow::{
array::ArrayRef,
- compute::kernels::length::length,
+ compute::kernels::length::{bit_length, length},
datatypes::TimeUnit,
- datatypes::{DataType, Field, Schema},
+ datatypes::{DataType, Field, Int32Type, Int64Type, Schema},
record_batch::RecordBatch,
};
use fmt::{Debug, Formatter};
@@ -118,8 +118,6 @@ pub enum BuiltinScalarFunction {
Abs,
/// signum
Signum,
- /// length
- Length,
/// concat
Concat,
/// lower
@@ -150,6 +148,12 @@ pub enum BuiltinScalarFunction {
SHA384,
/// SHA512,
SHA512,
+ /// bit_length
+ BitLength,
+ /// character_length
+ CharacterLength,
+ /// octet_length
+ OctetLength,
}
impl fmt::Display for BuiltinScalarFunction {
@@ -180,9 +184,6 @@ impl FromStr for BuiltinScalarFunction {
"truc" => BuiltinScalarFunction::Trunc,
"abs" => BuiltinScalarFunction::Abs,
"signum" => BuiltinScalarFunction::Signum,
- "length" => BuiltinScalarFunction::Length,
- "char_length" => BuiltinScalarFunction::Length,
- "character_length" => BuiltinScalarFunction::Length,
"concat" => BuiltinScalarFunction::Concat,
"lower" => BuiltinScalarFunction::Lower,
"trim" => BuiltinScalarFunction::Trim,
@@ -198,6 +199,11 @@ impl FromStr for BuiltinScalarFunction {
"sha256" => BuiltinScalarFunction::SHA256,
"sha384" => BuiltinScalarFunction::SHA384,
"sha512" => BuiltinScalarFunction::SHA512,
+ "bit_length" => BuiltinScalarFunction::BitLength,
+ "octet_length" => BuiltinScalarFunction::OctetLength,
+ "length" => BuiltinScalarFunction::CharacterLength,
+ "char_length" => BuiltinScalarFunction::CharacterLength,
+ "character_length" => BuiltinScalarFunction::CharacterLength,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
@@ -231,16 +237,6 @@ pub fn return_type(
// the return type of the built in function.
// Some built-in functions' return type depends on the incoming type.
match fun {
- BuiltinScalarFunction::Length => Ok(match arg_types[0] {
- DataType::LargeUtf8 => DataType::Int64,
- DataType::Utf8 => DataType::Int32,
- _ => {
- // this error is internal as `data_types` should have captured this.
- return Err(DataFusionError::Internal(
- "The length function can only accept strings.".to_string(),
- ));
- }
- }),
BuiltinScalarFunction::Concat => Ok(DataType::Utf8),
BuiltinScalarFunction::Lower => Ok(match arg_types[0] {
DataType::LargeUtf8 => DataType::LargeUtf8,
@@ -357,6 +353,36 @@ pub fn return_type(
));
}
}),
+ BuiltinScalarFunction::BitLength => Ok(match arg_types[0] {
+ DataType::LargeUtf8 => DataType::Int64,
+ DataType::Utf8 => DataType::Int32,
+ _ => {
+ // this error is internal as `data_types` should have captured this.
+ return Err(DataFusionError::Internal(
+ "The bit_length function can only accept strings.".to_string(),
+ ));
+ }
+ }),
+ BuiltinScalarFunction::CharacterLength => Ok(match arg_types[0] {
+ DataType::LargeUtf8 => DataType::Int64,
+ DataType::Utf8 => DataType::Int32,
+ _ => {
+ // this error is internal as `data_types` should have captured this.
+ return Err(DataFusionError::Internal(
+ "The character_length function can only accept strings.".to_string(),
+ ));
+ }
+ }),
+ BuiltinScalarFunction::OctetLength => Ok(match arg_types[0] {
+ DataType::LargeUtf8 => DataType::Int64,
+ DataType::Utf8 => DataType::Int32,
+ _ => {
+ // this error is internal as `data_types` should have captured this.
+ return Err(DataFusionError::Internal(
+ "The octet_length function can only accept strings.".to_string(),
+ ));
+ }
+ }),
_ => Ok(DataType::Float64),
}
}
@@ -392,7 +418,41 @@ pub fn create_physical_expr(
BuiltinScalarFunction::SHA256 => crypto_expressions::sha256,
BuiltinScalarFunction::SHA384 => crypto_expressions::sha384,
BuiltinScalarFunction::SHA512 => crypto_expressions::sha512,
- BuiltinScalarFunction::Length => |args| match &args[0] {
+ BuiltinScalarFunction::Concat => string_expressions::concatenate,
+ BuiltinScalarFunction::Lower => string_expressions::lower,
+ BuiltinScalarFunction::Trim => string_expressions::trim,
+ BuiltinScalarFunction::Ltrim => string_expressions::ltrim,
+ BuiltinScalarFunction::Rtrim => string_expressions::rtrim,
+ BuiltinScalarFunction::Upper => string_expressions::upper,
+ BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp,
+ BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc,
+ BuiltinScalarFunction::Array => array_expressions::array,
+ BuiltinScalarFunction::BitLength => |args| match &args[0] {
+ ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)),
+ ColumnarValue::Scalar(v) => match v {
+ ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32(
+ v.as_ref().map(|x| (x.len() * 8) as i32),
+ ))),
+ ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar(
+ ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)),
+ )),
+ _ => unreachable!(),
+ },
+ },
+ BuiltinScalarFunction::CharacterLength => |args| match args[0].data_type() {
+ DataType::Utf8 => make_scalar_function(
+ string_expressions::character_length::<Int32Type>,
+ )(args),
+ DataType::LargeUtf8 => make_scalar_function(
+ string_expressions::character_length::<Int64Type>,
+ )(args),
+ other => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {:?} for function character_length",
+ other,
+ ))),
+ },
+ BuiltinScalarFunction::OctetLength => |args| match &args[0] {
+ ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)),
ColumnarValue::Scalar(v) => match v {
ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32(
v.as_ref().map(|x| x.len() as i32),
@@ -402,17 +462,7 @@ pub fn create_physical_expr(
)),
_ => unreachable!(),
},
- ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)),
},
- BuiltinScalarFunction::Concat => string_expressions::concatenate,
- BuiltinScalarFunction::Lower => string_expressions::lower,
- BuiltinScalarFunction::Trim => string_expressions::trim,
- BuiltinScalarFunction::Ltrim => string_expressions::ltrim,
- BuiltinScalarFunction::Rtrim => string_expressions::rtrim,
- BuiltinScalarFunction::Upper => string_expressions::upper,
- BuiltinScalarFunction::ToTimestamp => datetime_expressions::to_timestamp,
- BuiltinScalarFunction::DateTrunc => datetime_expressions::date_trunc,
- BuiltinScalarFunction::Array => array_expressions::array,
});
// coerce
let args = coerce(args, input_schema, &signature(fun))?;
@@ -439,7 +489,9 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]),
BuiltinScalarFunction::Upper
| BuiltinScalarFunction::Lower
- | BuiltinScalarFunction::Length
+ | BuiltinScalarFunction::BitLength
+ | BuiltinScalarFunction::CharacterLength
+ | BuiltinScalarFunction::OctetLength
| BuiltinScalarFunction::Trim
| BuiltinScalarFunction::Ltrim
| BuiltinScalarFunction::Rtrim
@@ -617,48 +669,135 @@ mod tests {
};
use arrow::{
array::{
- ArrayRef, FixedSizeListArray, Float64Array, Int32Array, StringArray,
+ Array, ArrayRef, FixedSizeListArray, Float64Array, Int32Array, StringArray,
UInt32Array, UInt64Array,
},
datatypes::Field,
record_batch::RecordBatch,
};
- fn generic_test_math(value: ScalarValue, expected: &str) -> Result<()> {
- // any type works here: we evaluate against a literal of `value`
- let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
- let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from(vec![1]))];
-
- let arg = lit(value);
-
- let expr = create_physical_expr(&BuiltinScalarFunction::Exp, &[arg], &schema)?;
-
- // type is correct
- assert_eq!(expr.data_type(&schema)?, DataType::Float64);
-
- // evaluate works
- let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?;
- let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
-
- // downcast works
- let result = result.as_any().downcast_ref::<Float64Array>().unwrap();
-
- // value is correct
- assert_eq!(result.value(0).to_string(), expected);
-
- Ok(())
+ /// $FUNC function to test
+ /// $ARGS arguments (vec) to pass to function
+ /// $EXPECTED a Result<Option<$EXPECTED_TYPE>> where Result allows testing errors and Option allows testing Null
+ /// $EXPECTED_TYPE is the expected value type
+ /// $DATA_TYPE is the function to test result type
+ /// $ARRAY_TYPE is the column type after function applied
+ macro_rules! test_function {
+ ($FUNC:ident, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $DATA_TYPE: ident, $ARRAY_TYPE:ident) => {
+ // used to provide type annotation
+ let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
+
+ // any type works here: we evaluate against a literal of `value`
+ let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
+ let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from(vec![1]))];
+
+ let expr =
+ create_physical_expr(&BuiltinScalarFunction::$FUNC, $ARGS, &schema)?;
+
+ // type is correct
+ assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TYPE);
+
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()), columns)?;
+
+ match expected {
+ Ok(expected) => {
+ let result = expr.evaluate(&batch)?;
+ let result = result.into_array(batch.num_rows());
+ let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap();
+
+ // value is correct
+ match expected {
+ Some(v) => assert_eq!(result.value(0), v),
+ None => assert!(result.is_null(0)),
+ };
+ }
+ Err(expected_error) => {
+ // evaluate is expected error - cannot use .expect_err() due to Debug not being implemented
+ match expr.evaluate(&batch) {
+ Ok(_) => assert!(false, "expected error"),
+ Err(error) => {
+ assert_eq!(error.to_string(), expected_error.to_string());
+ }
+ }
+ }
+ };
+ };
}
#[test]
- fn test_math_function() -> Result<()> {
- // 2.71828182845904523536... : https://oeis.org/A001113
- let exp_f64 = "2.718281828459045";
- let exp_f32 = "2.7182817459106445";
- generic_test_math(ScalarValue::from(1i32), exp_f64)?;
- generic_test_math(ScalarValue::from(1u32), exp_f64)?;
- generic_test_math(ScalarValue::from(1u64), exp_f64)?;
- generic_test_math(ScalarValue::from(1f64), exp_f64)?;
- generic_test_math(ScalarValue::from(1f32), exp_f32)?;
+ fn test_functions() -> Result<()> {
+ test_function!(
+ CharacterLength,
+ &[lit(ScalarValue::Utf8(Some("chars".to_string())))],
+ Ok(Some(5)),
+ i32,
+ Int32,
+ Int32Array
+ );
+ test_function!(
+ CharacterLength,
+ &[lit(ScalarValue::Utf8(Some("josé".to_string())))],
+ Ok(Some(4)),
+ i32,
+ Int32,
+ Int32Array
+ );
+ test_function!(
+ CharacterLength,
+ &[lit(ScalarValue::Utf8(Some("".to_string())))],
+ Ok(Some(0)),
+ i32,
+ Int32,
+ Int32Array
+ );
+ test_function!(
+ CharacterLength,
+ &[lit(ScalarValue::Utf8(None))],
+ Ok(None),
+ i32,
+ Int32,
+ Int32Array
+ );
+ test_function!(
+ Exp,
+ &[lit(ScalarValue::Int32(Some(1)))],
+ Ok(Some((1.0_f64).exp())),
+ f64,
+ Float64,
+ Float64Array
+ );
+ test_function!(
+ Exp,
+ &[lit(ScalarValue::UInt32(Some(1)))],
+ Ok(Some((1.0_f64).exp())),
+ f64,
+ Float64,
+ Float64Array
+ );
+ test_function!(
+ Exp,
+ &[lit(ScalarValue::UInt64(Some(1)))],
+ Ok(Some((1.0_f64).exp())),
+ f64,
+ Float64,
+ Float64Array
+ );
+ test_function!(
+ Exp,
+ &[lit(ScalarValue::Float64(Some(1.0)))],
+ Ok(Some((1.0_f64).exp())),
+ f64,
+ Float64,
+ Float64Array
+ );
+ test_function!(
+ Exp,
+ &[lit(ScalarValue::Float32(Some(1.0)))],
+ Ok(Some((1.0_f32).exp() as f64)),
+ f64,
+ Float64,
+ Float64Array
+ );
Ok(())
}
diff --git a/rust/datafusion/src/physical_plan/string_expressions.rs b/rust/datafusion/src/physical_plan/string_expressions.rs
index a4ccef0..81d2c67 100644
--- a/rust/datafusion/src/physical_plan/string_expressions.rs
+++ b/rust/datafusion/src/physical_plan/string_expressions.rs
@@ -24,9 +24,13 @@ use crate::{
scalar::ScalarValue,
};
use arrow::{
- array::{Array, GenericStringArray, StringArray, StringOffsetSizeTrait},
- datatypes::DataType,
+ array::{
+ Array, ArrayRef, GenericStringArray, PrimitiveArray, StringArray,
+ StringOffsetSizeTrait,
+ },
+ datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType},
};
+use unicode_segmentation::UnicodeSegmentation;
use super::ColumnarValue;
@@ -115,6 +119,27 @@ where
}
}
+/// Returns number of characters in the string.
+/// character_length('josé') = 4
+pub fn character_length<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<ArrayRef>
+where
+ T::Native: StringOffsetSizeTrait,
+{
+ let string_array: &GenericStringArray<T::Native> = args[0]
+ .as_any()
+ .downcast_ref::<GenericStringArray<T::Native>>()
+ .unwrap();
+
+ let result = string_array
+ .iter()
+ .map(|x| {
+ x.map(|x: &str| T::Native::from_usize(x.graphemes(true).count()).unwrap())
+ })
+ .collect::<PrimitiveArray<T>>();
+
+ Ok(Arc::new(result) as ArrayRef)
+}
+
/// concatenate string columns together.
pub fn concatenate(args: &[ColumnarValue]) -> Result<ColumnarValue> {
// downcast all arguments to strings
diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs
index 4575de1..26e03c7 100644
--- a/rust/datafusion/src/prelude.rs
+++ b/rust/datafusion/src/prelude.rs
@@ -28,8 +28,8 @@
pub use crate::dataframe::DataFrame;
pub use crate::execution::context::{ExecutionConfig, ExecutionContext};
pub use crate::logical_plan::{
- array, avg, col, concat, count, create_udf, in_list, length, lit, lower, ltrim, max,
- md5, min, rtrim, sha224, sha256, sha384, sha512, sum, trim, upper, JoinType,
- Partitioning,
+ array, avg, bit_length, character_length, col, concat, count, create_udf, in_list,
+ length, lit, lower, ltrim, max, md5, min, octet_length, rtrim, sha224, sha256,
+ sha384, sha512, sum, trim, upper, JoinType, Partitioning,
};
pub use crate::physical_plan::csv::CsvReadOptions;