You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/10/12 00:09:46 UTC

[GitHub] [arrow-datafusion] liukun4515 commented on a diff in pull request #3768: move type coercion of agg and agg_udaf to logical phase

liukun4515 commented on code in PR #3768:
URL: https://github.com/apache/arrow-datafusion/pull/3768#discussion_r992870967


##########
datafusion/optimizer/src/type_coercion.rs:
##########
@@ -596,6 +659,123 @@ mod test {
         Ok(())
     }
 
+    #[test]
+    fn agg_udaf() -> Result<()> {
+        let empty = empty();
+        let my_avg = create_udaf(
+            "MY_AVG",
+            DataType::Float64,
+            Arc::new(DataType::Float64),
+            Volatility::Immutable,
+            Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
+            Arc::new(vec![DataType::UInt64, DataType::Float64]),
+        );
+        let udaf = Expr::AggregateUDF {
+            fun: Arc::new(my_avg),
+            args: vec![lit(10i64)],
+            filter: None,
+        };
+        let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty, None)?);
+        let rule = TypeCoercion::new();
+        let mut config = OptimizerConfig::default();
+        let plan = rule.optimize(&plan, &mut config)?;
+        assert_eq!(
+            "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n  EmptyRelation",
+            &format!("{:?}", plan)
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn agg_udaf_invalid_input() -> Result<()> {
+        let empty = empty();
+        let return_type: ReturnTypeFunction =
+            Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
+        let state_type: StateTypeFunction =
+            Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64])));
+        let accumulator: AccumulatorFunctionImplementation =
+            Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?)));
+        let my_avg = AggregateUDF::new(
+            "MY_AVG",
+            &Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
+            &return_type,
+            &accumulator,
+            &state_type,
+        );
+        let udaf = Expr::AggregateUDF {
+            fun: Arc::new(my_avg),
+            args: vec![lit("10")],
+            filter: None,
+        };
+        let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty, None)?);
+        let rule = TypeCoercion::new();
+        let mut config = OptimizerConfig::default();
+        let plan = rule.optimize(&plan, &mut config);
+        assert!(plan.is_err());
+        assert_eq!(
+            "Plan(\"Coercion from [Utf8] to the signature Uniform(1, [Float64]) failed.\")",
+            &format!("{:?}", plan.err().unwrap())
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn agg_function_case() -> Result<()> {
+        let empty = empty();
+        let fun: AggregateFunction = AggregateFunction::Avg;
+        let agg_expr = Expr::AggregateFunction {
+            fun,
+            args: vec![lit(12i64)],
+            distinct: false,
+            filter: None,
+        };
+        let plan =
+            LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty, None)?);
+        let rule = TypeCoercion::new();
+        let mut config = OptimizerConfig::default();
+        let plan = rule.optimize(&plan, &mut config)?;
+        assert_eq!(
+            "Projection: AVG(Int64(12))\n  EmptyRelation",

Review Comment:
   Yes, you don't missing anything.
   
   You can take a look `type_coercion::aggregates::coerce_types` function which just check the input data type and don't do any coercion for the function.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org