You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/04/10 14:15:58 UTC
[arrow-datafusion] branch main updated: feat: extend substrait type support, including type variations (#5775)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 98197d5403 feat: extend substrait type support, including type variations (#5775)
98197d5403 is described below
commit 98197d5403a59bf3417aa9366d4a1b2f0adf232e
Author: Ruihang Xia <wa...@gmail.com>
AuthorDate: Mon Apr 10 22:15:50 2023 +0800
feat: extend substrait type support, including type variations (#5775)
* feat: extend types
Signed-off-by: Ruihang Xia <wa...@gmail.com>
* feat: extend literal
Signed-off-by: Ruihang Xia <wa...@gmail.com>
* fix clippy warnings
Signed-off-by: Ruihang Xia <wa...@gmail.com>
* add some cases
Signed-off-by: Ruihang Xia <wa...@gmail.com>
* clean up
Signed-off-by: Ruihang Xia <wa...@gmail.com>
* add documentations and accomplish test
Signed-off-by: Ruihang Xia <wa...@gmail.com>
---------
Signed-off-by: Ruihang Xia <wa...@gmail.com>
---
datafusion/substrait/src/lib.rs | 1 +
datafusion/substrait/src/logical_plan/consumer.rs | 404 +++++++++++++++------
datafusion/substrait/src/logical_plan/producer.rs | 397 +++++++++++++++++---
datafusion/substrait/src/variation_const.rs | 39 ++
.../substrait/tests/roundtrip_logical_plan.rs | 143 +++++++-
datafusion/substrait/tests/testdata/empty.csv | 0
6 files changed, 811 insertions(+), 173 deletions(-)
diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs
index 34c9c46edf..432553ec79 100644
--- a/datafusion/substrait/src/lib.rs
+++ b/datafusion/substrait/src/lib.rs
@@ -18,6 +18,7 @@
pub mod logical_plan;
pub mod physical_plan;
pub mod serializer;
+pub mod variation_const;
// Re-export substrait crate
pub use substrait;
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs
index 767c4a3937..607012bfd6 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -16,7 +16,7 @@
// under the License.
use async_recursion::async_recursion;
-use datafusion::arrow::datatypes::DataType;
+use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
use datafusion::common::{DFField, DFSchema, DFSchemaRef};
use datafusion::logical_expr::{
aggregate_function, window_function::find_df_window_func, BinaryExpr, Case, Expr,
@@ -32,6 +32,7 @@ use datafusion::{
prelude::{Column, SessionContext},
scalar::ScalarValue,
};
+use substrait::proto::expression::Literal;
use substrait::proto::{
aggregate_function::AggregationInvocation,
expression::{
@@ -56,6 +57,13 @@ use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
+use crate::variation_const::{
+ DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF,
+ DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF,
+ TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF,
+ TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
+};
+
pub fn name_to_op(name: &str) -> Result<Operator> {
match name {
"equal" => Ok(Operator::Eq),
@@ -682,109 +690,8 @@ pub async fn from_substrait_rex(
}
}
Some(RexType::Literal(lit)) => {
- match &lit.literal_type {
- Some(LiteralType::I8(n)) => {
- if lit.type_variation_reference == 0 {
- Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as i8)))))
- } else if lit.type_variation_reference == 1 {
- Ok(Arc::new(Expr::Literal(ScalarValue::UInt8(Some(*n as u8)))))
- } else {
- Err(DataFusionError::Substrait(format!(
- "Unknown type variation reference {}",
- lit.type_variation_reference
- )))
- }
- }
- Some(LiteralType::I16(n)) => {
- if lit.type_variation_reference == 0 {
- Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as i16)))))
- } else if lit.type_variation_reference == 1 {
- Ok(Arc::new(Expr::Literal(ScalarValue::UInt16(Some(
- *n as u16,
- )))))
- } else {
- Err(DataFusionError::Substrait(format!(
- "Unknown type variation reference {}",
- lit.type_variation_reference
- )))
- }
- }
- Some(LiteralType::I32(n)) => {
- if lit.type_variation_reference == 0 {
- Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n)))))
- } else if lit.type_variation_reference == 1 {
- Ok(Arc::new(Expr::Literal(ScalarValue::UInt32(Some(unsafe {
- std::mem::transmute_copy::<i32, u32>(n)
- })))))
- } else {
- Err(DataFusionError::Substrait(format!(
- "Unknown type variation reference {}",
- lit.type_variation_reference
- )))
- }
- }
- Some(LiteralType::I64(n)) => {
- if lit.type_variation_reference == 0 {
- Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n)))))
- } else if lit.type_variation_reference == 1 {
- Ok(Arc::new(Expr::Literal(ScalarValue::UInt64(Some(unsafe {
- std::mem::transmute_copy::<i64, u64>(n)
- })))))
- } else {
- Err(DataFusionError::Substrait(format!(
- "Unknown type variation reference {}",
- lit.type_variation_reference
- )))
- }
- }
- Some(LiteralType::Boolean(b)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b)))))
- }
- Some(LiteralType::Date(d)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Date32(Some(*d)))))
- }
- Some(LiteralType::Fp32(f)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Float32(Some(*f)))))
- }
- Some(LiteralType::Fp64(f)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Float64(Some(*f)))))
- }
- Some(LiteralType::Decimal(d)) => {
- let value: [u8; 16] = d.value.clone().try_into().or(Err(
- DataFusionError::Substrait(
- "Failed to parse decimal value".to_string(),
- ),
- ))?;
- let p = d.precision.try_into().map_err(|e| {
- DataFusionError::Substrait(format!(
- "Failed to parse decimal precision: {e}"
- ))
- })?;
- let s = d.scale.try_into().map_err(|e| {
- DataFusionError::Substrait(format!(
- "Failed to parse decimal scale: {e}"
- ))
- })?;
- Ok(Arc::new(Expr::Literal(ScalarValue::Decimal128(
- Some(std::primitive::i128::from_le_bytes(value)),
- p,
- s,
- ))))
- }
- Some(LiteralType::String(s)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Utf8(Some(s.clone())))))
- }
- Some(LiteralType::Binary(b)) => Ok(Arc::new(Expr::Literal(
- ScalarValue::Binary(Some(b.clone())),
- ))),
- Some(LiteralType::Null(ntype)) => {
- Ok(Arc::new(Expr::Literal(from_substrait_null(ntype)?)))
- }
- _ => Err(DataFusionError::NotImplemented(format!(
- "Unsupported literal_type: {:?}",
- lit.literal_type
- ))),
- }
+ let scalar_value = from_substrait_literal(lit)?;
+ Ok(Arc::new(Expr::Literal(scalar_value)))
}
Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() {
Some(output_type) => Ok(Arc::new(Expr::Cast(Cast::new(
@@ -855,13 +762,104 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result<DataType> {
match &dt.kind {
Some(s_kind) => match s_kind {
r#type::Kind::Bool(_) => Ok(DataType::Boolean),
- r#type::Kind::I8(_) => Ok(DataType::Int8),
- r#type::Kind::I16(_) => Ok(DataType::Int16),
- r#type::Kind::I32(_) => Ok(DataType::Int32),
- r#type::Kind::I64(_) => Ok(DataType::Int64),
- r#type::Kind::Decimal(d) => {
- Ok(DataType::Decimal128(d.precision as u8, d.scale as i8))
+ r#type::Kind::I8(integer) => match integer.type_variation_reference {
+ DEFAULT_TYPE_REF => Ok(DataType::Int8),
+ UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt8),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {s_kind:?}"
+ ))),
+ },
+ r#type::Kind::I16(integer) => match integer.type_variation_reference {
+ DEFAULT_TYPE_REF => Ok(DataType::Int16),
+ UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt16),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {s_kind:?}"
+ ))),
+ },
+ r#type::Kind::I32(integer) => match integer.type_variation_reference {
+ DEFAULT_TYPE_REF => Ok(DataType::Int32),
+ UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt32),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {s_kind:?}"
+ ))),
+ },
+ r#type::Kind::I64(integer) => match integer.type_variation_reference {
+ DEFAULT_TYPE_REF => Ok(DataType::Int64),
+ UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt64),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {s_kind:?}"
+ ))),
+ },
+ r#type::Kind::Fp32(_) => Ok(DataType::Float32),
+ r#type::Kind::Fp64(_) => Ok(DataType::Float64),
+ r#type::Kind::Timestamp(ts) => match ts.type_variation_reference {
+ TIMESTAMP_SECOND_TYPE_REF => {
+ Ok(DataType::Timestamp(TimeUnit::Second, None))
+ }
+ TIMESTAMP_MILLI_TYPE_REF => {
+ Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
+ }
+ TIMESTAMP_MICRO_TYPE_REF => {
+ Ok(DataType::Timestamp(TimeUnit::Microsecond, None))
+ }
+ TIMESTAMP_NANO_TYPE_REF => {
+ Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
+ }
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {s_kind:?}"
+ ))),
+ },
+ r#type::Kind::Date(date) => match date.type_variation_reference {
+ DATE_32_TYPE_REF => Ok(DataType::Date32),
+ DATE_64_TYPE_REF => Ok(DataType::Date64),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {s_kind:?}"
+ ))),
+ },
+ r#type::Kind::Binary(binary) => match binary.type_variation_reference {
+ DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::Binary),
+ LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeBinary),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {s_kind:?}"
+ ))),
+ },
+ r#type::Kind::FixedBinary(fixed) => {
+ Ok(DataType::FixedSizeBinary(fixed.length))
}
+ r#type::Kind::String(string) => match string.type_variation_reference {
+ DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::Utf8),
+ LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeUtf8),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {s_kind:?}"
+ ))),
+ },
+ r#type::Kind::List(list) => {
+ let inner_type =
+ from_substrait_type(list.r#type.as_ref().ok_or_else(|| {
+ DataFusionError::Substrait(
+ "List type must have inner type".to_string(),
+ )
+ })?)?;
+ let field = Box::new(Field::new("list_item", inner_type, true));
+ match list.type_variation_reference {
+ DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::List(field)),
+ LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeList(field)),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {s_kind:?}"
+ )))?,
+ }
+ }
+ r#type::Kind::Decimal(d) => match d.type_variation_reference {
+ DECIMAL_128_TYPE_REF => {
+ Ok(DataType::Decimal128(d.precision as u8, d.scale as i8))
+ }
+ DECIMAL_256_TYPE_REF => {
+ Ok(DataType::Decimal256(d.precision as u8, d.scale as i8))
+ }
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {s_kind:?}"
+ ))),
+ },
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported Substrait type: {s_kind:?}"
))),
@@ -910,20 +908,196 @@ fn from_substrait_bound(
}
}
+fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
+ let scalar_value = match &lit.literal_type {
+ Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)),
+ Some(LiteralType::I8(n)) => match lit.type_variation_reference {
+ DEFAULT_TYPE_REF => ScalarValue::Int8(Some(*n as i8)),
+ UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt8(Some(*n as u8)),
+ others => {
+ return Err(DataFusionError::Substrait(format!(
+ "Unknown type variation reference {others}",
+ )));
+ }
+ },
+ Some(LiteralType::I16(n)) => match lit.type_variation_reference {
+ DEFAULT_TYPE_REF => ScalarValue::Int16(Some(*n as i16)),
+ UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt16(Some(*n as u16)),
+ others => {
+ return Err(DataFusionError::Substrait(format!(
+ "Unknown type variation reference {others}",
+ )));
+ }
+ },
+ Some(LiteralType::I32(n)) => match lit.type_variation_reference {
+ DEFAULT_TYPE_REF => ScalarValue::Int32(Some(*n)),
+ UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt32(Some(unsafe {
+ std::mem::transmute_copy::<i32, u32>(n)
+ })),
+ others => {
+ return Err(DataFusionError::Substrait(format!(
+ "Unknown type variation reference {others}",
+ )));
+ }
+ },
+ Some(LiteralType::I64(n)) => match lit.type_variation_reference {
+ DEFAULT_TYPE_REF => ScalarValue::Int64(Some(*n)),
+ UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt64(Some(unsafe {
+ std::mem::transmute_copy::<i64, u64>(n)
+ })),
+ others => {
+ return Err(DataFusionError::Substrait(format!(
+ "Unknown type variation reference {others}",
+ )));
+ }
+ },
+ Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)),
+ Some(LiteralType::Fp64(f)) => ScalarValue::Float64(Some(*f)),
+ Some(LiteralType::Timestamp(t)) => match lit.type_variation_reference {
+ TIMESTAMP_SECOND_TYPE_REF => ScalarValue::TimestampSecond(Some(*t), None),
+ TIMESTAMP_MILLI_TYPE_REF => ScalarValue::TimestampMillisecond(Some(*t), None),
+ TIMESTAMP_MICRO_TYPE_REF => ScalarValue::TimestampMicrosecond(Some(*t), None),
+ TIMESTAMP_NANO_TYPE_REF => ScalarValue::TimestampNanosecond(Some(*t), None),
+ others => {
+ return Err(DataFusionError::Substrait(format!(
+ "Unknown type variation reference {others}",
+ )));
+ }
+ },
+ Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)),
+ Some(LiteralType::String(s)) => match lit.type_variation_reference {
+ DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Utf8(Some(s.clone())),
+ LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeUtf8(Some(s.clone())),
+ others => {
+ return Err(DataFusionError::Substrait(format!(
+ "Unknown type variation reference {others}",
+ )));
+ }
+ },
+ Some(LiteralType::Binary(b)) => match lit.type_variation_reference {
+ DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Binary(Some(b.clone())),
+ LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeBinary(Some(b.clone())),
+ others => {
+ return Err(DataFusionError::Substrait(format!(
+ "Unknown type variation reference {others}",
+ )));
+ }
+ },
+ Some(LiteralType::FixedBinary(b)) => {
+ ScalarValue::FixedSizeBinary(b.len() as _, Some(b.clone()))
+ }
+ Some(LiteralType::Decimal(d)) => {
+ let value: [u8; 16] =
+ d.value
+ .clone()
+ .try_into()
+ .or(Err(DataFusionError::Substrait(
+ "Failed to parse decimal value".to_string(),
+ )))?;
+ let p = d.precision.try_into().map_err(|e| {
+ DataFusionError::Substrait(format!(
+ "Failed to parse decimal precision: {e}"
+ ))
+ })?;
+ let s = d.scale.try_into().map_err(|e| {
+ DataFusionError::Substrait(format!("Failed to parse decimal scale: {e}"))
+ })?;
+ ScalarValue::Decimal128(
+ Some(std::primitive::i128::from_le_bytes(value)),
+ p,
+ s,
+ )
+ }
+ Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
+ _ => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Unsupported literal_type: {:?}",
+ lit.literal_type
+ )))
+ }
+ };
+
+ Ok(scalar_value)
+}
+
fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
if let Some(kind) = &null_type.kind {
match kind {
- r#type::Kind::I8(_) => Ok(ScalarValue::Int8(None)),
- r#type::Kind::I16(_) => Ok(ScalarValue::Int16(None)),
- r#type::Kind::I32(_) => Ok(ScalarValue::Int32(None)),
- r#type::Kind::I64(_) => Ok(ScalarValue::Int64(None)),
+ r#type::Kind::Bool(_) => Ok(ScalarValue::Boolean(None)),
+ r#type::Kind::I8(integer) => match integer.type_variation_reference {
+ DEFAULT_TYPE_REF => Ok(ScalarValue::Int8(None)),
+ UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt8(None)),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {kind:?}"
+ ))),
+ },
+ r#type::Kind::I16(integer) => match integer.type_variation_reference {
+ DEFAULT_TYPE_REF => Ok(ScalarValue::Int16(None)),
+ UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt16(None)),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {kind:?}"
+ ))),
+ },
+ r#type::Kind::I32(integer) => match integer.type_variation_reference {
+ DEFAULT_TYPE_REF => Ok(ScalarValue::Int32(None)),
+ UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt32(None)),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {kind:?}"
+ ))),
+ },
+ r#type::Kind::I64(integer) => match integer.type_variation_reference {
+ DEFAULT_TYPE_REF => Ok(ScalarValue::Int64(None)),
+ UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt64(None)),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {kind:?}"
+ ))),
+ },
+ r#type::Kind::Fp32(_) => Ok(ScalarValue::Float32(None)),
+ r#type::Kind::Fp64(_) => Ok(ScalarValue::Float64(None)),
+ r#type::Kind::Timestamp(ts) => match ts.type_variation_reference {
+ TIMESTAMP_SECOND_TYPE_REF => Ok(ScalarValue::TimestampSecond(None, None)),
+ TIMESTAMP_MILLI_TYPE_REF => {
+ Ok(ScalarValue::TimestampMillisecond(None, None))
+ }
+ TIMESTAMP_MICRO_TYPE_REF => {
+ Ok(ScalarValue::TimestampMicrosecond(None, None))
+ }
+ TIMESTAMP_NANO_TYPE_REF => {
+ Ok(ScalarValue::TimestampNanosecond(None, None))
+ }
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {kind:?}"
+ ))),
+ },
+ r#type::Kind::Date(date) => match date.type_variation_reference {
+ DATE_32_TYPE_REF => Ok(ScalarValue::Date32(None)),
+ DATE_64_TYPE_REF => Ok(ScalarValue::Date64(None)),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {kind:?}"
+ ))),
+ },
+ r#type::Kind::Binary(binary) => match binary.type_variation_reference {
+ DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::Binary(None)),
+ LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeBinary(None)),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {kind:?}"
+ ))),
+ },
+ // FixedBinary is not supported because `None` doesn't have length
+ r#type::Kind::String(string) => match string.type_variation_reference {
+ DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::Utf8(None)),
+ LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeUtf8(None)),
+ v => Err(DataFusionError::NotImplemented(format!(
+ "Unsupported Substrait type variation {v} of type {kind:?}"
+ ))),
+ },
r#type::Kind::Decimal(d) => Ok(ScalarValue::Decimal128(
None,
d.precision as u8,
d.scale as i8,
)),
_ => Err(DataFusionError::NotImplemented(format!(
- "Unsupported null kind: {kind:?}"
+ "Unsupported Substrait type: {kind:?}"
))),
}
} else {
diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs
index c0a4dd04f3..9ad9645ffc 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -18,7 +18,7 @@
use std::{collections::HashMap, mem, sync::Arc};
use datafusion::{
- arrow::datatypes::DataType,
+ arrow::datatypes::{DataType, TimeUnit},
error::{DataFusionError, Result},
logical_expr::{WindowFrame, WindowFrameBound},
prelude::JoinType,
@@ -63,6 +63,13 @@ use substrait::{
version,
};
+use crate::variation_const::{
+ DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF,
+ DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF,
+ TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF,
+ TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
+};
+
/// Convert DataFusion LogicalPlan to Substrait Plan
pub fn to_substrait_plan(plan: &LogicalPlan) -> Result<Box<Plan>> {
// Parse relation nodes
@@ -637,48 +644,7 @@ pub fn to_substrait_rex(
))),
})
}
- Expr::Literal(value) => {
- let literal_type = match value {
- ScalarValue::Int8(Some(n)) => Some(LiteralType::I8(*n as i32)),
- ScalarValue::UInt8(Some(n)) => Some(LiteralType::I8(*n as i32)),
- ScalarValue::Int16(Some(n)) => Some(LiteralType::I16(*n as i32)),
- ScalarValue::UInt16(Some(n)) => Some(LiteralType::I16(*n as i32)),
- ScalarValue::Int32(Some(n)) => Some(LiteralType::I32(*n)),
- ScalarValue::UInt32(Some(n)) => Some(LiteralType::I32(unsafe {
- mem::transmute_copy::<u32, i32>(n)
- })),
- ScalarValue::Int64(Some(n)) => Some(LiteralType::I64(*n)),
- ScalarValue::UInt64(Some(n)) => Some(LiteralType::I64(unsafe {
- mem::transmute_copy::<u64, i64>(n)
- })),
- ScalarValue::Boolean(Some(b)) => Some(LiteralType::Boolean(*b)),
- ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
- ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
- ScalarValue::Decimal128(v, p, s) if v.is_some() => {
- Some(LiteralType::Decimal(Decimal {
- value: v.unwrap().to_le_bytes().to_vec(),
- precision: *p as i32,
- scale: *s as i32,
- }))
- }
- ScalarValue::Utf8(Some(s)) => Some(LiteralType::String(s.clone())),
- ScalarValue::LargeUtf8(Some(s)) => Some(LiteralType::String(s.clone())),
- ScalarValue::Binary(Some(b)) => Some(LiteralType::Binary(b.clone())),
- ScalarValue::LargeBinary(Some(b)) => Some(LiteralType::Binary(b.clone())),
- ScalarValue::Date32(Some(d)) => Some(LiteralType::Date(*d)),
- _ => Some(try_to_substrait_null(value)?),
- };
-
- let type_variation_reference = if value.is_unsigned() { 1 } else { 0 };
-
- Ok(Expression {
- rex_type: Some(RexType::Literal(Literal {
- nullable: true,
- type_variation_reference,
- literal_type,
- })),
- })
- }
+ Expr::Literal(value) => to_substrait_literal(value),
Expr::Alias(expr, _alias) => to_substrait_rex(expr, schema, extension_info),
Expr::WindowFunction(WindowFunction {
fun,
@@ -728,7 +694,6 @@ pub fn to_substrait_rex(
}
fn to_substrait_type(dt: &DataType) -> Result<substrait::proto::Type> {
- let default_type_ref = 0;
let default_nullability = r#type::Nullability::Required as i32;
match dt {
DataType::Null => Err(DataFusionError::Internal(
@@ -736,37 +701,173 @@ fn to_substrait_type(dt: &DataType) -> Result<substrait::proto::Type> {
)),
DataType::Boolean => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Bool(r#type::Boolean {
- type_variation_reference: default_type_ref,
+ type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::Int8 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I8(r#type::I8 {
- type_variation_reference: default_type_ref,
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }),
+ DataType::UInt8 => Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::I8(r#type::I8 {
+ type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::Int16 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I16(r#type::I16 {
- type_variation_reference: default_type_ref,
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }),
+ DataType::UInt16 => Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::I16(r#type::I16 {
+ type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::Int32 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
- type_variation_reference: default_type_ref,
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }),
+ DataType::UInt32 => Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::I32(r#type::I32 {
+ type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
}),
DataType::Int64 => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
- type_variation_reference: default_type_ref,
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }),
+ DataType::UInt64 => Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::I64(r#type::I64 {
+ type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }),
+ // Float16 is not supported in Substrait
+ DataType::Float32 => Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::Fp32(r#type::Fp32 {
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }),
+ DataType::Float64 => Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::Fp64(r#type::Fp64 {
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }),
+ // Timezone is ignored.
+ DataType::Timestamp(unit, _) => {
+ let type_variation_reference = match unit {
+ TimeUnit::Second => TIMESTAMP_SECOND_TYPE_REF,
+ TimeUnit::Millisecond => TIMESTAMP_MILLI_TYPE_REF,
+ TimeUnit::Microsecond => TIMESTAMP_MICRO_TYPE_REF,
+ TimeUnit::Nanosecond => TIMESTAMP_NANO_TYPE_REF,
+ };
+ Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
+ type_variation_reference,
+ nullability: default_nullability,
+ })),
+ })
+ }
+ DataType::Date32 => Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::Date(r#type::Date {
+ type_variation_reference: DATE_32_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }),
+ DataType::Date64 => Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::Date(r#type::Date {
+ type_variation_reference: DATE_64_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }),
+ DataType::Binary => Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::Binary(r#type::Binary {
+ type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }),
+ DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary {
+ length: *length,
+ type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}),
+ DataType::LargeBinary => Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::Binary(r#type::Binary {
+ type_variation_reference: LARGE_CONTAINER_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }),
+ DataType::Utf8 => Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::String(r#type::String {
+ type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }),
+ DataType::LargeUtf8 => Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::String(r#type::String {
+ type_variation_reference: LARGE_CONTAINER_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }),
+ DataType::List(inner) => {
+ let inner_type = to_substrait_type(inner.data_type())?;
+ Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::List(Box::new(r#type::List {
+ r#type: Some(Box::new(inner_type)),
+ type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
+ nullability: default_nullability,
+ }))),
+ })
+ }
+ DataType::LargeList(inner) => {
+ let inner_type = to_substrait_type(inner.data_type())?;
+ Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::List(Box::new(r#type::List {
+ r#type: Some(Box::new(inner_type)),
+ type_variation_reference: LARGE_CONTAINER_TYPE_REF,
+ nullability: default_nullability,
+ }))),
+ })
+ }
+ DataType::Struct(fields) => {
+ let field_types = fields
+ .iter()
+ .map(|field| to_substrait_type(field.data_type()))
+ .collect::<Result<Vec<_>>>()?;
+ Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::Struct(r#type::Struct {
+ types: field_types,
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })
+ }
DataType::Decimal128(p, s) => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Decimal(r#type::Decimal {
- type_variation_reference: default_type_ref,
+ type_variation_reference: DECIMAL_128_TYPE_REF,
+ nullability: default_nullability,
+ scale: *s as i32,
+ precision: *p as i32,
+ })),
+ }),
+ DataType::Decimal256(p, s) => Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::Decimal(r#type::Decimal {
+ type_variation_reference: DECIMAL_256_TYPE_REF,
nullability: default_nullability,
scale: *s as i32,
precision: *p as i32,
@@ -908,31 +1009,215 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> {
))
}
+fn to_substrait_literal(value: &ScalarValue) -> Result<Expression> {
+ let (literal_type, type_variation_reference) = match value {
+ ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), DEFAULT_TYPE_REF),
+ ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), DEFAULT_TYPE_REF),
+ ScalarValue::UInt8(Some(n)) => {
+ (LiteralType::I8(*n as i32), UNSIGNED_INTEGER_TYPE_REF)
+ }
+ ScalarValue::Int16(Some(n)) => (LiteralType::I16(*n as i32), DEFAULT_TYPE_REF),
+ ScalarValue::UInt16(Some(n)) => {
+ (LiteralType::I16(*n as i32), UNSIGNED_INTEGER_TYPE_REF)
+ }
+ ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_REF),
+ ScalarValue::UInt32(Some(n)) => (
+ LiteralType::I32(unsafe { mem::transmute_copy::<u32, i32>(n) }),
+ UNSIGNED_INTEGER_TYPE_REF,
+ ),
+ ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_REF),
+ ScalarValue::UInt64(Some(n)) => (
+ LiteralType::I64(unsafe { mem::transmute_copy::<u64, i64>(n) }),
+ UNSIGNED_INTEGER_TYPE_REF,
+ ),
+ ScalarValue::Float32(Some(f)) => (LiteralType::Fp32(*f), DEFAULT_TYPE_REF),
+ ScalarValue::Float64(Some(f)) => (LiteralType::Fp64(*f), DEFAULT_TYPE_REF),
+ ScalarValue::TimestampSecond(Some(t), _) => {
+ (LiteralType::Timestamp(*t), TIMESTAMP_SECOND_TYPE_REF)
+ }
+ ScalarValue::TimestampMillisecond(Some(t), _) => {
+ (LiteralType::Timestamp(*t), TIMESTAMP_MILLI_TYPE_REF)
+ }
+ ScalarValue::TimestampMicrosecond(Some(t), _) => {
+ (LiteralType::Timestamp(*t), TIMESTAMP_MICRO_TYPE_REF)
+ }
+ ScalarValue::TimestampNanosecond(Some(t), _) => {
+ (LiteralType::Timestamp(*t), TIMESTAMP_NANO_TYPE_REF)
+ }
+ ScalarValue::Date32(Some(d)) => (LiteralType::Date(*d), DATE_32_TYPE_REF),
+ // Date64 literal is not supported in Substrait
+ ScalarValue::Binary(Some(b)) => {
+ (LiteralType::Binary(b.clone()), DEFAULT_CONTAINER_TYPE_REF)
+ }
+ ScalarValue::LargeBinary(Some(b)) => {
+ (LiteralType::Binary(b.clone()), LARGE_CONTAINER_TYPE_REF)
+ }
+ ScalarValue::FixedSizeBinary(_, Some(b)) => {
+ (LiteralType::FixedBinary(b.clone()), DEFAULT_TYPE_REF)
+ }
+ ScalarValue::Utf8(Some(s)) => {
+ (LiteralType::String(s.clone()), DEFAULT_CONTAINER_TYPE_REF)
+ }
+ ScalarValue::LargeUtf8(Some(s)) => {
+ (LiteralType::String(s.clone()), LARGE_CONTAINER_TYPE_REF)
+ }
+ ScalarValue::Decimal128(v, p, s) if v.is_some() => (
+ LiteralType::Decimal(Decimal {
+ value: v.unwrap().to_le_bytes().to_vec(),
+ precision: *p as i32,
+ scale: *s as i32,
+ }),
+ DECIMAL_128_TYPE_REF,
+ ),
+ _ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF),
+ };
+
+ Ok(Expression {
+ rex_type: Some(RexType::Literal(Literal {
+ nullable: true,
+ type_variation_reference,
+ literal_type: Some(literal_type),
+ })),
+ })
+}
+
fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
- let default_type_ref = 0;
+ // let default_type_ref = 0;
let default_nullability = r#type::Nullability::Nullable as i32;
match v {
ScalarValue::Int8(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I8(r#type::I8 {
- type_variation_reference: default_type_ref,
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })),
+ ScalarValue::UInt8(None) => Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::I8(r#type::I8 {
+ type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Int16(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I16(r#type::I16 {
- type_variation_reference: default_type_ref,
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })),
+ ScalarValue::UInt16(None) => Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::I16(r#type::I16 {
+ type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Int32(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I32(r#type::I32 {
- type_variation_reference: default_type_ref,
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })),
+ ScalarValue::UInt32(None) => Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::I32(r#type::I32 {
+ type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
nullability: default_nullability,
})),
})),
ScalarValue::Int64(None) => Ok(LiteralType::Null(substrait::proto::Type {
kind: Some(r#type::Kind::I64(r#type::I64 {
- type_variation_reference: default_type_ref,
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })),
+ ScalarValue::UInt64(None) => Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::I64(r#type::I64 {
+ type_variation_reference: UNSIGNED_INTEGER_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })),
+ ScalarValue::Float32(None) => Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::Fp32(r#type::Fp32 {
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })),
+ ScalarValue::Float64(None) => Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::Fp64(r#type::Fp64 {
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })),
+ ScalarValue::TimestampSecond(None, _) => {
+ Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
+ type_variation_reference: TIMESTAMP_SECOND_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }))
+ }
+ ScalarValue::TimestampMillisecond(None, _) => {
+ Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
+ type_variation_reference: TIMESTAMP_MILLI_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }))
+ }
+ ScalarValue::TimestampMicrosecond(None, _) => {
+ Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
+ type_variation_reference: TIMESTAMP_MICRO_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }))
+ }
+ ScalarValue::TimestampNanosecond(None, _) => {
+ Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::Timestamp(r#type::Timestamp {
+ type_variation_reference: TIMESTAMP_NANO_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }))
+ }
+ ScalarValue::Date32(None) => Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::Date(r#type::Date {
+ type_variation_reference: DATE_32_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })),
+ ScalarValue::Date64(None) => Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::Date(r#type::Date {
+ type_variation_reference: DATE_64_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })),
+ ScalarValue::Binary(None) => Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::Binary(r#type::Binary {
+ type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })),
+ ScalarValue::LargeBinary(None) => Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::Binary(r#type::Binary {
+ type_variation_reference: LARGE_CONTAINER_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })),
+ ScalarValue::FixedSizeBinary(_, None) => {
+ Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::Binary(r#type::Binary {
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ }))
+ }
+ ScalarValue::Utf8(None) => Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::String(r#type::String {
+ type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })),
+ ScalarValue::LargeUtf8(None) => Ok(LiteralType::Null(substrait::proto::Type {
+ kind: Some(r#type::Kind::String(r#type::String {
+ type_variation_reference: LARGE_CONTAINER_TYPE_REF,
nullability: default_nullability,
})),
})),
@@ -941,7 +1226,7 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result<LiteralType> {
kind: Some(r#type::Kind::Decimal(r#type::Decimal {
scale: *s as i32,
precision: *p as i32,
- type_variation_reference: default_type_ref,
+ type_variation_reference: DEFAULT_TYPE_REF,
nullability: default_nullability,
})),
}))
diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs
new file mode 100644
index 0000000000..27ef15153b
--- /dev/null
+++ b/datafusion/substrait/src/variation_const.rs
@@ -0,0 +1,39 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Type variation constants
+//!
+//! To add support for types not in the [core specification](https://substrait.io/types/type_classes/),
+//! we make use of the [simple extensions](https://substrait.io/extensions/#simple-extensions) of substrait
+//! type. This module contains the constants used to identify the type variation.
+//!
+//! The rules of type variations here are:
+//! - Default type reference is 0. It is used when the actual type is the same with the original type.
+//! - Extended variant type references start from 1, and ususlly increase by 1.
+
+pub const DEFAULT_TYPE_REF: u32 = 0;
+pub const UNSIGNED_INTEGER_TYPE_REF: u32 = 1;
+pub const TIMESTAMP_SECOND_TYPE_REF: u32 = 0;
+pub const TIMESTAMP_MILLI_TYPE_REF: u32 = 1;
+pub const TIMESTAMP_MICRO_TYPE_REF: u32 = 2;
+pub const TIMESTAMP_NANO_TYPE_REF: u32 = 3;
+pub const DATE_32_TYPE_REF: u32 = 0;
+pub const DATE_64_TYPE_REF: u32 = 1;
+pub const DEFAULT_CONTAINER_TYPE_REF: u32 = 0;
+pub const LARGE_CONTAINER_TYPE_REF: u32 = 1;
+pub const DECIMAL_128_TYPE_REF: u32 = 0;
+pub const DECIMAL_256_TYPE_REF: u32 = 1;
diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs b/datafusion/substrait/tests/roundtrip_logical_plan.rs
index 3389658d2a..965c007e98 100644
--- a/datafusion/substrait/tests/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs
@@ -21,7 +21,7 @@ use datafusion_substrait::logical_plan::{consumer, producer};
mod tests {
use crate::{consumer::from_substrait_plan, producer::to_substrait_plan};
- use datafusion::arrow::datatypes::{DataType, Field, Schema};
+ use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::error::Result;
use datafusion::prelude::*;
use substrait::proto::extensions::simple_extension_declaration::MappingType;
@@ -262,7 +262,65 @@ mod tests {
#[tokio::test]
async fn qualified_catalog_schema_table_reference() -> Result<()> {
- roundtrip("SELECT * FROM datafusion.public.data;").await
+ roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await
+ }
+
+ /// Construct a plan that contains several literals of types that are currently supported.
+ /// This case ignores:
+ /// - Date64, for this literal is not supported
+ /// - FixedSizeBinary, for converting UTF-8 literal to FixedSizeBinary is not supported
+ /// - List, this nested type is not supported in arrow_cast
+ /// - Decimal128 and Decimal256, them will fallback to UTF8 cast expr rather than plain literal.
+ #[tokio::test]
+ async fn all_type_literal() -> Result<()> {
+ roundtrip_all_types(
+ "select * from data where
+ bool_col = TRUE AND
+ int8_col = arrow_cast('0', 'Int8') AND
+ uint8_col = arrow_cast('0', 'UInt8') AND
+ int16_col = arrow_cast('0', 'Int16') AND
+ uint16_col = arrow_cast('0', 'UInt16') AND
+ int32_col = arrow_cast('0', 'Int32') AND
+ uint32_col = arrow_cast('0', 'UInt32') AND
+ int64_col = arrow_cast('0', 'Int64') AND
+ uint64_col = arrow_cast('0', 'UInt64') AND
+ float32_col = arrow_cast('0', 'Float32') AND
+ float64_col = arrow_cast('0', 'Float64') AND
+ sec_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Second, None)') AND
+ ms_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Millisecond, None)') AND
+ us_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Microsecond, None)') AND
+ ns_timestamp_col = arrow_cast('2020-01-01 00:00:00', 'Timestamp (Nanosecond, None)') AND
+ date32_col = arrow_cast('2020-01-01', 'Date32') AND
+ binary_col = arrow_cast('binary', 'Binary') AND
+ large_binary_col = arrow_cast('large_binary', 'LargeBinary') AND
+ utf8_col = arrow_cast('utf8', 'Utf8') AND
+ large_utf8_col = arrow_cast('large_utf8', 'LargeUtf8');",
+ )
+ .await
+ }
+
+ /// Construct a plan that cast columns. Only those SQL types are supported for now.
+ #[tokio::test]
+ async fn new_test_grammar() -> Result<()> {
+ roundtrip_all_types(
+ "select
+ bool_col::boolean,
+ int8_col::tinyint,
+ uint8_col::tinyint unsigned,
+ int16_col::smallint,
+ uint16_col::smallint unsigned,
+ int32_col::integer,
+ uint32_col::integer unsigned,
+ int64_col::bigint,
+ uint64_col::bigint unsigned,
+ float32_col::float,
+ float64_col::double,
+ decimal_128_col::decimal(10, 2),
+ date32_col::date,
+ binary_col::bytea
+ from data",
+ )
+ .await
}
async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> {
@@ -333,6 +391,23 @@ mod tests {
Ok(())
}
+ async fn roundtrip_all_types(sql: &str) -> Result<()> {
+ let mut ctx = create_all_type_context().await?;
+ let df = ctx.sql(sql).await?;
+ let plan = df.into_optimized_plan()?;
+ let proto = to_substrait_plan(&plan)?;
+ let plan2 = from_substrait_plan(&mut ctx, &proto).await?;
+ let plan2 = ctx.state().optimize(&plan2)?;
+
+ println!("{plan:#?}");
+ println!("{plan2:#?}");
+
+ let plan1str = format!("{plan:?}");
+ let plan2str = format!("{plan2:?}");
+ assert_eq!(plan1str, plan2str);
+ Ok(())
+ }
+
async fn function_extension_info(sql: &str) -> Result<(Vec<String>, Vec<u32>)> {
let ctx = create_context().await?;
let df = ctx.sql(sql).await?;
@@ -373,4 +448,68 @@ mod tests {
.await?;
Ok(ctx)
}
+
+ /// Cover all supported types
+ async fn create_all_type_context() -> Result<SessionContext> {
+ let ctx = SessionContext::new();
+ let mut explicit_options = CsvReadOptions::new();
+ let schema = Schema::new(vec![
+ Field::new("bool_col", DataType::Boolean, true),
+ Field::new("int8_col", DataType::Int8, true),
+ Field::new("uint8_col", DataType::UInt8, true),
+ Field::new("int16_col", DataType::Int16, true),
+ Field::new("uint16_col", DataType::UInt16, true),
+ Field::new("int32_col", DataType::Int32, true),
+ Field::new("uint32_col", DataType::UInt32, true),
+ Field::new("int64_col", DataType::Int64, true),
+ Field::new("uint64_col", DataType::UInt64, true),
+ Field::new("float32_col", DataType::Float32, true),
+ Field::new("float64_col", DataType::Float64, true),
+ Field::new(
+ "sec_timestamp_col",
+ DataType::Timestamp(TimeUnit::Second, None),
+ true,
+ ),
+ Field::new(
+ "ms_timestamp_col",
+ DataType::Timestamp(TimeUnit::Millisecond, None),
+ true,
+ ),
+ Field::new(
+ "us_timestamp_col",
+ DataType::Timestamp(TimeUnit::Microsecond, None),
+ true,
+ ),
+ Field::new(
+ "ns_timestamp_col",
+ DataType::Timestamp(TimeUnit::Nanosecond, None),
+ true,
+ ),
+ Field::new("date32_col", DataType::Date32, true),
+ Field::new("date64_col", DataType::Date64, true),
+ Field::new("binary_col", DataType::Binary, true),
+ Field::new("large_binary_col", DataType::LargeBinary, true),
+ Field::new("fixed_size_binary_col", DataType::FixedSizeBinary(42), true),
+ Field::new("utf8_col", DataType::Utf8, true),
+ Field::new("large_utf8_col", DataType::LargeUtf8, true),
+ Field::new(
+ "list_col",
+ DataType::List(Box::new(Field::new("item", DataType::Int64, true))),
+ true,
+ ),
+ Field::new(
+ "large_list_col",
+ DataType::LargeList(Box::new(Field::new("item", DataType::Int64, true))),
+ true,
+ ),
+ Field::new("decimal_128_col", DataType::Decimal128(10, 2), true),
+ Field::new("decimal_256_col", DataType::Decimal256(10, 2), true),
+ ]);
+ explicit_options.schema = Some(&schema);
+ explicit_options.has_header = false;
+ ctx.register_csv("data", "tests/testdata/empty.csv", explicit_options)
+ .await?;
+
+ Ok(ctx)
+ }
}
diff --git a/datafusion/substrait/tests/testdata/empty.csv b/datafusion/substrait/tests/testdata/empty.csv
new file mode 100644
index 0000000000..e69de29bb2