You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2023/05/08 21:35:02 UTC

[arrow] branch main updated: GH-35468: [C++] Fix Acero var/std for multiple batches (#35469)

This is an automated email from the ASF dual-hosted git repository.

westonpace pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5b8eb64780 GH-35468: [C++] Fix Acero var/std for multiple batches (#35469)
5b8eb64780 is described below

commit 5b8eb64780c9c017c4375fcf94bf009e978bf087
Author: rtpsw <rt...@hotmail.com>
AuthorDate: Tue May 9 00:34:54 2023 +0300

    GH-35468: [C++] Fix Acero var/std for multiple batches (#35469)
    
    See https://github.com/apache/arrow/issues/35468
    
    **This PR contains a "Critical Fix".**
    
    The current result of Acero var/std for multiple batches is incorrect.
    * Closes: #35468
    
    Authored-by: Yaron Gvili <rt...@hotmail.com>
    Signed-off-by: Weston Pace <we...@gmail.com>
---
 cpp/src/arrow/acero/groupby_test.cc                | 27 ++++++++++++++++++++++
 cpp/src/arrow/compute/kernels/aggregate_var_std.cc |  8 ++++---
 2 files changed, 32 insertions(+), 3 deletions(-)

diff --git a/cpp/src/arrow/acero/groupby_test.cc b/cpp/src/arrow/acero/groupby_test.cc
index 5710ad2598..7fe67a1f0d 100644
--- a/cpp/src/arrow/acero/groupby_test.cc
+++ b/cpp/src/arrow/acero/groupby_test.cc
@@ -21,6 +21,7 @@
 
 #include <memory>
 
+#include "arrow/table.h"
 #include "arrow/testing/gtest_util.h"
 
 namespace arrow {
@@ -123,5 +124,31 @@ TEST(GroupByConvenienceFunc, Invalid) {
       TableGroupBy(in_table, {{"add", {"value"}, "value_add"}}, {}));
 }
 
+void TestVarStdMultiBatch(const std::string& var_std_func_name) {
+  std::shared_ptr<Schema> in_schema = schema({field("value", float64())});
+  std::shared_ptr<Table> in_table = TableFromJSON(in_schema, {R"([
+    [1],
+    [2],
+    [3]
+  ])",
+                                                              R"([
+    [4],
+    [4],
+    [4]
+  ])"});
+
+  ASSERT_OK_AND_ASSIGN(std::shared_ptr<Table> actual,
+                       TableGroupBy(in_table, {{var_std_func_name, {"value"}, "x"}}, {},
+                                    /*use_threads=*/false));
+
+  ASSERT_OK_AND_ASSIGN(auto var_scalar, actual->column(0)->GetScalar(0));
+  // the next assertion will fail if only the second batch affects the result
+  ASSERT_NE(0, std::dynamic_pointer_cast<DoubleScalar>(var_scalar)->value);
+}
+
+TEST(GroupByConvenienceFunc, VarianceMultiBatch) { TestVarStdMultiBatch("variance"); }
+
+TEST(GroupByConvenienceFunc, StdDevMultiBatch) { TestVarStdMultiBatch("stddev"); }
+
 }  // namespace acero
 }  // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc
index e650285dfd..c2fab48dbe 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_var_std.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_var_std.cc
@@ -71,9 +71,11 @@ struct VarStdState {
           return (v - mean) * (v - mean);
         });
 
-    this->count = count;
-    this->mean = mean;
-    this->m2 = m2;
+    ThisType state(decimal_scale, options);
+    state.count = count;
+    state.mean = mean;
+    state.m2 = m2;
+    this->MergeFrom(state);
   }
 
   // int32/16/8: textbook one pass algorithm with integer arithmetic