You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2020/09/07 11:04:14 UTC

[GitHub] [arrow] alamb commented on a change in pull request #7967: ARROW-9751: [Rust] [DataFusion] Allow UDFs to accept multiple data types per argument

alamb commented on a change in pull request #7967:
URL: https://github.com/apache/arrow/pull/7967#discussion_r484350796



##########
File path: rust/datafusion/examples/simple_udf.rs
##########
@@ -0,0 +1,138 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::{
+    array::{Array, ArrayRef, Float32Array, Float64Array, Float64Builder},
+    datatypes::DataType,
+    record_batch::RecordBatch,
+    util::pretty,
+};
+
+use datafusion::error::Result;
+use datafusion::{physical_plan::functions::ScalarFunctionImplementation, prelude::*};
+use std::sync::Arc;
+
+// create local execution context with an in-memory table
+fn create_context() -> Result<ExecutionContext> {
+    use arrow::datatypes::{Field, Schema};
+    use datafusion::datasource::MemTable;
+    // define a schema.
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("a", DataType::Float32, false),
+        Field::new("b", DataType::Float64, false),
+    ]));
+
+    // define data.
+    let batch = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])),
+            Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
+        ],
+    )?;
+
+    // declare a new context. In spark API, this corresponds to a new spark SQLsession
+    let mut ctx = ExecutionContext::new();
+
+    // declare a table in memory. In spark API, this corresponds to createDataFrame(...).
+    let provider = MemTable::new(schema, vec![vec![batch]])?;
+    ctx.register_table("t", Box::new(provider));
+    Ok(ctx)
+}
+
+/// In this example we will declare a single-type, single return type UDF that exponentiates f64, a^b
+fn main() -> Result<()> {
+    let mut ctx = create_context()?;
+
+    // First, declare the actual implementation of the calculation
+    let pow: ScalarFunctionImplementation = Arc::new(|args: &[ArrayRef]| {
+        // in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to:
+        // 1. cast the values to the type we want
+        // 2. perform the computation for every element in the array (using a loop or SIMD)
+        // 3. construct the resulting array
+
+        // this is guaranteed by DataFusion based on the function's signature.
+        assert_eq!(args.len(), 2);
+
+        // 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics!
+        let base = &args[0]
+            .as_any()
+            .downcast_ref::<Float64Array>()
+            .expect("cast failed");
+        let exponent = &args[1]
+            .as_any()
+            .downcast_ref::<Float64Array>()
+            .expect("cast failed");
+
+        // this is guaranteed by DataFusion. We place it just to make it obvious.
+        assert_eq!(exponent.len(), base.len());
+
+        // 2. Arrow's builder is used to construct an Arrow array.
+        let mut builder = Float64Builder::new(base.len());
+        for index in 0..base.len() {
+            // in arrow, any value can be null.
+            // Here we decide to make our UDF to return null when either base or exponent is null.
+            if base.is_null(index) || exponent.is_null(index) {
+                builder.append_null()?;
+            } else {
+                // 3. computation. Since we do not have any SIMD `pow` operation at our hands,
+                // we loop over each entry. Array's values are obtained via `.value(index)`.
+                let value = base.value(index).powf(exponent.value(index));
+                builder.append_value(value)?;
+            }
+        }
+        Ok(Arc::new(builder.finish()))
+    });
+
+    // Next:
+    // * git it a name (so that it shows nicely when the plan is printed)

Review comment:
       ```suggestion
       // * give it a name so that it shows nicely when the plan is printed
       //   and `pow` can be used in expressions
   ```

##########
File path: rust/datafusion/examples/simple_udf.rs
##########
@@ -0,0 +1,138 @@
+// Licensed to the Apache Software Foundation (ASF) under one

Review comment:
       This example is really nice -- and the comments throughout make it easy for me to follow

##########
File path: rust/datafusion/src/execution/dataframe_impl.rs
##########
@@ -232,6 +241,50 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn registry() -> Result<()> {
+        let mut ctx = ExecutionContext::new();
+        register_aggregate_csv(&mut ctx)?;
+
+        // declare the udf
+        let my_add: ScalarFunctionImplementation = Arc::new(|args: &[ArrayRef]| {

Review comment:
       This doesn't seem right -- the implementation uses two arguments, but the call to `create_udf` only registers a single float64 arg. Given the `my_add` function never actually gets implemented, this doesn't cause a problem in this test.
   
   I suggest changing the body of `my_add` to be `unimplemented!("my_add is not implemented")` to make it clear the code is not executed during this test.

##########
File path: rust/datafusion/src/physical_plan/functions.rs
##########
@@ -232,6 +246,80 @@ fn signature(fun: &ScalarFunction) -> Signature {
     }
 }
 
+/// Physical expression of a scalar function
+pub struct ScalarFunctionExpr {
+    fun: ScalarFunctionImplementation,
+    name: String,
+    args: Vec<Arc<dyn PhysicalExpr>>,
+    return_type: DataType,
+}
+
+impl Debug for ScalarFunctionExpr {
+    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+        f.debug_struct("ScalarFunctionExpr")
+            .field("fun", &"<FUNC>")
+            .field("name", &self.name)
+            .field("args", &self.args)
+            .field("return_type", &self.return_type)
+            .finish()
+    }
+}
+
+impl ScalarFunctionExpr {
+    /// Create a new Scalar function
+    pub fn new(
+        name: &str,
+        fun: ScalarFunctionImplementation,
+        args: Vec<Arc<dyn PhysicalExpr>>,
+        return_type: &DataType,
+    ) -> Self {
+        Self {
+            fun,
+            name: name.to_owned(),
+            args,
+            return_type: return_type.clone(),
+        }
+    }
+}
+
+impl fmt::Display for ScalarFunctionExpr {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(
+            f,
+            "{}({})",
+            self.name,
+            self.args
+                .iter()
+                .map(|e| format!("{}", e))
+                .collect::<Vec<String>>()
+                .join(", ")
+        )
+    }
+}
+
+impl PhysicalExpr for ScalarFunctionExpr {
+    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
+        Ok(self.return_type.clone())
+    }
+
+    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {

Review comment:
       Allowing the definition of a user defined function to define "nullable" or not is probably something we should do in a future PR 

##########
File path: rust/datafusion/src/physical_plan/planner.rs
##########
@@ -368,32 +368,22 @@ impl DefaultPhysicalPlanner {
                     .collect::<Result<Vec<_>>>()?;
                 functions::create_physical_expr(fun, &physical_args, input_schema)
             }
-            Expr::ScalarUDF {
-                name,
-                args,
-                return_type,
-            } => match ctx_state.scalar_functions.get(name) {
-                Some(f) => {
-                    let mut physical_args = vec![];
-                    for e in args {
-                        physical_args.push(self.create_physical_expr(
-                            e,
-                            input_schema,
-                            ctx_state,
-                        )?);
-                    }
-                    Ok(Arc::new(ScalarFunctionExpr::new(
-                        name,
-                        f.fun.clone(),
-                        physical_args,
-                        return_type,
-                    )))
+            Expr::ScalarUDF { fun, args } => {
+                let mut physical_args = vec![];
+                for e in args {
+                    physical_args.push(self.create_physical_expr(
+                        e,
+                        input_schema,
+                        ctx_state,
+                    )?);
                 }
-                _ => Err(ExecutionError::General(format!(
-                    "Invalid scalar function '{:?}'",
-                    name
-                ))),
-            },
+
+                udf::create_physical_expr(

Review comment:
       And an error will occur here if the inputs can't be coerced to the form required by the inputs.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org