You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "icexelloss (via GitHub)" <gi...@apache.org> on 2023/06/26 14:21:02 UTC

[GitHub] [arrow] icexelloss commented on a diff in pull request #36253: GH-36252: [Python] Add non decomposable hash aggregate UDF

icexelloss commented on code in PR #36253:
URL: https://github.com/apache/arrow/pull/36253#discussion_r1242281854


##########
python/pyarrow/src/arrow/python/udf.cc:
##########
@@ -215,9 +249,162 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
     return Status::OK();
   }
 
-  UdfWrapperCallback agg_cb;
+  std::shared_ptr<OwnedRefNoGIL> function;
+  UdfWrapperCallback cb;
+  std::vector<std::shared_ptr<RecordBatch>> values;
+  std::shared_ptr<Schema> input_schema;
+  std::shared_ptr<DataType> output_type;
+};
+
+struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
+  PythonUdfHashAggregatorImpl(std::shared_ptr<OwnedRefNoGIL> function,
+                              UdfWrapperCallback cb,
+                              std::vector<std::shared_ptr<DataType>> input_types,
+                              std::shared_ptr<DataType> output_type)
+      : function(function), cb(std::move(cb)), output_type(std::move(output_type)) {
+    Py_INCREF(function->obj());
+    std::vector<std::shared_ptr<Field>> fields;
+    for (size_t i = 0; i < input_types.size(); i++) {
+      fields.push_back(field("", input_types[i]));
+    }
+    input_schema = schema(std::move(fields));
+  };
+
+  ~PythonUdfHashAggregatorImpl() override {
+    if (_Py_IsFinalizing()) {
+      function->detach();
+    }
+  }
+
+  /// @brief Same as ApplyGrouping in parition.cc
+  /// Replicated the code here to avoid complicating the dependencies
+  static Result<RecordBatchVector> ApplyGroupings(
+      const ListArray& groupings, const std::shared_ptr<RecordBatch>& batch) {
+    ARROW_ASSIGN_OR_RAISE(Datum sorted,
+                          compute::Take(batch, groupings.data()->child_data[0]));
+
+    const auto& sorted_batch = *sorted.record_batch();
+
+    RecordBatchVector out(static_cast<size_t>(groupings.length()));
+    for (size_t i = 0; i < out.size(); ++i) {
+      out[i] = sorted_batch.Slice(groupings.value_offset(i), groupings.value_length(i));
+    }
+
+    return out;
+  }
+
+  Status Resize(KernelContext* ctx, int64_t new_num_groups) {
+    // We only need to change num_groups in resize
+    // similar to other hash aggregate kernels
+    num_groups = new_num_groups;
+    return Status::OK();
+  }
+
+  Status Consume(KernelContext* ctx, const ExecSpan& batch) {
+    ARROW_ASSIGN_OR_RAISE(
+        std::shared_ptr<RecordBatch> rb,
+        batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool()));
+
+    // This is similar to GroupedListImpl
+    // last array is the group id
+    const ArraySpan& groups_array_data = batch[batch.num_values() - 1].array;
+    DCHECK_EQ(groups_array_data.offset, 0);
+    int64_t batch_num_values = groups_array_data.length;
+    const auto* batch_groups = groups_array_data.GetValues<uint32_t>(1, 0);
+    RETURN_NOT_OK(groups.Append(batch_groups, batch_num_values));
+    values.push_back(std::move(rb));
+    num_values += batch_num_values;
+    return Status::OK();
+  }
+  Status Merge(KernelContext* ctx, KernelState&& other_state,
+               const ArrayData& group_id_mapping) {
+    // This is similar to GroupedListImpl
+    auto& other = checked_cast<PythonUdfHashAggregatorImpl&>(other_state);
+    auto& other_values = other.values;
+    const uint32_t* other_raw_groups = other.groups.data();
+    values.insert(values.end(), std::make_move_iterator(other_values.begin()),
+                  std::make_move_iterator(other_values.end()));
+
+    auto g = group_id_mapping.GetValues<uint32_t>(1);
+    for (uint32_t other_g = 0; static_cast<int64_t>(other_g) < other.num_values;
+         ++other_g) {
+      // Different state can have different group_id mappings, so we
+      // need to translate the ids
+      RETURN_NOT_OK(groups.Append(g[other_raw_groups[other_g]]));
+    }
+
+    num_values += other.num_values;
+    return Status::OK();

Review Comment:
   I don't think `num_groups` need to be updated here. Reasoning:
   
   From the code in https://github.com/apache/arrow/blob/main/cpp/src/arrow/compute/kernels/hash_aggregate.cc#L233 and https://github.com/apache/arrow/blob/main/cpp/src/arrow/acero/groupby_aggregate_node.cc#L248
   (1) Other hash kernel implementation updates the num_groups in `Resize`
   (2) `resize` is always called before `consume` and `merge`
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org