You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2023/01/26 15:36:26 UTC
[arrow-datafusion] branch master updated: Add decimal support to substrait serde (#5054)
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 552eea719 Add decimal support to substrait serde (#5054)
552eea719 is described below
commit 552eea719cba7e331a7955e15a3438ae4aa06ed0
Author: Andy Grove <an...@gmail.com>
AuthorDate: Thu Jan 26 08:36:19 2023 -0700
Add decimal support to substrait serde (#5054)
---
datafusion/substrait/src/consumer.rs | 100 +++++++++++++++++----------
datafusion/substrait/src/producer.rs | 9 ++-
datafusion/substrait/tests/roundtrip.rs | 18 ++++-
datafusion/substrait/tests/testdata/data.csv | 4 +-
4 files changed, 89 insertions(+), 42 deletions(-)
diff --git a/datafusion/substrait/src/consumer.rs b/datafusion/substrait/src/consumer.rs
index 80b400d52..7293ad967 100644
--- a/datafusion/substrait/src/consumer.rs
+++ b/datafusion/substrait/src/consumer.rs
@@ -606,44 +606,70 @@ pub async fn from_substrait_rex(
))),
}
}
- Some(RexType::Literal(lit)) => match &lit.literal_type {
- Some(LiteralType::I8(n)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8)))))
- }
- Some(LiteralType::I16(n)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16)))))
- }
- Some(LiteralType::I32(n)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n)))))
- }
- Some(LiteralType::I64(n)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n)))))
- }
- Some(LiteralType::Boolean(b)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b)))))
- }
- Some(LiteralType::Date(d)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d)))))
- }
- Some(LiteralType::Fp32(f)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f)))))
- }
- Some(LiteralType::Fp64(f)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f)))))
- }
- Some(LiteralType::String(s)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(Some(s.clone())))))
- }
- Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal(
- ScalarValue::Binary(Some(b.clone())),
- ))),
- _ => {
- return Err(DataFusionError::NotImplemented(format!(
- "Unsupported literal_type: {:?}",
- lit.literal_type
- )))
+ Some(RexType::Literal(lit)) => {
+ match &lit.literal_type {
+ Some(LiteralType::I8(n)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8)))))
+ }
+ Some(LiteralType::I16(n)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16)))))
+ }
+ Some(LiteralType::I32(n)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n)))))
+ }
+ Some(LiteralType::I64(n)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n)))))
+ }
+ Some(LiteralType::Boolean(b)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b)))))
+ }
+ Some(LiteralType::Date(d)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d)))))
+ }
+ Some(LiteralType::Fp32(f)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f)))))
+ }
+ Some(LiteralType::Fp64(f)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f)))))
+ }
+ Some(LiteralType::Decimal(d)) => {
+ let value: [u8; 16] = d.value.clone().try_into().or(Err(
+ DataFusionError::Substrait(
+ "Failed to parse decimal value".to_string(),
+ ),
+ ))?;
+ let p = d.precision.try_into().map_err(|e| {
+ DataFusionError::Substrait(format!(
+ "Failed to parse decimal precision: {}",
+ e
+ ))
+ })?;
+ let s = d.scale.try_into().map_err(|e| {
+ DataFusionError::Substrait(format!(
+ "Failed to parse decimal scale: {}",
+ e
+ ))
+ })?;
+ Ok(Arc::new(Expr::Literal(ScalarValue::Decimal128(
+ Some(std::primitive::i128::from_le_bytes(value)),
+ p,
+ s,
+ ))))
+ }
+ Some(LiteralType::String(s)) => {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(Some(s.clone())))))
+ }
+ Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal(
+ ScalarValue::Binary(Some(b.clone())),
+ ))),
+ _ => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Unsupported literal_type: {:?}",
+ lit.literal_type
+ )))
+ }
}
- },
+ }
_ => Err(DataFusionError::NotImplemented(
"unsupported rex_type".to_string(),
)),
diff --git a/datafusion/substrait/src/producer.rs b/datafusion/substrait/src/producer.rs
index a1748c3ff..163abbaa9 100644
--- a/datafusion/substrait/src/producer.rs
+++ b/datafusion/substrait/src/producer.rs
@@ -35,7 +35,7 @@ use substrait::proto::{
expression::{
field_reference::ReferenceType,
if_then::IfClause,
- literal::LiteralType,
+ literal::{Decimal, LiteralType},
mask_expression::{StructItem, StructSelect},
reference_segment, FieldReference, IfThen, Literal, MaskExpression,
ReferenceSegment, RexType, ScalarFunction,
@@ -579,6 +579,13 @@ pub fn to_substrait_rex(
ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)),
ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
+ ScalarValue::Decimal128(v, p, s) if v.is_some() => {
+ Some(LiteralType::Decimal(Decimal {
+ value: v.unwrap().to_le_bytes().to_vec(),
+ precision: *p as i32,
+ scale: *s as i32,
+ }))
+ }
ScalarValue::Utf8(Some(s)) => Some(LiteralType::String(s.clone())),
ScalarValue::LargeUtf8(Some(s)) => Some(LiteralType::String(s.clone())),
ScalarValue::Binary(Some(b)) => Some(LiteralType::Binary(b.clone())),
diff --git a/datafusion/substrait/tests/roundtrip.rs b/datafusion/substrait/tests/roundtrip.rs
index a819b2ba5..141f4eb6b 100644
--- a/datafusion/substrait/tests/roundtrip.rs
+++ b/datafusion/substrait/tests/roundtrip.rs
@@ -22,6 +22,7 @@ use datafusion_substrait::producer;
mod tests {
use crate::{consumer::from_substrait_plan, producer::to_substrait_plan};
+ use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::error::Result;
use datafusion::prelude::*;
use substrait::proto::extensions::simple_extension_declaration::MappingType;
@@ -95,6 +96,11 @@ mod tests {
roundtrip("SELECT a, c, avg(b) FROM data GROUP BY a, c").await
}
+ #[tokio::test]
+ async fn decimal_literal() -> Result<()> {
+ roundtrip("SELECT * FROM data WHERE b > 2.5").await
+ }
+
#[tokio::test]
async fn simple_distinct() -> Result<()> {
test_alias(
@@ -290,9 +296,17 @@ mod tests {
async fn create_context() -> Result<SessionContext> {
let ctx = SessionContext::new();
- ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::new())
+ let mut explicit_options = CsvReadOptions::new();
+ let schema = Schema::new(vec![
+ Field::new("a", DataType::Int64, true),
+ Field::new("b", DataType::Decimal128(5, 2), true),
+ Field::new("c", DataType::Date32, true),
+ Field::new("d", DataType::Boolean, true),
+ ]);
+ explicit_options.schema = Some(&schema);
+ ctx.register_csv("data", "tests/testdata/data.csv", explicit_options.clone())
.await?;
- ctx.register_csv("data2", "tests/testdata/data.csv", CsvReadOptions::new())
+ ctx.register_csv("data2", "tests/testdata/data.csv", explicit_options)
.await?;
Ok(ctx)
}
diff --git a/datafusion/substrait/tests/testdata/data.csv b/datafusion/substrait/tests/testdata/data.csv
index 4394789bc..b0fc71024 100644
--- a/datafusion/substrait/tests/testdata/data.csv
+++ b/datafusion/substrait/tests/testdata/data.csv
@@ -1,3 +1,3 @@
a,b,c,d
-1,2,2020-01-01,false
-3,4,2020-01-01,true
\ No newline at end of file
+1,2.0,2020-01-01,false
+3,4.5,2020-01-01,true
\ No newline at end of file