You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/07/11 15:48:59 UTC

[GitHub] [arrow] pitrou commented on a diff in pull request #13500: ARROW 16968: [C++] Expand Python-UDF support to Arrow Substrait

pitrou commented on code in PR #13500:
URL: https://github.com/apache/arrow/pull/13500#discussion_r918045167


##########
cpp/src/arrow/compute/registry_util.cc:
##########
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/registry_util.h"
+
+namespace arrow {
+namespace compute {
+
+std::unique_ptr<FunctionRegistry> MakeFunctionRegistry() {
+  return FunctionRegistry::Make(GetFunctionRegistry());

Review Comment:
   What is the point of this mostly trivial function? Why not let the user call `FunctionRegistry::Make` directly?



##########
cpp/src/arrow/compute/exec/options.h:
##########
@@ -438,7 +450,5 @@ class ARROW_EXPORT TableSinkNodeOptions : public ExecNodeOptions {
   std::shared_ptr<Table>* output_table;
 };
 
-/// @}

Review Comment:
   You shouldn't remove this, this matches the opening brace in `\addtogroup execnode-options` above.



##########
cpp/src/arrow/engine/substrait/serde.h:
##########
@@ -115,6 +115,22 @@ ARROW_ENGINE_EXPORT Result<compute::ExecPlan> DeserializePlan(
     const Buffer& buf, const std::shared_ptr<dataset::WriteNodeOptions>& write_options,
     const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR);
 
+/// Factory function type for generating the write options of a node consuming the batches
+/// produced by each toplevel Substrait relation when deserializing a Substrait Plan.
+using WriteOptionsFactory = std::function<std::shared_ptr<dataset::WriteNodeOptions>()>;

Review Comment:
   It seems this is a duplicate declaration.



##########
python/pyarrow/_exec_plan.pxd:
##########
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+
+cdef is_supported_execplan_output_type(output_type)
+
+cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads=*, CFunctionRegistry* c_func_registry=*)

Review Comment:
   `ExecPlan` is already the name of a (C++, Python) class, can we name this function something else? Perhaps `execute_plan`?



##########
python/pyarrow/_exec_plan.pyx:
##########
@@ -214,6 +220,9 @@ def _perform_join(join_type, left_operand not None, left_keys,
         vector[c_string] c_projected_col_names
         CJoinType c_join_type
 
+    if not is_supported_execplan_output_type(output_type):
+        raise TypeError(f"Unsupported output type {output_type}")

Review Comment:
   Is it useful to do this here as it's also done in `execplan`?



##########
python/pyarrow/_substrait.pyx:
##########
@@ -15,35 +15,183 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import base64
+import cloudpickle
+import inspect
+
 # cython: language_level = 3
-from cython.operator cimport dereference as deref
+from cython.operator cimport dereference as deref, preincrement as inc
 
+from pyarrow import compute as pc
 from pyarrow import Buffer
+from pyarrow.lib import frombytes, tobytes
 from pyarrow.lib cimport *
 from pyarrow.includes.libarrow cimport *
 from pyarrow.includes.libarrow_substrait cimport *
+from pyarrow._compute cimport FunctionRegistry
+
+
+from pyarrow._exec_plan cimport is_supported_execplan_output_type, execplan
+from pyarrow._compute import make_function_registry
+
+
+def make_extension_id_registry():
+    cdef:
+        shared_ptr[CExtensionIdRegistry] c_extid_registry
+        ExtensionIdRegistry registry
+
+    with nogil:
+        c_extid_registry = MakeExtensionIdRegistry()
+
+    return pyarrow_wrap_extension_id_registry(c_extid_registry)
+
+
+def _get_udf_code(func):
+    return frombytes(base64.b64encode(cloudpickle.dumps(func)))

Review Comment:
   You can instead cimport `pickle` from `pyarrow.compat`, which will use cloudpickle if available, otherwise regular pickle.



##########
python/pyarrow/lib.pyx:
##########
@@ -169,6 +169,9 @@ include "builder.pxi"
 # Column, Table, Record Batch
 include "table.pxi"
 
+# Compute registries
+include "compute.pxi"

Review Comment:
   For the record, why is this addition necessary? Can Substrait instead directly import these declarations?



##########
python/pyarrow/_substrait.pyx:
##########
@@ -15,35 +15,183 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import base64
+import cloudpickle
+import inspect
+
 # cython: language_level = 3
-from cython.operator cimport dereference as deref
+from cython.operator cimport dereference as deref, preincrement as inc
 
+from pyarrow import compute as pc
 from pyarrow import Buffer
+from pyarrow.lib import frombytes, tobytes
 from pyarrow.lib cimport *
 from pyarrow.includes.libarrow cimport *
 from pyarrow.includes.libarrow_substrait cimport *
+from pyarrow._compute cimport FunctionRegistry
+
+
+from pyarrow._exec_plan cimport is_supported_execplan_output_type, execplan
+from pyarrow._compute import make_function_registry
+
+
+def make_extension_id_registry():
+    cdef:
+        shared_ptr[CExtensionIdRegistry] c_extid_registry
+        ExtensionIdRegistry registry
+
+    with nogil:
+        c_extid_registry = MakeExtensionIdRegistry()
+
+    return pyarrow_wrap_extension_id_registry(c_extid_registry)
+
+
+def _get_udf_code(func):
+    return frombytes(base64.b64encode(cloudpickle.dumps(func)))
+
+
+def get_udf_declarations(plan, extid_registry):
+    cdef:
+        shared_ptr[CBuffer] c_buf_plan
+        shared_ptr[CExtensionIdRegistry] c_extid_registry
+        vector[CUdfDeclaration] c_decls
+        vector[CUdfDeclaration].iterator c_decls_iter
+        vector[pair[shared_ptr[CDataType], c_bool]].iterator c_in_types_iter
+
+    c_buf_plan = pyarrow_unwrap_buffer(plan)
+    c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry)
+    with nogil:
+        c_res_decls = DeserializePlanUdfs(
+            deref(c_buf_plan), c_extid_registry.get())
+    c_decls = GetResultValue(c_res_decls)
+
+    decls = []
+    c_decls_iter = c_decls.begin()
+    while c_decls_iter != c_decls.end():
+        input_types = []
+        c_in_types_iter = deref(c_decls_iter).input_types.begin()
+        while c_in_types_iter != deref(c_decls_iter).input_types.end():
+            input_types.append((pyarrow_wrap_data_type(deref(c_in_types_iter).first),
+                                deref(c_in_types_iter).second))
+            inc(c_in_types_iter)
+        decls.append({
+            "name": frombytes(deref(c_decls_iter).name),
+            "code": frombytes(deref(c_decls_iter).code),
+            "summary": frombytes(deref(c_decls_iter).summary),
+            "description": frombytes(deref(c_decls_iter).description),
+            "input_types": input_types,
+            "output_type": (pyarrow_wrap_data_type(deref(c_decls_iter).output_type.first),
+                            deref(c_decls_iter).output_type.second),
+        })
+        inc(c_decls_iter)
+    return decls
+
+
+def register_function(extid_registry, id_uri, id_name, arrow_function_name):
+    cdef:
+        c_string c_id_uri, c_id_name, c_arrow_function_name
+        shared_ptr[CExtensionIdRegistry] c_extid_registry
+        CStatus c_status
+
+    c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry)
+    c_id_uri = id_uri or default_extension_types_uri()
+    c_id_name = tobytes(id_name)
+    c_arrow_function_name = tobytes(arrow_function_name)
+
+    with nogil:
+        c_status = RegisterFunction(
+            deref(c_extid_registry), c_id_uri, c_id_name, c_arrow_function_name
+        )
+
+    check_status(c_status)
+
 
+def register_udf_declarations(plan, extid_registry, func_registry, udf_decls=None):
+    if udf_decls is None:
+        udf_decls = get_udf_declarations(plan, extid_registry)
+    for udf_decl in udf_decls:
+        udf_name = udf_decl["name"]
+        udf_func = cloudpickle.loads(
+            base64.b64decode(tobytes(udf_decl["code"])))
+        udf_arg_names = list(inspect.signature(udf_func).parameters.keys())
+        udf_arg_types = udf_decl["input_types"]
+        register_function(extid_registry, None, udf_name, udf_name)
+        def udf(ctx, *args):
+            return udf_func(*args)
 
-def run_query(plan):
+        pc.register_scalar_function(
+            udf,
+            udf_name,
+            {"summary": udf_decl["summary"],
+                "description": udf_decl["description"]},
+            # range start from 1 to skip over udf scalar context argument
+            {udf_arg_names[i]: udf_arg_types[i][0]
+                for i in range(0 ,len(udf_arg_types))},
+            udf_decl["output_type"][0],
+            func_registry,
+        )
+
+
+def run_query_as(plan, extid_registry, func_registry, output_type=RecordBatchReader):
+    if output_type == RecordBatchReader:
+        return run_query(plan, extid_registry, func_registry)
+    return _run_query(plan, extid_registry, func_registry, output_type)
+
+
+def _run_query(plan, extid_registry, func_registry, output_type):
+    cdef:
+        shared_ptr[CBuffer] c_buf_plan
+        shared_ptr[CExtensionIdRegistry] c_extid_registry
+        CFunctionRegistry* c_func_registry
+        CResult[vector[CDeclaration]] c_res_decls
+        vector[CDeclaration] c_decls
+
+    if not is_supported_execplan_output_type(output_type):
+        raise TypeError(f"Unsupported output type {output_type}")
+
+    c_buf_plan = pyarrow_unwrap_buffer(plan)
+    c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry)
+    c_func_registry = pyarrow_unwrap_function_registry(func_registry)
+    if c_func_registry == NULL:
+        c_func_registry = (<FunctionRegistry>func_registry).registry
+    with nogil:
+        c_res_decls = DeserializePlans(
+            deref(c_buf_plan), c_extid_registry.get())
+    c_decls = GetResultValue(c_res_decls)
+    return execplan([], output_type, c_decls, True, c_func_registry)
+
+
+def run_query(plan, extid_registry, func_registry):

Review Comment:
   Also, what about `extid_registry`? Are there cases where it could be omitted? @westonpace 



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -52,6 +51,69 @@ Status CheckRelCommon(const RelMessage& rel) {
   return Status::OK();
 }
 
+Result<FieldRef> FromProto(const substrait::Expression& expr, const std::string& what) {
+  int32_t index;
+  switch (expr.rex_type_case()) {
+    case substrait::Expression::RexTypeCase::kSelection: {
+      const auto& selection = expr.selection();
+      switch (selection.root_type_case()) {
+        case substrait::Expression_FieldReference::RootTypeCase::kRootReference: {
+          break;
+        }
+        default: {
+          return Status::NotImplemented(
+              std::string("substrait::Expression with non-root-reference for ") + what);
+        }
+      }
+      switch (selection.reference_type_case()) {
+        case substrait::Expression_FieldReference::ReferenceTypeCase::kDirectReference: {
+          const auto& direct_reference = selection.direct_reference();
+          switch (direct_reference.reference_type_case()) {
+            case substrait::Expression_ReferenceSegment::ReferenceTypeCase::
+                kStructField: {
+              break;
+            }
+            default: {
+              return Status::NotImplemented(
+                  std::string("substrait::Expression with non-struct-field for ") + what);
+            }
+          }
+          const auto& struct_field = direct_reference.struct_field();
+          if (struct_field.has_child()) {
+            return Status::NotImplemented(
+                std::string("substrait::Expression with non-flat struct-field for ") +
+                what);
+          }
+          index = struct_field.field();
+          break;
+        }
+        default: {
+          return Status::NotImplemented(
+              std::string("substrait::Expression with non-direct reference for ") + what);
+        }
+      }
+      break;
+    }
+    default: {
+      return Status::NotImplemented(
+          std::string("substrait::Expression with non-selection for ") + what);
+    }
+  }
+  return FieldRef(FieldPath({index}));
+}
+
+Result<std::vector<FieldRef>> FromProto(
+    const google::protobuf::RepeatedPtrField<substrait::Expression>& exprs,
+    const std::string& what) {
+  std::vector<FieldRef> fields;

Review Comment:
   May want to presize this?



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -52,6 +51,69 @@ Status CheckRelCommon(const RelMessage& rel) {
   return Status::OK();
 }
 
+Result<FieldRef> FromProto(const substrait::Expression& expr, const std::string& what) {
+  int32_t index;
+  switch (expr.rex_type_case()) {
+    case substrait::Expression::RexTypeCase::kSelection: {
+      const auto& selection = expr.selection();
+      switch (selection.root_type_case()) {
+        case substrait::Expression_FieldReference::RootTypeCase::kRootReference: {
+          break;
+        }
+        default: {
+          return Status::NotImplemented(
+              std::string("substrait::Expression with non-root-reference for ") + what);
+        }
+      }
+      switch (selection.reference_type_case()) {
+        case substrait::Expression_FieldReference::ReferenceTypeCase::kDirectReference: {
+          const auto& direct_reference = selection.direct_reference();
+          switch (direct_reference.reference_type_case()) {
+            case substrait::Expression_ReferenceSegment::ReferenceTypeCase::
+                kStructField: {
+              break;
+            }
+            default: {
+              return Status::NotImplemented(
+                  std::string("substrait::Expression with non-struct-field for ") + what);
+            }
+          }
+          const auto& struct_field = direct_reference.struct_field();
+          if (struct_field.has_child()) {
+            return Status::NotImplemented(
+                std::string("substrait::Expression with non-flat struct-field for ") +
+                what);
+          }
+          index = struct_field.field();
+          break;
+        }
+        default: {
+          return Status::NotImplemented(
+              std::string("substrait::Expression with non-direct reference for ") + what);
+        }
+      }
+      break;
+    }
+    default: {
+      return Status::NotImplemented(
+          std::string("substrait::Expression with non-selection for ") + what);
+    }
+  }
+  return FieldRef(FieldPath({index}));
+}
+
+Result<std::vector<FieldRef>> FromProto(
+    const google::protobuf::RepeatedPtrField<substrait::Expression>& exprs,
+    const std::string& what) {
+  std::vector<FieldRef> fields;
+  int size = exprs.size();
+  for (int i = 0; i < size; i++) {
+    ARROW_ASSIGN_OR_RAISE(FieldRef field, FromProto(exprs[i], what));
+    fields.push_back(field);

Review Comment:
   ```suggestion
       fields.push_back(std::move(field));
   ```



##########
cpp/src/arrow/dataset/file_base.cc:
##########
@@ -331,6 +331,22 @@ class DatasetWritingSinkNodeConsumer : public compute::SinkNodeConsumer {
     return Status::OK();
   }
 
+  Status Init(compute::ExecNode* node) {
+    if (node == nullptr) {
+      return Status::Invalid("internal error - null node");
+    }
+    auto schema = node->inputs()[0]->output_schema();
+    if (schema.get() == nullptr) {
+      return Status::Invalid("internal error - null schema");
+    }
+    if (schema_.get() == nullptr) {
+      schema_ = schema;
+    } else if (schema_.get() != schema.get()) {

Review Comment:
   Is this really comparing the pointers by value? Don't you want to compare the underlying schemas instead?



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -109,6 +171,8 @@ Result<compute::Declaration> FromProto(const substrait::Rel& rel,
           path = item.uri_path_glob();
         }
 
+        util::string_view uri_file{item.uri_file()};

Review Comment:
   This seems unused.



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -16,7 +16,6 @@
 // under the License.
 
 #include "arrow/engine/substrait/relation_internal.h"
-

Review Comment:
   FTR, the convention would be to leave the blank link as it's separating the `.h` corresponding to this `.cc` from other included headers.



##########
cpp/src/arrow/engine/substrait/serde.h:
##########
@@ -115,6 +115,22 @@ ARROW_ENGINE_EXPORT Result<compute::ExecPlan> DeserializePlan(
     const Buffer& buf, const std::shared_ptr<dataset::WriteNodeOptions>& write_options,
     const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR);
 
+/// Factory function type for generating the write options of a node consuming the batches
+/// produced by each toplevel Substrait relation when deserializing a Substrait Plan.
+using WriteOptionsFactory = std::function<std::shared_ptr<dataset::WriteNodeOptions>()>;
+
+struct ARROW_ENGINE_EXPORT UdfDeclaration {
+  std::string name;
+  std::string code;
+  std::string summary;
+  std::string description;
+  std::vector<std::pair<std::shared_ptr<DataType>, bool>> input_types;
+  std::pair<std::shared_ptr<DataType>, bool> output_type;
+};
+
+ARROW_ENGINE_EXPORT Result<std::vector<UdfDeclaration>> DeserializePlanUdfs(

Review Comment:
   Add a docstring?



##########
cpp/src/arrow/engine/substrait/serde.cc:
##########
@@ -204,6 +203,48 @@ Result<compute::ExecPlan> DeserializePlan(
   return MakeSingleDeclarationPlan(declarations);
 }
 
+Result<std::vector<UdfDeclaration>> DeserializePlanUdfs(
+    const Buffer& buf, const ExtensionIdRegistry* registry) {
+  ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer<substrait::Plan>(buf));
+
+  ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan, registry, true));
+
+  std::vector<UdfDeclaration> decls;
+  /*

Review Comment:
   So, is this code that needs to be debugged and then enabled?
   If this PR is not finished, could you mark it as draft?



##########
cpp/src/arrow/engine/substrait/serde.h:
##########
@@ -115,6 +115,22 @@ ARROW_ENGINE_EXPORT Result<compute::ExecPlan> DeserializePlan(
     const Buffer& buf, const std::shared_ptr<dataset::WriteNodeOptions>& write_options,
     const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR);
 
+/// Factory function type for generating the write options of a node consuming the batches
+/// produced by each toplevel Substrait relation when deserializing a Substrait Plan.
+using WriteOptionsFactory = std::function<std::shared_ptr<dataset::WriteNodeOptions>()>;
+
+struct ARROW_ENGINE_EXPORT UdfDeclaration {
+  std::string name;
+  std::string code;
+  std::string summary;
+  std::string description;
+  std::vector<std::pair<std::shared_ptr<DataType>, bool>> input_types;
+  std::pair<std::shared_ptr<DataType>, bool> output_type;

Review Comment:
   It's not obvious what the `bool` is for. Can you perhaps use a helper struct, e.g.:
   ```suggestion
     struct TypeDeclaration {
       std::shared_ptr<DataType> type;
       bool xxx_some_suitable_name;
     };
     std::vector<TypeDeclaration> input_types;
     TypeDeclaration output_type;
   ```



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -52,6 +51,69 @@ Status CheckRelCommon(const RelMessage& rel) {
   return Status::OK();
 }
 
+Result<FieldRef> FromProto(const substrait::Expression& expr, const std::string& what) {
+  int32_t index;
+  switch (expr.rex_type_case()) {
+    case substrait::Expression::RexTypeCase::kSelection: {
+      const auto& selection = expr.selection();
+      switch (selection.root_type_case()) {
+        case substrait::Expression_FieldReference::RootTypeCase::kRootReference: {
+          break;
+        }
+        default: {
+          return Status::NotImplemented(
+              std::string("substrait::Expression with non-root-reference for ") + what);
+        }
+      }
+      switch (selection.reference_type_case()) {
+        case substrait::Expression_FieldReference::ReferenceTypeCase::kDirectReference: {
+          const auto& direct_reference = selection.direct_reference();
+          switch (direct_reference.reference_type_case()) {
+            case substrait::Expression_ReferenceSegment::ReferenceTypeCase::
+                kStructField: {
+              break;
+            }
+            default: {
+              return Status::NotImplemented(
+                  std::string("substrait::Expression with non-struct-field for ") + what);
+            }
+          }
+          const auto& struct_field = direct_reference.struct_field();
+          if (struct_field.has_child()) {
+            return Status::NotImplemented(
+                std::string("substrait::Expression with non-flat struct-field for ") +
+                what);
+          }
+          index = struct_field.field();
+          break;
+        }
+        default: {
+          return Status::NotImplemented(
+              std::string("substrait::Expression with non-direct reference for ") + what);
+        }
+      }
+      break;
+    }
+    default: {
+      return Status::NotImplemented(
+          std::string("substrait::Expression with non-selection for ") + what);
+    }
+  }
+  return FieldRef(FieldPath({index}));
+}
+
+Result<std::vector<FieldRef>> FromProto(
+    const google::protobuf::RepeatedPtrField<substrait::Expression>& exprs,
+    const std::string& what) {
+  std::vector<FieldRef> fields;
+  int size = exprs.size();
+  for (int i = 0; i < size; i++) {
+    ARROW_ASSIGN_OR_RAISE(FieldRef field, FromProto(exprs[i], what));

Review Comment:
   You can probably use a for-range construct:
   ```suggestion
     for (const auto& expr : exprs) {
       ARROW_ASSIGN_OR_RAISE(FieldRef field, FromProto(expr, what));
   ```



##########
cpp/src/arrow/engine/substrait/util.h:
##########
@@ -39,6 +39,9 @@ ARROW_ENGINE_EXPORT Result<std::shared_ptr<RecordBatchReader>> ExecuteSerialized
 ARROW_ENGINE_EXPORT Result<std::shared_ptr<Buffer>> SerializeJsonPlan(
     const std::string& substrait_json);
 
+ARROW_ENGINE_EXPORT Result<std::vector<compute::Declaration>> DeserializePlans(
+    const Buffer& buf, const ExtensionIdRegistry* registry);

Review Comment:
   There are already functions named `DeserializePlans` in `serde.h`. Isn't it a bit confusing to have another one similarly named here?
   
   Also, can you add a docstring?



##########
cpp/src/arrow/python/pyarrow.h:
##########
@@ -71,6 +77,8 @@ DECLARE_WRAP_FUNCTIONS(tensor, Tensor)
 DECLARE_WRAP_FUNCTIONS(batch, RecordBatch)
 DECLARE_WRAP_FUNCTIONS(table, Table)
 
+DECLARE_WRAP_FUNCTIONS(extension_id_registry, engine::ExtensionIdRegistry)

Review Comment:
   Hmm, this will expose the wrapper functions to C++ code, which doesn't seem to be used anywhere. Instead, you should wrap/unwrap purely on the Cython side, like for most other C++ classes.



##########
python/pyarrow/_compute.pxd:
##########
@@ -27,6 +27,9 @@ cdef class ScalarUdfContext(_Weakrefable):
 
     cdef void init(self, const CScalarUdfContext& c_context)
 
+cdef class FunctionRegistry(_Weakrefable):
+    cdef CFunctionRegistry* registry

Review Comment:
   Do we really want to use a raw pointer here? Who will own the `unique_ptr` of a nested registry?



##########
python/pyarrow/_compute.pyx:
##########
@@ -513,6 +514,13 @@ def function_registry():
     return _global_func_registry
 
 
+def make_function_registry():
+    up_registry = MakeFunctionRegistry()
+    c_registry = up_registry.get()
+    up_registry.release()

Review Comment:
   AFAICT, this effectively creates a memory leak as there's no `unique_ptr` owning the pointer anymore.



##########
python/pyarrow/_compute.pyx:
##########
@@ -2601,7 +2612,11 @@ def register_scalar_function(func, function_name, function_doc, in_types,
 
     c_func_name = tobytes(function_name)
 
-    func_spec = inspect.getfullargspec(func)
+    try:
+        func_spec = inspect.getfullargspec(func)
+        is_varargs = func_spec.varargs is not None
+    except:

Review Comment:
   `except` what? Bare `except` statements are discouraged, you should catch the precise exception type here.
   Also, please add a comment explaining when this can happen.



##########
python/pyarrow/_exec_plan.pyx:
##########
@@ -382,6 +391,9 @@ def _filter_table(table, expression, output_type=Table):
         vector[CDeclaration] c_decl_plan
         Expression expr = expression
 
+    if not is_supported_execplan_output_type(output_type):
+        raise TypeError(f"Unsupported output type {output_type}")

Review Comment:
   Why add this check since you're also raising the error below?



##########
cpp/src/arrow/engine/substrait/plan_internal.h:
##########
@@ -49,7 +49,8 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan)
 ARROW_ENGINE_EXPORT
 Result<ExtensionSet> GetExtensionSetFromPlan(
     const substrait::Plan& plan,
-    const ExtensionIdRegistry* registry = default_extension_id_registry());
+    const ExtensionIdRegistry* registry = default_extension_id_registry(),
+    bool exclude_functions = false);

Review Comment:
   Can you add documentation for this parameter in the docstring above?



##########
python/pyarrow/_substrait.pyx:
##########
@@ -15,35 +15,183 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import base64
+import cloudpickle
+import inspect
+
 # cython: language_level = 3
-from cython.operator cimport dereference as deref
+from cython.operator cimport dereference as deref, preincrement as inc
 
+from pyarrow import compute as pc
 from pyarrow import Buffer
+from pyarrow.lib import frombytes, tobytes
 from pyarrow.lib cimport *
 from pyarrow.includes.libarrow cimport *
 from pyarrow.includes.libarrow_substrait cimport *
+from pyarrow._compute cimport FunctionRegistry
+
+
+from pyarrow._exec_plan cimport is_supported_execplan_output_type, execplan
+from pyarrow._compute import make_function_registry
+
+
+def make_extension_id_registry():
+    cdef:
+        shared_ptr[CExtensionIdRegistry] c_extid_registry
+        ExtensionIdRegistry registry
+
+    with nogil:
+        c_extid_registry = MakeExtensionIdRegistry()
+
+    return pyarrow_wrap_extension_id_registry(c_extid_registry)
+
+
+def _get_udf_code(func):
+    return frombytes(base64.b64encode(cloudpickle.dumps(func)))
+
+
+def get_udf_declarations(plan, extid_registry):
+    cdef:
+        shared_ptr[CBuffer] c_buf_plan
+        shared_ptr[CExtensionIdRegistry] c_extid_registry
+        vector[CUdfDeclaration] c_decls
+        vector[CUdfDeclaration].iterator c_decls_iter
+        vector[pair[shared_ptr[CDataType], c_bool]].iterator c_in_types_iter
+
+    c_buf_plan = pyarrow_unwrap_buffer(plan)
+    c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry)
+    with nogil:
+        c_res_decls = DeserializePlanUdfs(
+            deref(c_buf_plan), c_extid_registry.get())
+    c_decls = GetResultValue(c_res_decls)
+
+    decls = []
+    c_decls_iter = c_decls.begin()
+    while c_decls_iter != c_decls.end():

Review Comment:
   Hmm, normally, Cython allows you to iterate a `std::vector` the Python way:
   ```suggestion
       for c_decl in c_decls:
   ```
   (you might have to `cdef c_decl` above)



##########
python/pyarrow/_exec_plan.pyx:
##########
@@ -36,7 +36,10 @@ from pyarrow._dataset import InMemoryDataset
 Initialize()  # Initialise support for Datasets in ExecPlan
 
 
-cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads=True):
+cdef is_supported_execplan_output_type(output_type):
+    return output_type in [Table, InMemoryDataset]
+
+cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads=True, CFunctionRegistry* c_func_registry=NULL):

Review Comment:
   Can you please wrap code to at most a 80 character width?



##########
python/pyarrow/_substrait.pyx:
##########
@@ -15,35 +15,183 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import base64
+import cloudpickle

Review Comment:
   We don't have a cloudpickle dependency, why is this required?



##########
python/pyarrow/_substrait.pyx:
##########
@@ -15,35 +15,183 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import base64
+import cloudpickle
+import inspect
+
 # cython: language_level = 3
-from cython.operator cimport dereference as deref
+from cython.operator cimport dereference as deref, preincrement as inc
 
+from pyarrow import compute as pc
 from pyarrow import Buffer
+from pyarrow.lib import frombytes, tobytes
 from pyarrow.lib cimport *
 from pyarrow.includes.libarrow cimport *
 from pyarrow.includes.libarrow_substrait cimport *
+from pyarrow._compute cimport FunctionRegistry
+
+
+from pyarrow._exec_plan cimport is_supported_execplan_output_type, execplan
+from pyarrow._compute import make_function_registry
+
+
+def make_extension_id_registry():
+    cdef:
+        shared_ptr[CExtensionIdRegistry] c_extid_registry
+        ExtensionIdRegistry registry
+
+    with nogil:
+        c_extid_registry = MakeExtensionIdRegistry()
+
+    return pyarrow_wrap_extension_id_registry(c_extid_registry)
+
+
+def _get_udf_code(func):
+    return frombytes(base64.b64encode(cloudpickle.dumps(func)))
+
+
+def get_udf_declarations(plan, extid_registry):
+    cdef:
+        shared_ptr[CBuffer] c_buf_plan
+        shared_ptr[CExtensionIdRegistry] c_extid_registry
+        vector[CUdfDeclaration] c_decls
+        vector[CUdfDeclaration].iterator c_decls_iter
+        vector[pair[shared_ptr[CDataType], c_bool]].iterator c_in_types_iter
+
+    c_buf_plan = pyarrow_unwrap_buffer(plan)
+    c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry)
+    with nogil:
+        c_res_decls = DeserializePlanUdfs(
+            deref(c_buf_plan), c_extid_registry.get())
+    c_decls = GetResultValue(c_res_decls)
+
+    decls = []
+    c_decls_iter = c_decls.begin()
+    while c_decls_iter != c_decls.end():
+        input_types = []
+        c_in_types_iter = deref(c_decls_iter).input_types.begin()
+        while c_in_types_iter != deref(c_decls_iter).input_types.end():
+            input_types.append((pyarrow_wrap_data_type(deref(c_in_types_iter).first),
+                                deref(c_in_types_iter).second))
+            inc(c_in_types_iter)
+        decls.append({
+            "name": frombytes(deref(c_decls_iter).name),
+            "code": frombytes(deref(c_decls_iter).code),
+            "summary": frombytes(deref(c_decls_iter).summary),
+            "description": frombytes(deref(c_decls_iter).description),
+            "input_types": input_types,
+            "output_type": (pyarrow_wrap_data_type(deref(c_decls_iter).output_type.first),
+                            deref(c_decls_iter).output_type.second),
+        })
+        inc(c_decls_iter)
+    return decls
+
+
+def register_function(extid_registry, id_uri, id_name, arrow_function_name):
+    cdef:
+        c_string c_id_uri, c_id_name, c_arrow_function_name
+        shared_ptr[CExtensionIdRegistry] c_extid_registry
+        CStatus c_status
+
+    c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry)
+    c_id_uri = id_uri or default_extension_types_uri()
+    c_id_name = tobytes(id_name)
+    c_arrow_function_name = tobytes(arrow_function_name)
+
+    with nogil:
+        c_status = RegisterFunction(
+            deref(c_extid_registry), c_id_uri, c_id_name, c_arrow_function_name
+        )
+
+    check_status(c_status)
+
 
+def register_udf_declarations(plan, extid_registry, func_registry, udf_decls=None):
+    if udf_decls is None:
+        udf_decls = get_udf_declarations(plan, extid_registry)
+    for udf_decl in udf_decls:
+        udf_name = udf_decl["name"]
+        udf_func = cloudpickle.loads(
+            base64.b64decode(tobytes(udf_decl["code"])))
+        udf_arg_names = list(inspect.signature(udf_func).parameters.keys())
+        udf_arg_types = udf_decl["input_types"]
+        register_function(extid_registry, None, udf_name, udf_name)
+        def udf(ctx, *args):
+            return udf_func(*args)
 
-def run_query(plan):
+        pc.register_scalar_function(
+            udf,
+            udf_name,
+            {"summary": udf_decl["summary"],
+                "description": udf_decl["description"]},
+            # range start from 1 to skip over udf scalar context argument
+            {udf_arg_names[i]: udf_arg_types[i][0]
+                for i in range(0 ,len(udf_arg_types))},
+            udf_decl["output_type"][0],
+            func_registry,
+        )
+
+
+def run_query_as(plan, extid_registry, func_registry, output_type=RecordBatchReader):
+    if output_type == RecordBatchReader:
+        return run_query(plan, extid_registry, func_registry)
+    return _run_query(plan, extid_registry, func_registry, output_type)
+
+
+def _run_query(plan, extid_registry, func_registry, output_type):
+    cdef:
+        shared_ptr[CBuffer] c_buf_plan
+        shared_ptr[CExtensionIdRegistry] c_extid_registry
+        CFunctionRegistry* c_func_registry
+        CResult[vector[CDeclaration]] c_res_decls
+        vector[CDeclaration] c_decls
+
+    if not is_supported_execplan_output_type(output_type):
+        raise TypeError(f"Unsupported output type {output_type}")
+
+    c_buf_plan = pyarrow_unwrap_buffer(plan)
+    c_extid_registry = pyarrow_unwrap_extension_id_registry(extid_registry)
+    c_func_registry = pyarrow_unwrap_function_registry(func_registry)
+    if c_func_registry == NULL:
+        c_func_registry = (<FunctionRegistry>func_registry).registry
+    with nogil:
+        c_res_decls = DeserializePlans(
+            deref(c_buf_plan), c_extid_registry.get())
+    c_decls = GetResultValue(c_res_decls)
+    return execplan([], output_type, c_decls, True, c_func_registry)
+
+
+def run_query(plan, extid_registry, func_registry):

Review Comment:
   Should `func_registry` be optional?
   ```suggestion
   def run_query(plan, extid_registry, func_registry=None):
   ```



##########
python/pyarrow/compute.pxi:
##########
@@ -0,0 +1,33 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# separating out this base class is easier than unifying it into
+# FunctionRegistry, which lives outside libarrow
+cdef class BaseFunctionRegistry(_Weakrefable):
+    cdef CFunctionRegistry* registry
+
+cdef class ExtensionIdRegistry(_Weakrefable):
+    def __cinit__(self):
+        self.registry = NULL
+
+    def __init__(self):
+        raise TypeError("Do not call ExtensionIdRegistry's constructor directly, use "
+                        "the `MakeExtensionIdRegistry` function instead.")
+
+    cdef void init(self, shared_ptr[CExtensionIdRegistry]& registry):
+        self.sp_registry = registry

Review Comment:
   Hmm, I don't see `sp_registry` defined anywhere, did I miss something?



##########
python/pyarrow/public-api.pxi:
##########
@@ -416,3 +416,42 @@ cdef api object pyarrow_wrap_batch(
     cdef RecordBatch batch = RecordBatch.__new__(RecordBatch)
     batch.init(cbatch)
     return batch
+
+
+cdef api bint pyarrow_is_function_registry(object registry):

Review Comment:
   Are these meant to be exposed to third-party C++ or Cython code? Otherwise we shouldn't define such public (un)wrapping functions.



##########
python/pyarrow/tests/test_substrait.py:
##########
@@ -74,7 +75,9 @@ def test_run_serialized_query(tmpdir):
 
     buf = pa._substrait._parse_json_plan(query)
 
-    reader = substrait.run_query(buf)
+    extid_registry = substrait.make_extension_id_registry()
+    func_registry = substrait.make_function_registry()
+    reader = substrait.run_query(buf, extid_registry, func_registry)

Review Comment:
   Can there also be a test with `func_registry` omitted or None?



##########
python/pyarrow/_substrait.pyx:
##########
@@ -15,35 +15,183 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import base64
+import cloudpickle
+import inspect
+
 # cython: language_level = 3
-from cython.operator cimport dereference as deref
+from cython.operator cimport dereference as deref, preincrement as inc
 
+from pyarrow import compute as pc
 from pyarrow import Buffer
+from pyarrow.lib import frombytes, tobytes
 from pyarrow.lib cimport *
 from pyarrow.includes.libarrow cimport *
 from pyarrow.includes.libarrow_substrait cimport *
+from pyarrow._compute cimport FunctionRegistry
+
+
+from pyarrow._exec_plan cimport is_supported_execplan_output_type, execplan
+from pyarrow._compute import make_function_registry
+
+
+def make_extension_id_registry():

Review Comment:
   Public APIs should get a docstring.



##########
python/pyarrow/tests/test_substrait.py:
##########
@@ -22,6 +22,7 @@
 import pyarrow as pa
 from pyarrow.lib import tobytes
 from pyarrow.lib import ArrowInvalid
+from pyarrow.substrait import make_extension_id_registry

Review Comment:
   It's not useful at all, is it? Can just call `substrait.make_extension_id_registry`?



##########
cpp/src/arrow/engine/substrait/plan_internal.h:
##########
@@ -49,7 +49,8 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan)
 ARROW_ENGINE_EXPORT
 Result<ExtensionSet> GetExtensionSetFromPlan(
     const substrait::Plan& plan,
-    const ExtensionIdRegistry* registry = default_extension_id_registry());
+    const ExtensionIdRegistry* registry = default_extension_id_registry(),
+    bool exclude_functions = false);

Review Comment:
   Also, as a nit, double negatives are not terrific, so I would instead suggest `bool include_functions = true`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org