You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/21 21:59:17 UTC

[arrow-datafusion] branch master updated: Support `NTILE` window function (#4676)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 1e9b4c949 Support `NTILE` window function (#4676)
1e9b4c949 is described below

commit 1e9b4c9493df97e2cd7083a71f28e2365191d67d
Author: Berkay Şahin <96...@users.noreply.github.com>
AuthorDate: Thu Dec 22 00:59:11 2022 +0300

    Support `NTILE` window function (#4676)
    
    * Ntile support for window functions
    
    Ntile window function is implemented. The expected data type as argument is changed to Int64 in the signature of Ntile. That needs to be converted to UInt64 later.
    
    * Minor changes
    
    * Better error handling
    
    * Ntile function parses its arguments correctly as UInt64
---
 datafusion/core/src/physical_plan/windows/mod.rs | 20 ++++--
 datafusion/core/tests/sql/window.rs              | 26 +++++++
 datafusion/expr/src/window_function.rs           |  4 +-
 datafusion/physical-expr/src/expressions/mod.rs  |  1 +
 datafusion/physical-expr/src/window/mod.rs       |  1 +
 datafusion/physical-expr/src/window/ntile.rs     | 86 ++++++++++++++++++++++++
 6 files changed, 128 insertions(+), 10 deletions(-)

diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs
index 76d39a199..473e7437e 100644
--- a/datafusion/core/src/physical_plan/windows/mod.rs
+++ b/datafusion/core/src/physical_plan/windows/mod.rs
@@ -21,7 +21,7 @@ use crate::error::{DataFusionError, Result};
 use crate::physical_plan::{
     aggregates,
     expressions::{
-        cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue,
+        cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile,
         PhysicalSortExpr, RowNumber,
     },
     type_coercion::coerce,
@@ -107,6 +107,18 @@ fn create_built_in_window_expr(
         BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name)),
         BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name)),
         BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name)),
+        BuiltInWindowFunction::Ntile => {
+            let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
+            let n: i64 = get_scalar_value_from_args(&coerced_args, 0)?
+                .ok_or_else(|| {
+                    DataFusionError::Execution(
+                        "NTILE requires at least 1 argument".to_string(),
+                    )
+                })?
+                .try_into()?;
+            let n: u64 = n as u64;
+            Arc::new(Ntile::new(name, n))
+        }
         BuiltInWindowFunction::Lag => {
             let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
             let arg = coerced_args[0].clone();
@@ -155,12 +167,6 @@ fn create_built_in_window_expr(
             let data_type = args[0].data_type(input_schema)?;
             Arc::new(NthValue::last(name, arg, data_type))
         }
-        _ => {
-            return Err(DataFusionError::NotImplemented(format!(
-                "Window function with {:?} not yet implemented",
-                fun
-            )))
-        }
     })
 }
 
diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs
index dba71f223..c9ef64212 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -889,6 +889,32 @@ async fn window_frame_ranges_preceding_following() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn window_frame_ranges_ntile() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_aggregate_csv(&ctx).await?;
+    let sql = "SELECT \
+               NTILE(8) OVER (ORDER BY C4) as ntile1,\
+               NTILE(12) OVER (ORDER BY C12 DESC) as ntile2 \
+               FROM aggregate_test_100 \
+               ORDER BY c7 \
+               LIMIT 5";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+--------+--------+",
+        "| ntile1 | ntile2 |",
+        "+--------+--------+",
+        "| 8      | 12     |",
+        "| 5      | 11     |",
+        "| 3      | 11     |",
+        "| 2      | 7      |",
+        "| 7      | 12     |",
+        "+--------+--------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
 #[tokio::test]
 async fn window_frame_ranges_string_check() -> Result<()> {
     let ctx = SessionContext::new();
diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs
index 038091ac7..86c63e376 100644
--- a/datafusion/expr/src/window_function.rs
+++ b/datafusion/expr/src/window_function.rs
@@ -211,9 +211,7 @@ pub fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature {
         BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => {
             Signature::any(1, Volatility::Immutable)
         }
-        BuiltInWindowFunction::Ntile => {
-            Signature::exact(vec![DataType::UInt64], Volatility::Immutable)
-        }
+        BuiltInWindowFunction::Ntile => Signature::any(1, Volatility::Immutable),
         BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable),
     }
 }
diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs
index 8222fb664..fc91d91cf 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -66,6 +66,7 @@ pub use crate::window::cume_dist::CumeDist;
 pub use crate::window::lead_lag::WindowShift;
 pub use crate::window::lead_lag::{lag, lead};
 pub use crate::window::nth_value::{NthValue, NthValueKind};
+pub use crate::window::ntile::Ntile;
 pub use crate::window::rank::{dense_rank, percent_rank, rank};
 pub use crate::window::rank::{Rank, RankType};
 pub use crate::window::row_number::RowNumber;
diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs
index 40ed658ee..05c7e3d4a 100644
--- a/datafusion/physical-expr/src/window/mod.rs
+++ b/datafusion/physical-expr/src/window/mod.rs
@@ -21,6 +21,7 @@ mod built_in_window_function_expr;
 pub(crate) mod cume_dist;
 pub(crate) mod lead_lag;
 pub(crate) mod nth_value;
+pub(crate) mod ntile;
 pub(crate) mod partition_evaluator;
 pub(crate) mod rank;
 pub(crate) mod row_number;
diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs
new file mode 100644
index 000000000..ed00c3c86
--- /dev/null
+++ b/datafusion/physical-expr/src/window/ntile.rs
@@ -0,0 +1,86 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines physical expression for `ntile` that can evaluated
+//! at runtime during query execution
+
+use crate::window::partition_evaluator::PartitionEvaluator;
+use crate::window::BuiltInWindowFunctionExpr;
+use crate::PhysicalExpr;
+use arrow::array::{ArrayRef, UInt64Array};
+use arrow::datatypes::Field;
+use arrow_schema::DataType;
+use datafusion_common::Result;
+use std::any::Any;
+use std::ops::Range;
+use std::sync::Arc;
+
+#[derive(Debug)]
+pub struct Ntile {
+    name: String,
+    n: u64,
+}
+
+impl Ntile {
+    pub fn new(name: String, n: u64) -> Self {
+        Self { name, n }
+    }
+}
+
+impl BuiltInWindowFunctionExpr for Ntile {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn field(&self) -> Result<Field> {
+        let nullable = false;
+        let data_type = DataType::UInt64;
+        Ok(Field::new(self.name(), data_type, nullable))
+    }
+
+    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        vec![]
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+
+    fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
+        Ok(Box::new(NtileEvaluator { n: self.n }))
+    }
+}
+
+pub(crate) struct NtileEvaluator {
+    n: u64,
+}
+
+impl PartitionEvaluator for NtileEvaluator {
+    fn evaluate_partition(
+        &self,
+        _values: &[ArrayRef],
+        partition: Range<usize>,
+    ) -> Result<ArrayRef> {
+        let num_rows = (partition.end - partition.start) as u64;
+        let mut vec: Vec<u64> = Vec::new();
+        for i in 0..num_rows {
+            let res = i * self.n / num_rows;
+            vec.push(res + 1)
+        }
+        Ok(Arc::new(UInt64Array::from_iter_values(vec)))
+    }
+}