You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/08/09 19:38:59 UTC

[arrow-datafusion] branch master updated: Add cast function for creating cast expression (#3084)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 01202d60e Add cast function for creating cast expression (#3084)
01202d60e is described below

commit 01202d60ea7dde68fe97c28b08af648c9ac32c67
Author: gorkem <yu...@gmail.com>
AuthorDate: Tue Aug 9 12:38:52 2022 -0700

    Add cast function for creating cast expression (#3084)
    
    * Add cast function for creating cast expression
    
    * Add cast function to prelude
---
 datafusion/core/src/dataframe.rs             | 25 ++++++++++++++++++++++++-
 datafusion/core/src/logical_plan/mod.rs      |  8 ++++----
 datafusion/core/src/prelude.rs               | 10 +++++-----
 datafusion/core/tests/dataframe_functions.rs | 21 ++++++++++++++++++++-
 datafusion/expr/src/expr_fn.rs               |  8 ++++++++
 5 files changed, 61 insertions(+), 11 deletions(-)

diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 46ea63882..66a5671f3 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -786,7 +786,7 @@ mod tests {
     use crate::{assert_batches_sorted_eq, execution::context::SessionContext};
     use crate::{logical_plan::*, test_util};
     use arrow::datatypes::DataType;
-    use datafusion_expr::Volatility;
+    use datafusion_expr::{cast, Volatility};
     use datafusion_expr::{
         BuiltInWindowFunction, ScalarFunctionImplementation, WindowFunction,
     };
@@ -1267,6 +1267,29 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn cast_expr_test() -> Result<()> {
+        let df = test_table()
+            .await?
+            .select_columns(&["c2", "c3"])?
+            .limit(None, Some(1))?
+            .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;
+
+        let df_results = df.collect().await?;
+        assert_batches_sorted_eq!(
+            vec![
+                "+----+----+-----+",
+                "| c2 | c3 | sum |",
+                "+----+----+-----+",
+                "| 2  | 1  | 3   |",
+                "+----+----+-----+",
+            ],
+            &df_results
+        );
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn row_writer_resize_test() -> Result<()> {
         let schema = Arc::new(Schema::new(vec![arrow::datatypes::Field::new(
diff --git a/datafusion/core/src/logical_plan/mod.rs b/datafusion/core/src/logical_plan/mod.rs
index 9b3919837..87a02ae01 100644
--- a/datafusion/core/src/logical_plan/mod.rs
+++ b/datafusion/core/src/logical_plan/mod.rs
@@ -28,10 +28,10 @@ pub use datafusion_common::{
 };
 pub use datafusion_expr::{
     abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan,
-    atan2, avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, coalesce,
-    col, combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, count,
-    count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exists, exp,
-    expr_rewriter,
+    atan2, avg, bit_length, btrim, call_fn, case, cast, ceil, character_length, chr,
+    coalesce, col, combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos,
+    count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest,
+    exists, exp, expr_rewriter,
     expr_rewriter::{
         normalize_col, normalize_col_with_schemas, normalize_cols, replace_col,
         rewrite_sort_cols_by_aggs, unnormalize_col, unnormalize_cols, ExprRewritable,
diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs
index eae96ce48..edae225d8 100644
--- a/datafusion/core/src/prelude.rs
+++ b/datafusion/core/src/prelude.rs
@@ -31,11 +31,11 @@ pub use crate::execution::options::{
     AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions,
 };
 pub use crate::logical_plan::{
-    approx_percentile_cont, array, ascii, avg, bit_length, btrim, character_length, chr,
-    coalesce, col, concat, concat_ws, count, create_udf, date_part, date_trunc, digest,
-    exists, from_unixtime, in_list, in_subquery, initcap, left, length, lit, lower, lpad,
-    ltrim, max, md5, min, not_exists, not_in_subquery, now, octet_length, random,
-    regexp_match, regexp_replace, repeat, replace, reverse, right, rpad, rtrim,
+    approx_percentile_cont, array, ascii, avg, bit_length, btrim, cast, character_length,
+    chr, coalesce, col, concat, concat_ws, count, create_udf, date_part, date_trunc,
+    digest, exists, from_unixtime, in_list, in_subquery, initcap, left, length, lit,
+    lower, lpad, ltrim, max, md5, min, not_exists, not_in_subquery, now, octet_length,
+    random, regexp_match, regexp_replace, repeat, replace, reverse, right, rpad, rtrim,
     scalar_subquery, sha224, sha256, sha384, sha512, split_part, starts_with, strpos,
     substr, sum, to_hex, translate, trim, upper, Column, Expr, JoinType, Partitioning,
 };
diff --git a/datafusion/core/tests/dataframe_functions.rs b/datafusion/core/tests/dataframe_functions.rs
index cefdaa777..19694285c 100644
--- a/datafusion/core/tests/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe_functions.rs
@@ -33,7 +33,7 @@ use datafusion::prelude::*;
 use datafusion::execution::context::SessionContext;
 
 use datafusion::assert_batches_eq;
-use datafusion_expr::approx_median;
+use datafusion_expr::{approx_median, cast};
 
 fn create_test_table() -> Result<Arc<DataFrame>> {
     let schema = Arc::new(Schema::new(vec![
@@ -663,6 +663,25 @@ async fn test_fn_substr() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn test_cast() -> Result<()> {
+    let expr = cast(col("b"), DataType::Float64);
+    let expected = vec![
+        "+-------------------------+",
+        "| CAST(test.b AS Float64) |",
+        "+-------------------------+",
+        "| 1                       |",
+        "| 10                      |",
+        "| 10                      |",
+        "| 100                     |",
+        "+-------------------------+",
+    ];
+
+    assert_fn_batches!(expr, expected);
+
+    Ok(())
+}
+
 #[tokio::test]
 async fn test_fn_to_hex() -> Result<()> {
     let expr = to_hex(col("b"));
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 171bd70a9..75abe44f9 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -251,6 +251,14 @@ pub fn rollup(exprs: Vec<Expr>) -> Expr {
     Expr::GroupingSet(GroupingSet::Rollup(exprs))
 }
 
+/// Create a cast expression
+pub fn cast(expr: Expr, data_type: DataType) -> Expr {
+    Expr::Cast {
+        expr: Box::new(expr),
+        data_type,
+    }
+}
+
 /// Create an convenience function representing a unary scalar function
 macro_rules! unary_scalar_expr {
     ($ENUM:ident, $FUNC:ident, $DOC:expr) => {