You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/04/06 15:48:08 UTC

[GitHub] [arrow-datafusion] Ted-Jiang commented on a diff in pull request #2130: Serialize scalar UDFs in physical plan

Ted-Jiang commented on code in PR #2130:
URL: https://github.com/apache/arrow-datafusion/pull/2130#discussion_r844106903


##########
ballista/rust/core/src/serde/physical_plan/from_proto.rs:
##########
@@ -54,125 +56,174 @@ impl From<&protobuf::PhysicalColumn> for Column {
     }
 }
 
-impl TryFrom<&protobuf::PhysicalExprNode> for Arc<dyn PhysicalExpr> {
-    type Error = BallistaError;
+pub(crate) fn parse_physical_expr(
+    proto: &protobuf::PhysicalExprNode,
+    registry: &dyn FunctionRegistry,
+) -> Result<Arc<dyn PhysicalExpr>, BallistaError> {
+    let expr_type = proto
+        .expr_type
+        .as_ref()
+        .ok_or_else(|| proto_error("Unexpected empty physical expression"))?;
+
+    let pexpr: Arc<dyn PhysicalExpr> = match expr_type {
+        ExprType::Column(c) => {
+            let pcol: Column = c.into();
+            Arc::new(pcol)
+        }
+        ExprType::Literal(scalar) => {
+            Arc::new(Literal::new(convert_required!(scalar.value)?))
+        }
+        ExprType::BinaryExpr(binary_expr) => Arc::new(BinaryExpr::new(
+            parse_required_physical_box_expr(&binary_expr.l, registry, "left")?,
+            from_proto_binary_op(&binary_expr.op)?,
+            parse_required_physical_box_expr(&binary_expr.r, registry, "right")?,
+        )),
+        ExprType::AggregateExpr(_) => {
+            return Err(BallistaError::General(
+                "Cannot convert aggregate expr node to physical expression".to_owned(),
+            ));
+        }
+        ExprType::WindowExpr(_) => {
+            return Err(BallistaError::General(
+                "Cannot convert window expr node to physical expression".to_owned(),
+            ));
+        }
+        ExprType::Sort(_) => {
+            return Err(BallistaError::General(
+                "Cannot convert sort expr node to physical expression".to_owned(),
+            ));
+        }
+        ExprType::IsNullExpr(e) => Arc::new(IsNullExpr::new(
+            parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+        )),
+        ExprType::IsNotNullExpr(e) => Arc::new(IsNotNullExpr::new(
+            parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+        )),
+        ExprType::NotExpr(e) => Arc::new(NotExpr::new(parse_required_physical_box_expr(
+            &e.expr, registry, "expr",
+        )?)),
+        ExprType::Negative(e) => Arc::new(NegativeExpr::new(
+            parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+        )),
+        ExprType::InList(e) => Arc::new(InListExpr::new(
+            parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+            e.list
+                .iter()
+                .map(|x| parse_physical_expr(x, registry))
+                .collect::<Result<Vec<_>, _>>()?,
+            e.negated,
+        )),
+        ExprType::Case(e) => Arc::new(CaseExpr::try_new(
+            e.expr
+                .as_ref()
+                .map(|e| parse_physical_expr(e.as_ref(), registry))
+                .transpose()?,
+            e.when_then_expr
+                .iter()
+                .map(|e| {
+                    Ok((
+                        parse_required_physical_expr(
+                            &e.when_expr,
+                            registry,
+                            "when_expr",
+                        )?,
+                        parse_required_physical_expr(
+                            &e.then_expr,
+                            registry,
+                            "then_expr",
+                        )?,
+                    ))
+                })
+                .collect::<Result<Vec<_>, BallistaError>>()?
+                .as_slice(),
+            e.else_expr
+                .as_ref()
+                .map(|e| parse_physical_expr(e.as_ref(), registry))
+                .transpose()?,
+        )?),
+        ExprType::Cast(e) => Arc::new(CastExpr::new(
+            parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+            convert_required!(e.arrow_type)?,
+            DEFAULT_DATAFUSION_CAST_OPTIONS,
+        )),
+        ExprType::TryCast(e) => Arc::new(TryCastExpr::new(
+            parse_required_physical_box_expr(&e.expr, registry, "expr")?,
+            convert_required!(e.arrow_type)?,
+        )),
+        ExprType::ScalarFunction(e) => {
+            let scalar_function = datafusion_proto::protobuf::ScalarFunction::from_i32(
+                e.fun,
+            )
+            .ok_or_else(|| {
+                proto_error(format!("Received an unknown scalar function: {}", e.fun,))
+            })?;
+
+            let args = e
+                .args
+                .iter()
+                .map(|x| parse_physical_expr(x, registry))
+                .collect::<Result<Vec<_>, _>>()?;
+
+            // TODO Do not create new the ExecutionProps
+            let execution_props = ExecutionProps::new();
+
+            let fun_expr = functions::create_physical_fun(
+                &(&scalar_function).into(),
+                &execution_props,
+            )?;
+
+            Arc::new(ScalarFunctionExpr::new(
+                &e.name,
+                fun_expr,
+                args,
+                &convert_required!(e.return_type)?,
+            ))
+        }
+        ExprType::ScalarUdf(e) => {
+            let scalar_fun = registry.udf(e.name.as_str())?.deref().clone().fun;

Review Comment:
   👍



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org