You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2023/01/26 15:36:26 UTC

[arrow-datafusion] branch master updated: Add decimal support to substrait serde (#5054)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 552eea719 Add decimal support to substrait serde (#5054)
552eea719 is described below

commit 552eea719cba7e331a7955e15a3438ae4aa06ed0
Author: Andy Grove <an...@gmail.com>
AuthorDate: Thu Jan 26 08:36:19 2023 -0700

    Add decimal support to substrait serde (#5054)
---
 datafusion/substrait/src/consumer.rs         | 100 +++++++++++++++++----------
 datafusion/substrait/src/producer.rs         |   9 ++-
 datafusion/substrait/tests/roundtrip.rs      |  18 ++++-
 datafusion/substrait/tests/testdata/data.csv |   4 +-
 4 files changed, 89 insertions(+), 42 deletions(-)

diff --git a/datafusion/substrait/src/consumer.rs b/datafusion/substrait/src/consumer.rs
index 80b400d52..7293ad967 100644
--- a/datafusion/substrait/src/consumer.rs
+++ b/datafusion/substrait/src/consumer.rs
@@ -606,44 +606,70 @@ pub async fn from_substrait_rex(
                 ))),
             }
         }
-        Some(RexType::Literal(lit)) => match &lit.literal_type {
-            Some(LiteralType::I8(n)) => {
-                Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8)))))
-            }
-            Some(LiteralType::I16(n)) => {
-                Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16)))))
-            }
-            Some(LiteralType::I32(n)) => {
-                Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n)))))
-            }
-            Some(LiteralType::I64(n)) => {
-                Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n)))))
-            }
-            Some(LiteralType::Boolean(b)) => {
-                Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b)))))
-            }
-            Some(LiteralType::Date(d)) => {
-                Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d)))))
-            }
-            Some(LiteralType::Fp32(f)) => {
-                Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f)))))
-            }
-            Some(LiteralType::Fp64(f)) => {
-                Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f)))))
-            }
-            Some(LiteralType::String(s)) => {
-                Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(Some(s.clone())))))
-            }
-            Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal(
-                ScalarValue::Binary(Some(b.clone())),
-            ))),
-            _ => {
-                return Err(DataFusionError::NotImplemented(format!(
-                    "Unsupported literal_type: {:?}",
-                    lit.literal_type
-                )))
+        Some(RexType::Literal(lit)) => {
+            match &lit.literal_type {
+                Some(LiteralType::I8(n)) => {
+                    Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8)))))
+                }
+                Some(LiteralType::I16(n)) => {
+                    Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16)))))
+                }
+                Some(LiteralType::I32(n)) => {
+                    Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n)))))
+                }
+                Some(LiteralType::I64(n)) => {
+                    Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n)))))
+                }
+                Some(LiteralType::Boolean(b)) => {
+                    Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b)))))
+                }
+                Some(LiteralType::Date(d)) => {
+                    Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d)))))
+                }
+                Some(LiteralType::Fp32(f)) => {
+                    Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f)))))
+                }
+                Some(LiteralType::Fp64(f)) => {
+                    Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f)))))
+                }
+                Some(LiteralType::Decimal(d)) => {
+                    let value: [u8; 16] = d.value.clone().try_into().or(Err(
+                        DataFusionError::Substrait(
+                            "Failed to parse decimal value".to_string(),
+                        ),
+                    ))?;
+                    let p = d.precision.try_into().map_err(|e| {
+                        DataFusionError::Substrait(format!(
+                            "Failed to parse decimal precision: {}",
+                            e
+                        ))
+                    })?;
+                    let s = d.scale.try_into().map_err(|e| {
+                        DataFusionError::Substrait(format!(
+                            "Failed to parse decimal scale: {}",
+                            e
+                        ))
+                    })?;
+                    Ok(Arc::new(Expr::Literal(ScalarValue::Decimal128(
+                        Some(std::primitive::i128::from_le_bytes(value)),
+                        p,
+                        s,
+                    ))))
+                }
+                Some(LiteralType::String(s)) => {
+                    Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(Some(s.clone())))))
+                }
+                Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal(
+                    ScalarValue::Binary(Some(b.clone())),
+                ))),
+                _ => {
+                    return Err(DataFusionError::NotImplemented(format!(
+                        "Unsupported literal_type: {:?}",
+                        lit.literal_type
+                    )))
+                }
             }
-        },
+        }
         _ => Err(DataFusionError::NotImplemented(
             "unsupported rex_type".to_string(),
         )),
diff --git a/datafusion/substrait/src/producer.rs b/datafusion/substrait/src/producer.rs
index a1748c3ff..163abbaa9 100644
--- a/datafusion/substrait/src/producer.rs
+++ b/datafusion/substrait/src/producer.rs
@@ -35,7 +35,7 @@ use substrait::proto::{
     expression::{
         field_reference::ReferenceType,
         if_then::IfClause,
-        literal::LiteralType,
+        literal::{Decimal, LiteralType},
         mask_expression::{StructItem, StructSelect},
         reference_segment, FieldReference, IfThen, Literal, MaskExpression,
         ReferenceSegment, RexType, ScalarFunction,
@@ -579,6 +579,13 @@ pub fn to_substrait_rex(
                 ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)),
                 ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
                 ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
+                ScalarValue::Decimal128(v, p, s) if v.is_some() => {
+                    Some(LiteralType::Decimal(Decimal {
+                        value: v.unwrap().to_le_bytes().to_vec(),
+                        precision: *p as i32,
+                        scale: *s as i32,
+                    }))
+                }
                 ScalarValue::Utf8(Some(s)) => Some(LiteralType::String(s.clone())),
                 ScalarValue::LargeUtf8(Some(s)) => Some(LiteralType::String(s.clone())),
                 ScalarValue::Binary(Some(b)) => Some(LiteralType::Binary(b.clone())),
diff --git a/datafusion/substrait/tests/roundtrip.rs b/datafusion/substrait/tests/roundtrip.rs
index a819b2ba5..141f4eb6b 100644
--- a/datafusion/substrait/tests/roundtrip.rs
+++ b/datafusion/substrait/tests/roundtrip.rs
@@ -22,6 +22,7 @@ use datafusion_substrait::producer;
 mod tests {
 
     use crate::{consumer::from_substrait_plan, producer::to_substrait_plan};
+    use datafusion::arrow::datatypes::{DataType, Field, Schema};
     use datafusion::error::Result;
     use datafusion::prelude::*;
     use substrait::proto::extensions::simple_extension_declaration::MappingType;
@@ -95,6 +96,11 @@ mod tests {
         roundtrip("SELECT a, c, avg(b) FROM data GROUP BY a, c").await
     }
 
+    #[tokio::test]
+    async fn decimal_literal() -> Result<()> {
+        roundtrip("SELECT * FROM data WHERE b > 2.5").await
+    }
+
     #[tokio::test]
     async fn simple_distinct() -> Result<()> {
         test_alias(
@@ -290,9 +296,17 @@ mod tests {
 
     async fn create_context() -> Result<SessionContext> {
         let ctx = SessionContext::new();
-        ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::new())
+        let mut explicit_options = CsvReadOptions::new();
+        let schema = Schema::new(vec![
+            Field::new("a", DataType::Int64, true),
+            Field::new("b", DataType::Decimal128(5, 2), true),
+            Field::new("c", DataType::Date32, true),
+            Field::new("d", DataType::Boolean, true),
+        ]);
+        explicit_options.schema = Some(&schema);
+        ctx.register_csv("data", "tests/testdata/data.csv", explicit_options.clone())
             .await?;
-        ctx.register_csv("data2", "tests/testdata/data.csv", CsvReadOptions::new())
+        ctx.register_csv("data2", "tests/testdata/data.csv", explicit_options)
             .await?;
         Ok(ctx)
     }
diff --git a/datafusion/substrait/tests/testdata/data.csv b/datafusion/substrait/tests/testdata/data.csv
index 4394789bc..b0fc71024 100644
--- a/datafusion/substrait/tests/testdata/data.csv
+++ b/datafusion/substrait/tests/testdata/data.csv
@@ -1,3 +1,3 @@
 a,b,c,d
-1,2,2020-01-01,false
-3,4,2020-01-01,true
\ No newline at end of file
+1,2.0,2020-01-01,false
+3,4.5,2020-01-01,true
\ No newline at end of file