You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2023/06/26 09:07:43 UTC
[arrow] branch main updated: GH-36128: [C++][Compute] Allow multiplication between duration and all integer types (#36231)
This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 5dd4cc08f1 GH-36128: [C++][Compute] Allow multiplication between duration and all integer types (#36231)
5dd4cc08f1 is described below
commit 5dd4cc08f1ed956b5e0f78f5782582f9b33c2661
Author: Jin Shang <sh...@gmail.com>
AuthorDate: Mon Jun 26 17:07:36 2023 +0800
GH-36128: [C++][Compute] Allow multiplication between duration and all integer types (#36231)
### Rationale for this change
Currently durations can only be multiplied with int64, but no other integer types.
### What changes are included in this PR?
Allow duration types to be multiplied with any integer type.
1. For `multiply`, new kernels are added to support the new types.
2. For `multiply_checked`, integers will be casted to int64. We can't add new kernels because the MultiplyChecked op class requires both operands to have the same type.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
No.
* Closes: #36128
Authored-by: Jin Shang <sh...@gmail.com>
Signed-off-by: Antoine Pitrou <an...@python.org>
---
cpp/src/arrow/compute/kernels/codegen_internal.cc | 15 +++++
cpp/src/arrow/compute/kernels/codegen_internal.h | 8 ++-
cpp/src/arrow/compute/kernels/scalar_arithmetic.cc | 6 ++
.../arrow/compute/kernels/scalar_temporal_test.cc | 71 +++++++++++++---------
4 files changed, 69 insertions(+), 31 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc
index 2625520d60..e0156caecf 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc
@@ -482,6 +482,21 @@ bool HasDecimal(const std::vector<TypeHolder>& types) {
return false;
}
+void PromoteIntegerForDurationArithmetic(std::vector<TypeHolder>* types) {
+ bool has_duration = std::any_of(types->begin(), types->end(), [](const TypeHolder& t) {
+ return t.id() == Type::DURATION;
+ });
+
+ if (!has_duration) return;
+
+ // Require implicit casts to int64 to match duration's bit width
+ for (auto& type : *types) {
+ if (is_integer(type.id())) {
+ type = int64();
+ }
+ }
+}
+
} // namespace internal
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h
index 6224a9fc2a..3c6c0d63fb 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -46,6 +46,7 @@
#include "arrow/util/decimal.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
+#include "arrow/util/visibility.h"
#include "arrow/visit_data_inline.h"
namespace arrow {
@@ -1337,7 +1338,7 @@ ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) {
// END of kernel generator-dispatchers
// ----------------------------------------------------------------------
-
+// BEGIN of DispatchBest helpers
ARROW_EXPORT
void EnsureDictionaryDecoded(std::vector<TypeHolder>* types);
@@ -1396,6 +1397,11 @@ Status CastDecimalArgs(TypeHolder* begin, size_t count);
ARROW_EXPORT
bool HasDecimal(const std::vector<TypeHolder>& types);
+ARROW_EXPORT
+void PromoteIntegerForDurationArithmetic(std::vector<TypeHolder>* types);
+
+// END of DispatchBest helpers
+// ----------------------------------------------------------------------
} // namespace internal
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
index 249da4758e..2c7363b3ca 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
@@ -26,6 +26,7 @@
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/cast.h"
#include "arrow/compute/kernels/base_arithmetic_internal.h"
+#include "arrow/compute/kernels/codegen_internal.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/compute/kernels/util_internal.h"
#include "arrow/type.h"
@@ -640,6 +641,11 @@ struct ArithmeticFunction : ScalarFunction {
ReplaceTypes(type, types);
}
}
+
+ if (name_ == "multiply" || name_ == "multiply_checked" || name_ == "divide" ||
+ name_ == "divide_checked") {
+ PromoteIntegerForDurationArithmetic(types);
+ }
}
if (auto kernel = DispatchExactImpl(this, *types)) return kernel;
diff --git a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
index 5cdf6f2bcf..cd8abf6e92 100644
--- a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
@@ -26,6 +26,7 @@
#include "arrow/testing/matchers.h"
#include "arrow/testing/util.h"
#include "arrow/type.h"
+#include "arrow/type_traits.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/formatting.h"
#include "arrow/util/logging.h"
@@ -1671,41 +1672,51 @@ TEST_F(ScalarTemporalTest, TestTemporalMultiplyDuration) {
ArrayFromVector<Int64Type, int64_t>({max, max, max, max, max}, &max_array);
for (auto u : TimeUnit::values()) {
- auto unit = duration(u);
- auto durations = ArrayFromJSON(unit, R"([0, -1, 2, 6, null])");
- auto multipliers = ArrayFromJSON(int64(), R"([0, 3, 2, 7, null])");
- auto durations_multiplied = ArrayFromJSON(unit, R"([0, -3, 4, 42, null])");
-
- CheckScalarBinaryCommutative("multiply", durations, multipliers,
- durations_multiplied);
- CheckScalarBinaryCommutative("multiply_checked", durations, multipliers,
- durations_multiplied);
-
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, ::testing::HasSubstr("Invalid: overflow"),
- CallFunction("multiply_checked", {durations, max_array}));
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, ::testing::HasSubstr("Invalid: overflow"),
- CallFunction("multiply_checked", {max_array, durations}));
+ for (auto numeric : NumericTypes()) {
+ if (!is_integer(numeric->id())) continue;
+ auto unit = duration(u);
+ auto durations = ArrayFromJSON(unit, R"([0, -1, 2, 6, null])");
+ auto multipliers = ArrayFromJSON(numeric, R"([0, 3, 2, 7, null])");
+ auto durations_multiplied = ArrayFromJSON(unit, R"([0, -3, 4, 42, null])");
+
+ CheckScalarBinaryCommutative("multiply", durations, multipliers,
+ durations_multiplied);
+ CheckScalarBinaryCommutative("multiply_checked", durations, multipliers,
+ durations_multiplied);
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Invalid: overflow"),
+ CallFunction("multiply_checked", {durations, max_array}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, ::testing::HasSubstr("Invalid: overflow"),
+ CallFunction("multiply_checked", {max_array, durations}));
+ }
}
}
TEST_F(ScalarTemporalTest, TestTemporalDivideDuration) {
for (auto u : TimeUnit::values()) {
- auto unit = duration(u);
- auto divided_durations = ArrayFromJSON(unit, R"([0, -1, -2, 6, null])");
- auto divisors = ArrayFromJSON(int64(), R"([3, 3, -2, 7, null])");
- auto durations = ArrayFromJSON(unit, R"([1, -3, 4, 42, null])");
- auto zeros = ArrayFromJSON(int64(), R"([0, 0, 0, 0, null])");
- CheckScalarBinary("divide", durations, divisors, divided_durations);
- CheckScalarBinary("divide_checked", durations, divisors, divided_durations);
-
- EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
- ::testing::HasSubstr("Invalid: divide by zero"),
- CallFunction("divide", {durations, zeros}));
- EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
- ::testing::HasSubstr("Invalid: divide by zero"),
- CallFunction("divide_checked", {durations, zeros}));
+ for (auto numeric : NumericTypes()) {
+ if (!is_integer(numeric->id())) continue;
+ auto unit = duration(u);
+ auto divided_durations = is_signed_integer(numeric->id())
+ ? ArrayFromJSON(unit, R"([0, -1, -2, 6, null])")
+ : ArrayFromJSON(unit, R"([0, -1, 2, 6, null])");
+ auto divisors = is_signed_integer(numeric->id())
+ ? ArrayFromJSON(numeric, R"([3, 3, -2, 7, null])")
+ : ArrayFromJSON(numeric, R"([3, 3, 2, 7, null])");
+ auto durations = ArrayFromJSON(unit, R"([1, -3, 4, 42, null])");
+ auto zeros = ArrayFromJSON(numeric, R"([0, 0, 0, 0, null])");
+ CheckScalarBinary("divide", durations, divisors, divided_durations);
+ CheckScalarBinary("divide_checked", durations, divisors, divided_durations);
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("Invalid: divide by zero"),
+ CallFunction("divide", {durations, zeros}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("Invalid: divide by zero"),
+ CallFunction("divide_checked", {durations, zeros}));
+ }
}
}