You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2023/06/26 09:07:43 UTC

[arrow] branch main updated: GH-36128: [C++][Compute] Allow multiplication between duration and all integer types (#36231)

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5dd4cc08f1 GH-36128: [C++][Compute] Allow multiplication between duration and all integer types (#36231)
5dd4cc08f1 is described below

commit 5dd4cc08f1ed956b5e0f78f5782582f9b33c2661
Author: Jin Shang <sh...@gmail.com>
AuthorDate: Mon Jun 26 17:07:36 2023 +0800

    GH-36128: [C++][Compute] Allow multiplication between duration and all integer types (#36231)
    
    
    
    ### Rationale for this change
    
    Currently durations can only be multiplied with int64, but no other integer types.
    
    ### What changes are included in this PR?
    
    Allow duration types to be multiplied with any integer type.
    1. For `multiply`, new kernels are added to support the new types.
    2. For `multiply_checked`, integers will be casted to int64. We can't add new kernels because the MultiplyChecked op class requires both operands to have the same type.
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    No.
    * Closes: #36128
    
    Authored-by: Jin Shang <sh...@gmail.com>
    Signed-off-by: Antoine Pitrou <an...@python.org>
---
 cpp/src/arrow/compute/kernels/codegen_internal.cc  | 15 +++++
 cpp/src/arrow/compute/kernels/codegen_internal.h   |  8 ++-
 cpp/src/arrow/compute/kernels/scalar_arithmetic.cc |  6 ++
 .../arrow/compute/kernels/scalar_temporal_test.cc  | 71 +++++++++++++---------
 4 files changed, 69 insertions(+), 31 deletions(-)

diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc
index 2625520d60..e0156caecf 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc
@@ -482,6 +482,21 @@ bool HasDecimal(const std::vector<TypeHolder>& types) {
   return false;
 }
 
+void PromoteIntegerForDurationArithmetic(std::vector<TypeHolder>* types) {
+  bool has_duration = std::any_of(types->begin(), types->end(), [](const TypeHolder& t) {
+    return t.id() == Type::DURATION;
+  });
+
+  if (!has_duration) return;
+
+  // Require implicit casts to int64 to match duration's bit width
+  for (auto& type : *types) {
+    if (is_integer(type.id())) {
+      type = int64();
+    }
+  }
+}
+
 }  // namespace internal
 }  // namespace compute
 }  // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h
index 6224a9fc2a..3c6c0d63fb 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -46,6 +46,7 @@
 #include "arrow/util/decimal.h"
 #include "arrow/util/logging.h"
 #include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
 #include "arrow/visit_data_inline.h"
 
 namespace arrow {
@@ -1337,7 +1338,7 @@ ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) {
 
 // END of kernel generator-dispatchers
 // ----------------------------------------------------------------------
-
+// BEGIN of DispatchBest helpers
 ARROW_EXPORT
 void EnsureDictionaryDecoded(std::vector<TypeHolder>* types);
 
@@ -1396,6 +1397,11 @@ Status CastDecimalArgs(TypeHolder* begin, size_t count);
 ARROW_EXPORT
 bool HasDecimal(const std::vector<TypeHolder>& types);
 
+ARROW_EXPORT
+void PromoteIntegerForDurationArithmetic(std::vector<TypeHolder>* types);
+
+// END of DispatchBest helpers
+// ----------------------------------------------------------------------
 }  // namespace internal
 }  // namespace compute
 }  // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
index 249da4758e..2c7363b3ca 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
@@ -26,6 +26,7 @@
 #include "arrow/compute/api_scalar.h"
 #include "arrow/compute/cast.h"
 #include "arrow/compute/kernels/base_arithmetic_internal.h"
+#include "arrow/compute/kernels/codegen_internal.h"
 #include "arrow/compute/kernels/common_internal.h"
 #include "arrow/compute/kernels/util_internal.h"
 #include "arrow/type.h"
@@ -640,6 +641,11 @@ struct ArithmeticFunction : ScalarFunction {
           ReplaceTypes(type, types);
         }
       }
+
+      if (name_ == "multiply" || name_ == "multiply_checked" || name_ == "divide" ||
+          name_ == "divide_checked") {
+        PromoteIntegerForDurationArithmetic(types);
+      }
     }
 
     if (auto kernel = DispatchExactImpl(this, *types)) return kernel;
diff --git a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
index 5cdf6f2bcf..cd8abf6e92 100644
--- a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
@@ -26,6 +26,7 @@
 #include "arrow/testing/matchers.h"
 #include "arrow/testing/util.h"
 #include "arrow/type.h"
+#include "arrow/type_traits.h"
 #include "arrow/util/checked_cast.h"
 #include "arrow/util/formatting.h"
 #include "arrow/util/logging.h"
@@ -1671,41 +1672,51 @@ TEST_F(ScalarTemporalTest, TestTemporalMultiplyDuration) {
   ArrayFromVector<Int64Type, int64_t>({max, max, max, max, max}, &max_array);
 
   for (auto u : TimeUnit::values()) {
-    auto unit = duration(u);
-    auto durations = ArrayFromJSON(unit, R"([0, -1, 2, 6, null])");
-    auto multipliers = ArrayFromJSON(int64(), R"([0, 3, 2, 7, null])");
-    auto durations_multiplied = ArrayFromJSON(unit, R"([0, -3, 4, 42, null])");
-
-    CheckScalarBinaryCommutative("multiply", durations, multipliers,
-                                 durations_multiplied);
-    CheckScalarBinaryCommutative("multiply_checked", durations, multipliers,
-                                 durations_multiplied);
-
-    EXPECT_RAISES_WITH_MESSAGE_THAT(
-        Invalid, ::testing::HasSubstr("Invalid: overflow"),
-        CallFunction("multiply_checked", {durations, max_array}));
-    EXPECT_RAISES_WITH_MESSAGE_THAT(
-        Invalid, ::testing::HasSubstr("Invalid: overflow"),
-        CallFunction("multiply_checked", {max_array, durations}));
+    for (auto numeric : NumericTypes()) {
+      if (!is_integer(numeric->id())) continue;
+      auto unit = duration(u);
+      auto durations = ArrayFromJSON(unit, R"([0, -1, 2, 6, null])");
+      auto multipliers = ArrayFromJSON(numeric, R"([0, 3, 2, 7, null])");
+      auto durations_multiplied = ArrayFromJSON(unit, R"([0, -3, 4, 42, null])");
+
+      CheckScalarBinaryCommutative("multiply", durations, multipliers,
+                                   durations_multiplied);
+      CheckScalarBinaryCommutative("multiply_checked", durations, multipliers,
+                                   durations_multiplied);
+
+      EXPECT_RAISES_WITH_MESSAGE_THAT(
+          Invalid, ::testing::HasSubstr("Invalid: overflow"),
+          CallFunction("multiply_checked", {durations, max_array}));
+      EXPECT_RAISES_WITH_MESSAGE_THAT(
+          Invalid, ::testing::HasSubstr("Invalid: overflow"),
+          CallFunction("multiply_checked", {max_array, durations}));
+    }
   }
 }
 
 TEST_F(ScalarTemporalTest, TestTemporalDivideDuration) {
   for (auto u : TimeUnit::values()) {
-    auto unit = duration(u);
-    auto divided_durations = ArrayFromJSON(unit, R"([0, -1, -2, 6, null])");
-    auto divisors = ArrayFromJSON(int64(), R"([3, 3, -2, 7, null])");
-    auto durations = ArrayFromJSON(unit, R"([1, -3, 4, 42, null])");
-    auto zeros = ArrayFromJSON(int64(), R"([0, 0, 0, 0, null])");
-    CheckScalarBinary("divide", durations, divisors, divided_durations);
-    CheckScalarBinary("divide_checked", durations, divisors, divided_durations);
-
-    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
-                                    ::testing::HasSubstr("Invalid: divide by zero"),
-                                    CallFunction("divide", {durations, zeros}));
-    EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
-                                    ::testing::HasSubstr("Invalid: divide by zero"),
-                                    CallFunction("divide_checked", {durations, zeros}));
+    for (auto numeric : NumericTypes()) {
+      if (!is_integer(numeric->id())) continue;
+      auto unit = duration(u);
+      auto divided_durations = is_signed_integer(numeric->id())
+                                   ? ArrayFromJSON(unit, R"([0, -1, -2, 6, null])")
+                                   : ArrayFromJSON(unit, R"([0, -1, 2, 6, null])");
+      auto divisors = is_signed_integer(numeric->id())
+                          ? ArrayFromJSON(numeric, R"([3, 3, -2, 7, null])")
+                          : ArrayFromJSON(numeric, R"([3, 3, 2, 7, null])");
+      auto durations = ArrayFromJSON(unit, R"([1, -3, 4, 42, null])");
+      auto zeros = ArrayFromJSON(numeric, R"([0, 0, 0, 0, null])");
+      CheckScalarBinary("divide", durations, divisors, divided_durations);
+      CheckScalarBinary("divide_checked", durations, divisors, divided_durations);
+
+      EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+                                      ::testing::HasSubstr("Invalid: divide by zero"),
+                                      CallFunction("divide", {durations, zeros}));
+      EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+                                      ::testing::HasSubstr("Invalid: divide by zero"),
+                                      CallFunction("divide_checked", {durations, zeros}));
+    }
   }
 }