You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "stuartcarnie (via GitHub)" <gi...@apache.org> on 2023/06/13 05:21:11 UTC

[GitHub] [arrow-datafusion] stuartcarnie commented on a diff in pull request #6617: RFC: User Defined Window Functions

stuartcarnie commented on code in PR #6617:
URL: https://github.com/apache/arrow-datafusion/pull/6617#discussion_r1227543327


##########
datafusion-examples/examples/simple_udwf.rs:
##########
@@ -0,0 +1,210 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use arrow::{
+    array::{AsArray, Float64Array},
+    datatypes::Float64Type,
+};
+use arrow_schema::DataType;
+use datafusion::datasource::file_format::options::CsvReadOptions;
+
+use datafusion::error::Result;
+use datafusion::prelude::*;
+use datafusion_common::DataFusionError;
+use datafusion_expr::{
+    partition_evaluator::PartitionEvaluator, Signature, Volatility, WindowUDF,
+};
+
+// create local execution context with `cars.csv` registered as a table named `cars`
+async fn create_context() -> Result<SessionContext> {
+    // declare a new context. In spark API, this corresponds to a new spark SQLsession
+    let ctx = SessionContext::new();
+
+    // declare a table in memory. In spark API, this corresponds to createDataFrame(...).
+    println!("pwd: {}", std::env::current_dir().unwrap().display());
+    let csv_path = format!("datafusion/core/tests/data/cars.csv");
+    let read_options = CsvReadOptions::default().has_header(true);
+
+    ctx.register_csv("cars", &csv_path, read_options).await?;
+    Ok(ctx)
+}
+
+/// In this example we will declare a user defined window function that computes a moving average and then run it using SQL
+#[tokio::main]
+async fn main() -> Result<()> {
+    let ctx = create_context().await?;
+
+    // register the window function with DataFusion so wecan call it
+    ctx.register_udwf(my_average());
+
+    // Use SQL to run the new window function
+    let df = ctx.sql("SELECT * from cars").await?;
+    // print the results
+    df.show().await?;
+
+    // Use SQL to run the new window function
+    // `PARTITION BY car`:each distinct value of car (red, and green) should be treated separately
+    // `ORDER BY time`: within each group (greed or green) the values will be orderd by time
+    let df = ctx
+        .sql(
+            "SELECT car, \
+                      speed, \
+                      lag(speed, 1) OVER (PARTITION BY car ORDER BY time),\
+                      my_average(speed) OVER (PARTITION BY car ORDER BY time),\
+                      time \
+                      from cars",
+        )
+        .await?;
+    // print the results
+    df.show().await?;
+
+    // // ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING: Run the window functon so that each invocation only sees 5 rows: the 2 before and 2 after) using
+    // let df = ctx.sql("SELECT car, \
+    //                   speed, \
+    //                   lag(speed, 1) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING),\
+    //                   time \
+    //                   from cars").await?;
+    // // print the results
+    // df.show().await?;
+
+    // todo show how to run dataframe API as well
+
+    Ok(())
+}
+
+// TODO make a helper funciton like `crate_udf` that helps to make these signatures
+
+fn my_average() -> WindowUDF {
+    WindowUDF {
+        name: String::from("my_average"),
+        // it will take 2 arguments -- the column and the window size
+        signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable),
+        return_type: Arc::new(return_type),
+        partition_evaluator: Arc::new(make_partition_evaluator),
+    }
+}
+
+/// Compute the return type of the function given the argument types
+fn return_type(arg_types: &[DataType]) -> Result<Arc<DataType>> {
+    if arg_types.len() != 1 {
+        return Err(DataFusionError::Plan(format!(
+            "my_udwf expects 1 argument, got {}: {:?}",
+            arg_types.len(),
+            arg_types
+        )));
+    }
+    Ok(Arc::new(arg_types[0].clone()))
+}
+
+/// Create a partition evaluator for this argument
+fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
+    Ok(Box::new(MyPartitionEvaluator::new()))
+}

Review Comment:
   Is it possible we could support passing scalar arguments when creating an instance of the function, similar to the built-in functions?
   
   For example, the `lag` function takes an optional scalar value for the second argument, which is the shift offset:
   
   https://github.com/apache/arrow-datafusion/blob/a42cc8d98b6e875c485e7e9b106d30803a32b00a/datafusion/core/src/physical_plan/windows/mod.rs#L148-L152



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org