You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2019/07/04 05:52:14 UTC

[arrow] 35/38: ARROW-5813: [C++] Fix TensorEquals for different contiguous tensors

This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit 9e1788003915faf061d1553283e0b5ee382b9038
Author: Kenta Murata <mr...@mrkn.jp>
AuthorDate: Wed Jul 3 11:26:52 2019 +0200

    ARROW-5813: [C++] Fix TensorEquals for different contiguous tensors
    
    This change makes TensorEquals correctly calculate the equality of a row-major tensor and a column-major tensor.
    
    Author: Kenta Murata <mr...@mrkn.jp>
    
    Closes #4774 from mrkn/tensor_equals_for_different_contiguous and squashes the following commits:
    
    d40c3c774 <Kenta Murata> Add unequal expectations
    e7ef17d82 <Kenta Murata> Fix TensorEquals for different contiguous tensors
---
 cpp/src/arrow/compare.cc     |  8 ++++++-
 cpp/src/arrow/tensor-test.cc | 50 ++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 57 insertions(+), 1 deletion(-)

diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc
index 4ae5d89..e1525a4 100644
--- a/cpp/src/arrow/compare.cc
+++ b/cpp/src/arrow/compare.cc
@@ -970,7 +970,13 @@ bool TensorEquals(const Tensor& left, const Tensor& right) {
   } else if (left.size() == 0) {
     are_equal = true;
   } else {
-    if (!left.is_contiguous() || !right.is_contiguous()) {
+    const bool left_row_major_p = left.is_row_major();
+    const bool left_column_major_p = left.is_column_major();
+    const bool right_row_major_p = right.is_row_major();
+    const bool right_column_major_p = right.is_column_major();
+
+    if (!(left_row_major_p && right_row_major_p) &&
+        !(left_column_major_p && right_column_major_p)) {
       const auto& shape = left.shape();
       if (shape != right.shape()) {
         are_equal = false;
diff --git a/cpp/src/arrow/tensor-test.cc b/cpp/src/arrow/tensor-test.cc
index 36e9743..4638cd7 100644
--- a/cpp/src/arrow/tensor-test.cc
+++ b/cpp/src/arrow/tensor-test.cc
@@ -155,6 +155,56 @@ TEST(TestTensor, CountNonZeroForNonContiguousTensor) {
   AssertCountNonZero(t, 8);
 }
 
+TEST(TestTensor, Equals) {
+  std::vector<int64_t> shape = {4, 4};
+
+  std::vector<int64_t> c_values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+  std::vector<int64_t> c_strides = {32, 8};
+  Tensor tc1(int64(), Buffer::Wrap(c_values), shape, c_strides);
+  Tensor tc2(int64(), Buffer::Wrap(c_values), shape, c_strides);
+
+  std::vector<int64_t> f_values = {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16};
+  Tensor tc3(int64(), Buffer::Wrap(f_values), shape, c_strides);
+
+  std::vector<int64_t> f_strides = {8, 32};
+  Tensor tf1(int64(), Buffer::Wrap(f_values), shape, f_strides);
+  Tensor tf2(int64(), Buffer::Wrap(c_values), shape, f_strides);
+
+  std::vector<int64_t> nc_values = {1, 0, 5, 0, 9,  0, 13, 0, 2, 0, 6, 0, 10, 0, 14, 0,
+                                    3, 0, 7, 0, 11, 0, 15, 0, 4, 0, 8, 0, 12, 0, 16, 0};
+  std::vector<int64_t> nc_strides = {16, 64};
+  Tensor tnc(int64(), Buffer::Wrap(nc_values), shape, nc_strides);
+
+  ASSERT_TRUE(tc1.is_contiguous());
+  ASSERT_TRUE(tc1.is_row_major());
+
+  ASSERT_TRUE(tf1.is_contiguous());
+  ASSERT_TRUE(tf1.is_column_major());
+
+  ASSERT_FALSE(tnc.is_contiguous());
+
+  // same object
+  EXPECT_TRUE(tc1.Equals(tc1));
+  EXPECT_TRUE(tf1.Equals(tf1));
+  EXPECT_TRUE(tnc.Equals(tnc));
+
+  // different objects
+  EXPECT_TRUE(tc1.Equals(tc2));
+  EXPECT_FALSE(tc1.Equals(tc3));
+
+  // row-major and column-major
+  EXPECT_TRUE(tc1.Equals(tf1));
+  EXPECT_FALSE(tc3.Equals(tf1));
+
+  // row-major and non-contiguous
+  EXPECT_TRUE(tc1.Equals(tnc));
+  EXPECT_FALSE(tc3.Equals(tnc));
+
+  // column-major and non-contiguous
+  EXPECT_TRUE(tf1.Equals(tnc));
+  EXPECT_FALSE(tf2.Equals(tnc));
+}
+
 TEST(TestNumericTensor, ElementAccessWithRowMajorStrides) {
   std::vector<int64_t> shape = {3, 4};