You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2020/08/27 11:48:58 UTC

[arrow] branch master updated: ARROW-9811: [C++] Unchecked floating point division by 0 should succeed

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 6e04489  ARROW-9811: [C++] Unchecked floating point division by 0 should succeed
6e04489 is described below

commit 6e044895fc236699c8c37e1dfb016e3cad9ef8d1
Author: liyafan82 <fa...@foxmail.com>
AuthorDate: Thu Aug 27 13:48:23 2020 +0200

    ARROW-9811: [C++] Unchecked floating point division by 0 should succeed
    
    See https://issues.apache.org/jira/browse/ARROW-9811
    
    Closes #8036 from liyafan82/fly_0824_div0
    
    Lead-authored-by: liyafan82 <fa...@foxmail.com>
    Co-authored-by: Antoine Pitrou <an...@python.org>
    Signed-off-by: Antoine Pitrou <an...@python.org>
---
 cpp/cmake_modules/san-config.cmake                 |  7 ++++--
 cpp/src/arrow/compare.cc                           | 28 ++++++++++++++++++----
 cpp/src/arrow/compare.h                            |  4 +++-
 cpp/src/arrow/compute/kernels/scalar_arithmetic.cc |  4 ----
 .../compute/kernels/scalar_arithmetic_test.cc      | 28 +++++++++++++++++-----
 cpp/src/arrow/scalar.cc                            |  4 +++-
 cpp/src/arrow/scalar.h                             |  4 +++-
 cpp/src/arrow/testing/gtest_util.cc                | 12 ++++++----
 cpp/src/arrow/testing/gtest_util.h                 | 11 +++++----
 9 files changed, 73 insertions(+), 29 deletions(-)

diff --git a/cpp/cmake_modules/san-config.cmake b/cpp/cmake_modules/san-config.cmake
index 2e28078..5eee627 100644
--- a/cpp/cmake_modules/san-config.cmake
+++ b/cpp/cmake_modules/san-config.cmake
@@ -35,14 +35,17 @@ endif()
 # - disable 'vptr' because of RTTI issues across shared libraries (?)
 # - disable 'alignment' because unaligned access is really OK on Nehalem and we do it
 #   all over the place.
-# - disable 'function' because it appears to give a false positive https://github.com/google/sanitizers/issues/911
+# - disable 'function' because it appears to give a false positive
+#   (https://github.com/google/sanitizers/issues/911)
+# - disable 'float-divide-by-zero' on clang, which considers it UB
+#   (https://bugs.llvm.org/show_bug.cgi?id=17000#c1)
 #   Note: GCC does not support the 'function' flag.
 if(${ARROW_USE_UBSAN})
   if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
      OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
     set(
       CMAKE_CXX_FLAGS
-      "${CMAKE_CXX_FLAGS} -fsanitize=undefined -fno-sanitize=alignment,vptr,function -fno-sanitize-recover=all"
+      "${CMAKE_CXX_FLAGS} -fsanitize=undefined -fno-sanitize=alignment,vptr,function,float-divide-by-zero -fno-sanitize-recover=all"
       )
   elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU"
          AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "5.1")
diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc
index e0c23a3..421ec13 100644
--- a/cpp/src/arrow/compare.cc
+++ b/cpp/src/arrow/compare.cc
@@ -862,7 +862,9 @@ class TypeEqualsVisitor {
 
 class ScalarEqualsVisitor {
  public:
-  explicit ScalarEqualsVisitor(const Scalar& right) : right_(right), result_(false) {}
+  explicit ScalarEqualsVisitor(const Scalar& right,
+                               const EqualOptions& opts = EqualOptions::Defaults())
+      : right_(right), result_(false), options_(opts) {}
 
   Status Visit(const NullScalar& left) {
     result_ = true;
@@ -876,8 +878,25 @@ class ScalarEqualsVisitor {
   }
 
   template <typename T>
+  typename std::enable_if<std::is_base_of<FloatScalar, T>::value ||
+                              std::is_base_of<DoubleScalar, T>::value,
+                          Status>::type
+  Visit(const T& left_) {
+    const auto& right = checked_cast<const T&>(right_);
+    if (options_.nans_equal()) {
+      result_ = right.value == left_.value ||
+                (std::isnan(right.value) && std::isnan(left_.value));
+    } else {
+      result_ = right.value == left_.value;
+    }
+    return Status::OK();
+  }
+
+  template <typename T>
   typename std::enable_if<
-      std::is_base_of<internal::PrimitiveScalar<typename T::TypeClass>, T>::value ||
+      (std::is_base_of<internal::PrimitiveScalar<typename T::TypeClass>, T>::value &&
+       !std::is_base_of<FloatScalar, T>::value &&
+       !std::is_base_of<DoubleScalar, T>::value) ||
           std::is_base_of<TemporalScalar<typename T::TypeClass>, T>::value,
       Status>::type
   Visit(const T& left_) {
@@ -968,6 +987,7 @@ class ScalarEqualsVisitor {
  protected:
   const Scalar& right_;
   bool result_;
+  const EqualOptions options_;
 };
 
 Status PrintDiff(const Array& left, const Array& right, std::ostream* os) {
@@ -1386,7 +1406,7 @@ bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata
   }
 }
 
-bool ScalarEquals(const Scalar& left, const Scalar& right) {
+bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options) {
   bool are_equal = false;
   if (&left == &right) {
     are_equal = true;
@@ -1395,7 +1415,7 @@ bool ScalarEquals(const Scalar& left, const Scalar& right) {
   } else if (left.is_valid != right.is_valid) {
     are_equal = false;
   } else {
-    ScalarEqualsVisitor visitor(right);
+    ScalarEqualsVisitor visitor(right, options);
     auto error = VisitScalarInline(left, &visitor);
     DCHECK_OK(error);
     are_equal = visitor.result();
diff --git a/cpp/src/arrow/compare.h b/cpp/src/arrow/compare.h
index abcf39a..f7899b7 100644
--- a/cpp/src/arrow/compare.h
+++ b/cpp/src/arrow/compare.h
@@ -111,6 +111,8 @@ bool ARROW_EXPORT TypeEquals(const DataType& left, const DataType& right,
 /// Returns true if scalars are equal
 /// \param[in] left a Scalar
 /// \param[in] right a Scalar
-bool ARROW_EXPORT ScalarEquals(const Scalar& left, const Scalar& right);
+/// \param[in] options comparison options
+bool ARROW_EXPORT ScalarEquals(const Scalar& left, const Scalar& right,
+                               const EqualOptions& options = EqualOptions::Defaults());
 
 }  // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
index e56203b..ff6c6fa 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
@@ -190,10 +190,6 @@ struct MultiplyChecked {
 struct Divide {
   template <typename T, typename Arg0, typename Arg1>
   static enable_if_floating_point<T> Call(KernelContext* ctx, Arg0 left, Arg1 right) {
-    if (ARROW_PREDICT_FALSE(right == 0)) {
-      ctx->SetStatus(Status::Invalid("divide by zero"));
-      return 0;
-    }
     return left / right;
   }
 
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
index ea24089..9b3ed2a 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
@@ -144,7 +144,8 @@ class TestBinaryArithmetic : public TestBase {
       const auto expected_scalar = *expected->GetScalar(i);
       ASSERT_OK_AND_ASSIGN(
           actual, func(*left->GetScalar(i), *right->GetScalar(i), options_, nullptr));
-      AssertScalarsEqual(*expected_scalar, *actual.scalar(), /*verbose=*/true);
+      AssertScalarsEqual(*expected_scalar, *actual.scalar(), /*verbose=*/true,
+                         equal_options_);
     }
   }
 
@@ -165,12 +166,17 @@ class TestBinaryArithmetic : public TestBase {
   void ValidateAndAssertApproxEqual(const std::shared_ptr<Array>& actual,
                                     const std::shared_ptr<Array>& expected) {
     ASSERT_OK(actual->ValidateFull());
-    AssertArraysApproxEqual(*expected, *actual, /*verbose=*/true);
+    AssertArraysApproxEqual(*expected, *actual, /*verbose=*/true, equal_options_);
   }
 
   void SetOverflowCheck(bool value = true) { options_.check_overflow = value; }
 
+  void SetNansEqual(bool value = true) {
+    this->equal_options_ = equal_options_.nans_equal(value);
+  }
+
   ArithmeticOptions options_ = ArithmeticOptions();
+  EqualOptions equal_options_ = EqualOptions::Defaults();
 };
 
 template <typename... Elements>
@@ -510,6 +516,9 @@ TYPED_TEST(TestBinaryArithmeticFloating, Div) {
                       "[null, 0.1, 0.25, null, 0.2, 0.5]");
     // Array with infinity
     this->AssertBinop(Divide, "[3.4, Inf, -Inf]", "[1, 2, 3]", "[3.4, Inf, -Inf]");
+    // Array with NaN
+    this->SetNansEqual(true);
+    this->AssertBinop(Divide, "[3.4, NaN, 2.0]", "[1, 2, 2.0]", "[3.4, NaN, 1.0]");
     // Scalar divides by scalar
     this->AssertBinop(Divide, 21.0F, 3.0F, 7.0F);
   }
@@ -557,10 +566,17 @@ TYPED_TEST(TestBinaryArithmeticIntegral, DivideByZero) {
 }
 
 TYPED_TEST(TestBinaryArithmeticFloating, DivideByZero) {
-  for (auto check_overflow : {false, true}) {
-    this->SetOverflowCheck(check_overflow);
-    this->AssertBinopRaises(Divide, "[3.0, 2.0, 6.0]", "[1.0, 1.0, 0]", "divide by zero");
-  }
+  this->SetOverflowCheck(true);
+  this->AssertBinopRaises(Divide, "[3.0, 2.0, 6.0]", "[1.0, 1.0, 0.0]", "divide by zero");
+  this->AssertBinopRaises(Divide, "[3.0, 2.0, 0.0]", "[1.0, 1.0, 0.0]", "divide by zero");
+  this->AssertBinopRaises(Divide, "[3.0, 2.0, -6.0]", "[1.0, 1.0, 0.0]",
+                          "divide by zero");
+
+  this->SetOverflowCheck(false);
+  this->SetNansEqual(true);
+  this->AssertBinop(Divide, "[3.0, 2.0, 6.0]", "[1.0, 1.0, 0.0]", "[3.0, 2.0, Inf]");
+  this->AssertBinop(Divide, "[3.0, 2.0, 0.0]", "[1.0, 1.0, 0.0]", "[3.0, 2.0, NaN]");
+  this->AssertBinop(Divide, "[3.0, 2.0, -6.0]", "[1.0, 1.0, 0.0]", "[3.0, 2.0, -Inf]");
 }
 
 TYPED_TEST(TestBinaryArithmeticSigned, DivideOverflowRaises) {
diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc
index b953177..88e594e 100644
--- a/cpp/src/arrow/scalar.cc
+++ b/cpp/src/arrow/scalar.cc
@@ -40,7 +40,9 @@ namespace arrow {
 using internal::checked_cast;
 using internal::checked_pointer_cast;
 
-bool Scalar::Equals(const Scalar& other) const { return ScalarEquals(*this, other); }
+bool Scalar::Equals(const Scalar& other, const EqualOptions& options) const {
+  return ScalarEquals(*this, other, options);
+}
 
 struct ScalarHashImpl {
   static std::hash<std::string> string_hash;
diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h
index d15c44c..4a007dd 100644
--- a/cpp/src/arrow/scalar.h
+++ b/cpp/src/arrow/scalar.h
@@ -28,6 +28,7 @@
 #include <utility>
 #include <vector>
 
+#include "arrow/compare.h"
 #include "arrow/result.h"
 #include "arrow/status.h"
 #include "arrow/type.h"
@@ -61,7 +62,8 @@ struct ARROW_EXPORT Scalar : public util::EqualityComparable<Scalar> {
 
   using util::EqualityComparable<Scalar>::operator==;
   using util::EqualityComparable<Scalar>::Equals;
-  bool Equals(const Scalar& other) const;
+  bool Equals(const Scalar& other,
+              const EqualOptions& options = EqualOptions::Defaults()) const;
 
   struct ARROW_EXPORT Hash {
     size_t operator()(const Scalar& scalar) const { return hash(scalar); }
diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc
index b2f5566..75cd204 100644
--- a/cpp/src/arrow/testing/gtest_util.cc
+++ b/cpp/src/arrow/testing/gtest_util.cc
@@ -135,22 +135,24 @@ void AssertArraysEqual(const Array& expected, const Array& actual, bool verbose)
       });
 }
 
-void AssertArraysApproxEqual(const Array& expected, const Array& actual, bool verbose) {
+void AssertArraysApproxEqual(const Array& expected, const Array& actual, bool verbose,
+                             const EqualOptions& option) {
   return AssertArraysEqualWith(
       expected, actual, verbose,
-      [](const Array& expected, const Array& actual, std::stringstream* diff) {
-        return expected.ApproxEquals(actual, EqualOptions().diff_sink(diff));
+      [&option](const Array& expected, const Array& actual, std::stringstream* diff) {
+        return expected.ApproxEquals(actual, option.diff_sink(diff));
       });
 }
 
-void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, bool verbose) {
+void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, bool verbose,
+                        const EqualOptions& options) {
   std::stringstream diff;
   // ARROW-8956, ScalarEquals returns false when both are null
   if (!expected.is_valid && !actual.is_valid) {
     // We consider both being null to be equal in this function
     return;
   }
-  if (!expected.Equals(actual)) {
+  if (!expected.Equals(actual, options)) {
     if (verbose) {
       diff << "Expected:\n" << expected.ToString();
       diff << "\nActual:\n" << actual.ToString();
diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h
index 1411e70..fd72b5a 100644
--- a/cpp/src/arrow/testing/gtest_util.h
+++ b/cpp/src/arrow/testing/gtest_util.h
@@ -164,12 +164,13 @@ std::vector<Type::type> AllTypeIds();
 // If verbose is true, then the arrays will be pretty printed
 ARROW_TESTING_EXPORT void AssertArraysEqual(const Array& expected, const Array& actual,
                                             bool verbose = false);
-ARROW_TESTING_EXPORT void AssertArraysApproxEqual(const Array& expected,
-                                                  const Array& actual,
-                                                  bool verbose = false);
+ARROW_TESTING_EXPORT void AssertArraysApproxEqual(
+    const Array& expected, const Array& actual, bool verbose = false,
+    const EqualOptions& option = EqualOptions::Defaults());
 // Returns true when values are both null
-ARROW_TESTING_EXPORT void AssertScalarsEqual(const Scalar& expected, const Scalar& actual,
-                                             bool verbose = false);
+ARROW_TESTING_EXPORT void AssertScalarsEqual(
+    const Scalar& expected, const Scalar& actual, bool verbose = false,
+    const EqualOptions& options = EqualOptions::Defaults());
 ARROW_TESTING_EXPORT void AssertBatchesEqual(const RecordBatch& expected,
                                              const RecordBatch& actual,
                                              bool check_metadata = false);