You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/11/15 22:59:49 UTC

[arrow-datafusion] branch master updated: Add parser option for parsing SQL numeric literals as decimal (#4102)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 406c1087b Add parser option for parsing SQL numeric literals as decimal (#4102)
406c1087b is described below

commit 406c1087bc16f8d2a49e5a9b05d2a0e1b67f7aa5
Author: Andy Grove <an...@gmail.com>
AuthorDate: Tue Nov 15 15:59:42 2022 -0700

    Add parser option for parsing SQL numeric literals as decimal (#4102)
---
 datafusion/sql/src/planner.rs | 133 ++++++++++++++++++++++++++++++++++++------
 1 file changed, 114 insertions(+), 19 deletions(-)

diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index bd698b505..58d221e07 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -92,9 +92,16 @@ pub trait ContextProvider {
     fn get_config_option(&self, variable: &str) -> Option<ScalarValue>;
 }
 
+/// SQL parser options
+#[derive(Debug, Default)]
+pub struct ParserOptions {
+    parse_float_as_decimal: bool,
+}
+
 /// SQL query planner
 pub struct SqlToRel<'a, S: ContextProvider> {
     schema_provider: &'a S,
+    options: ParserOptions,
 }
 
 fn plan_key(key: SQLExpr) -> Result<ScalarValue> {
@@ -137,7 +144,15 @@ fn plan_indexed(expr: Expr, mut keys: Vec<SQLExpr>) -> Result<Expr> {
 impl<'a, S: ContextProvider> SqlToRel<'a, S> {
     /// Create a new query planner
     pub fn new(schema_provider: &'a S) -> Self {
-        SqlToRel { schema_provider }
+        Self::new_with_options(schema_provider, ParserOptions::default())
+    }
+
+    /// Create a new query planner
+    pub fn new_with_options(schema_provider: &'a S, options: ParserOptions) -> Self {
+        SqlToRel {
+            schema_provider,
+            options,
+        }
     }
 
     /// Generate a logical plan from an DataFusion SQL statement
@@ -1699,7 +1714,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             .map(|row| {
                 row.into_iter()
                     .map(|v| match v {
-                        SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n),
+                        SQLExpr::Value(Value::Number(n, _)) => self.parse_sql_number(&n),
                         SQLExpr::Value(
                             Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
                         ) => Ok(lit(s)),
@@ -1753,7 +1768,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         ctes: &mut HashMap<String, LogicalPlan>,
     ) -> Result<Expr> {
         match sql {
-            SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n),
+            SQLExpr::Value(Value::Number(n, _)) => self.parse_sql_number(&n),
             SQLExpr::Value(Value::SingleQuotedString(ref s) | Value::DoubleQuotedString(ref s)) => Ok(lit(s.clone())),
             SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
             SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)),
@@ -2668,6 +2683,51 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         }
     }
 
+    /// Parse number in sql string, convert to Expr::Literal
+    fn parse_sql_number(&self, n: &str) -> Result<Expr> {
+        if n.find('E').is_some() {
+            // not implemented yet
+            // https://github.com/apache/arrow-datafusion/issues/3448
+            Err(DataFusionError::NotImplemented(
+                "sql numeric literals in scientific notation are not supported"
+                    .to_string(),
+            ))
+        } else if let Ok(n) = n.parse::<i64>() {
+            Ok(lit(n))
+        } else if self.options.parse_float_as_decimal {
+            // remove leading zeroes
+            let str = n.trim_start_matches('0');
+            if let Some(i) = str.find('.') {
+                let p = str.len() - 1;
+                let s = str.len() - i - 1;
+                let str = str.replace('.', "");
+                let n = str.parse::<i128>().map_err(|_| {
+                    DataFusionError::from(ParserError(format!(
+                        "Cannot parse {} as i128 when building decimal",
+                        str
+                    )))
+                })?;
+                Ok(Expr::Literal(ScalarValue::Decimal128(
+                    Some(n),
+                    p as u8,
+                    s as u8,
+                )))
+            } else {
+                let number = n.parse::<i128>().map_err(|_| {
+                    DataFusionError::from(ParserError(format!(
+                        "Cannot parse {} as i128 when building decimal",
+                        n
+                    )))
+                })?;
+                Ok(Expr::Literal(ScalarValue::Decimal128(Some(number), 38, 0)))
+            }
+        } else {
+            n.parse::<f64>().map(lit).map_err(|_| {
+                DataFusionError::from(ParserError(format!("Cannot parse {} as f64", n)))
+            })
+        }
+    }
+
     fn convert_data_type(&self, sql_type: &SQLDataType) -> Result<DataType> {
         match sql_type {
             SQLDataType::Array(inner_sql_type) => {
@@ -2919,21 +2979,6 @@ fn extract_possible_join_keys(
     }
 }
 
-// Parse number in sql string, convert to Expr::Literal
-fn parse_sql_number(n: &str) -> Result<Expr> {
-    // parse first as i64
-    n.parse::<i64>()
-        .map(lit)
-        // if parsing as i64 fails try f64
-        .or_else(|_| n.parse::<f64>().map(lit))
-        .map_err(|_| {
-            DataFusionError::from(ParserError(format!(
-                "Cannot parse {} as i64 or f64",
-                n
-            )))
-        })
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -2941,6 +2986,33 @@ mod tests {
     use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect};
     use std::any::Any;
 
+    #[test]
+    fn parse_decimals() {
+        let test_data = [
+            ("1", "Int64(1)"),
+            ("001", "Int64(1)"),
+            ("0.1", "Decimal128(Some(1),1,1)"),
+            ("0.01", "Decimal128(Some(1),2,2)"),
+            ("1.0", "Decimal128(Some(10),2,1)"),
+            ("10.01", "Decimal128(Some(1001),4,2)"),
+            (
+                "10000000000000000000.00",
+                "Decimal128(Some(1000000000000000000000),22,2)",
+            ),
+        ];
+        for (a, b) in test_data {
+            let sql = format!("SELECT {}", a);
+            let expected = format!("Projection: {}\n  EmptyRelation", b);
+            quick_test_with_options(
+                &sql,
+                &expected,
+                ParserOptions {
+                    parse_float_as_decimal: true,
+                },
+            );
+        }
+    }
+
     #[test]
     fn select_no_relation() {
         quick_test(
@@ -4913,8 +4985,15 @@ mod tests {
     }
 
     fn logical_plan(sql: &str) -> Result<LogicalPlan> {
+        logical_plan_with_options(sql, ParserOptions::default())
+    }
+
+    fn logical_plan_with_options(
+        sql: &str,
+        options: ParserOptions,
+    ) -> Result<LogicalPlan> {
         let dialect = &GenericDialect {};
-        logical_plan_with_dialect(sql, dialect)
+        logical_plan_with_dialect_and_options(sql, dialect, options)
     }
 
     fn logical_plan_with_dialect(
@@ -4927,12 +5006,28 @@ mod tests {
         planner.statement_to_plan(ast.pop_front().unwrap())
     }
 
+    fn logical_plan_with_dialect_and_options(
+        sql: &str,
+        dialect: &dyn Dialect,
+        options: ParserOptions,
+    ) -> Result<LogicalPlan> {
+        let planner = SqlToRel::new_with_options(&MockContextProvider {}, options);
+        let result = DFParser::parse_sql_with_dialect(sql, dialect);
+        let mut ast = result?;
+        planner.statement_to_plan(ast.pop_front().unwrap())
+    }
+
     /// Create logical plan, write with formatter, compare to expected output
     fn quick_test(sql: &str, expected: &str) {
         let plan = logical_plan(sql).unwrap();
         assert_eq!(format!("{:?}", plan), expected);
     }
 
+    fn quick_test_with_options(sql: &str, expected: &str, options: ParserOptions) {
+        let plan = logical_plan_with_options(sql, options).unwrap();
+        assert_eq!(format!("{:?}", plan), expected);
+    }
+
     struct MockContextProvider {}
 
     impl ContextProvider for MockContextProvider {