You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ho...@apache.org on 2021/12/11 20:47:30 UTC

[arrow-datafusion] branch master updated: support decimal data type in create table (#1431)

This is an automated email from the ASF dual-hosted git repository.

houqp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new dc80c11  support decimal data type in create table (#1431)
dc80c11 is described below

commit dc80c11954fa7855f79806cb29c9ea5283a98d01
Author: Kun Liu <li...@apache.org>
AuthorDate: Sun Dec 12 04:47:23 2021 +0800

    support decimal data type in create table (#1431)
    
    * support decimal data type in create table
---
 datafusion/src/sql/planner.rs | 47 +++++++++++++++++++++++++++++++++++---
 datafusion/tests/sql.rs       | 53 +++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 97 insertions(+), 3 deletions(-)

diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs
index eed2b96..72c1962 100644
--- a/datafusion/src/sql/planner.rs
+++ b/datafusion/src/sql/planner.rs
@@ -371,7 +371,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             SQLDataType::Char(_) | SQLDataType::Varchar(_) | SQLDataType::Text => {
                 Ok(DataType::Utf8)
             }
-            SQLDataType::Decimal(_, _) => Ok(DataType::Float64),
+            SQLDataType::Decimal(precision, scale) => {
+                match (precision, scale) {
+                    (None, _) | (_, None) => {
+                        return Err(DataFusionError::Internal(format!(
+                            "Invalid Decimal type ({:?}), precision or scale can't be empty.",
+                            sql_type
+                        )));
+                    }
+                    (Some(p), Some(s)) => {
+                        // TODO add bound checker in some utils file or function
+                        if *p > 38 || *s > *p {
+                            return Err(DataFusionError::Internal(format!(
+                                "Error Decimal Type ({:?}), precision must be less than or equal to 38 and scale can't be greater than precision",
+                                sql_type
+                            )));
+                        } else {
+                            Ok(DataType::Decimal(*p as usize, *s as usize))
+                        }
+                    }
+                }
+            }
             SQLDataType::Float(_) => Ok(DataType::Float32),
             SQLDataType::Real => Ok(DataType::Float32),
             SQLDataType::Double => Ok(DataType::Float64),
@@ -2022,8 +2042,8 @@ fn extract_possible_join_keys(
 }
 
 /// Convert SQL data type to relational representation of data type
-pub fn convert_data_type(sql: &SQLDataType) -> Result<DataType> {
-    match sql {
+pub fn convert_data_type(sql_type: &SQLDataType) -> Result<DataType> {
+    match sql_type {
         SQLDataType::Boolean => Ok(DataType::Boolean),
         SQLDataType::SmallInt(_) => Ok(DataType::Int16),
         SQLDataType::Int(_) => Ok(DataType::Int32),
@@ -2034,6 +2054,27 @@ pub fn convert_data_type(sql: &SQLDataType) -> Result<DataType> {
         SQLDataType::Char(_) | SQLDataType::Varchar(_) => Ok(DataType::Utf8),
         SQLDataType::Timestamp => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)),
         SQLDataType::Date => Ok(DataType::Date32),
+        SQLDataType::Decimal(precision, scale) => {
+            match (precision, scale) {
+                (None, _) | (_, None) => {
+                    return Err(DataFusionError::Internal(format!(
+                        "Invalid Decimal type ({:?}), precision or scale can't be empty.",
+                        sql_type
+                    )));
+                }
+                (Some(p), Some(s)) => {
+                    // TODO add bound checker in some utils file or function
+                    if *p > 38 || *s > *p {
+                        return Err(DataFusionError::Internal(format!(
+                            "Error Decimal Type ({:?})",
+                            sql_type
+                        )));
+                    } else {
+                        Ok(DataType::Decimal(*p as usize, *s as usize))
+                    }
+                }
+            }
+        }
         other => Err(DataFusionError::NotImplemented(format!(
             "Unsupported SQL type {:?}",
             other
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index f517e45..945bb7e 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -3761,6 +3761,28 @@ async fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> {
     Ok(())
 }
 
+async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &mut ExecutionContext) {
+    let df = ctx
+        .sql(
+            "CREATE EXTERNAL TABLE aggregate_simple (
+            c1  DECIMAL(10,6) NOT NULL,
+            c2  DOUBLE NOT NULL,
+            c3  BOOLEAN NOT NULL
+            )
+            STORED AS CSV
+            WITH HEADER ROW
+            LOCATION 'tests/aggregate_simple.csv'",
+        )
+        .await
+        .expect("Creating dataframe for CREATE EXTERNAL TABLE with decimal data type");
+
+    let results = df.collect().await.expect("Executing CREATE EXTERNAL TABLE");
+    assert!(
+        results.is_empty(),
+        "Expected no rows from executing CREATE EXTERNAL TABLE"
+    );
+}
+
 async fn register_aggregate_simple_csv(ctx: &mut ExecutionContext) -> Result<()> {
     // It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats
     let schema = Arc::new(Schema::new(vec![
@@ -6459,3 +6481,34 @@ async fn test_select_wildcard_without_table() -> Result<()> {
     }
     Ok(())
 }
+
+#[tokio::test]
+async fn csv_query_with_decimal_by_sql() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    register_simple_aggregate_csv_with_decimal_by_sql(&mut ctx).await;
+    let sql = "SELECT c1 from aggregate_simple";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    let expected = vec![
+        "+----------+",
+        "| c1       |",
+        "+----------+",
+        "| 0.000010 |",
+        "| 0.000020 |",
+        "| 0.000020 |",
+        "| 0.000030 |",
+        "| 0.000030 |",
+        "| 0.000030 |",
+        "| 0.000040 |",
+        "| 0.000040 |",
+        "| 0.000040 |",
+        "| 0.000040 |",
+        "| 0.000050 |",
+        "| 0.000050 |",
+        "| 0.000050 |",
+        "| 0.000050 |",
+        "| 0.000050 |",
+        "+----------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}