You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/08/16 02:29:09 UTC
[arrow] branch master updated: ARROW-6038: [C++] Faster type
equality
This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 91e33dc ARROW-6038: [C++] Faster type equality
91e33dc is described below
commit 91e33dcb6aa3c05eaf9d9d9f09579bb29e3fe175
Author: Antoine Pitrou <an...@python.org>
AuthorDate: Thu Aug 15 21:29:00 2019 -0500
ARROW-6038: [C++] Faster type equality
When checking for type equality, compute and cache a fingerprint of the type so as to avoid costly nested type walking and multiple comparisons.
Before:
```
----------------------------------------------------------------
Benchmark Time CPU Iterations
----------------------------------------------------------------
TypeEqualsSimple 13 ns 13 ns 55242976 150.558M items/s
TypeEqualsComplex 430 ns 430 ns 1637275 4.43634M items/s
TypeEqualsWithMetadata 595 ns 595 ns 1199216 3.20778M items/s
SchemaEquals 1465 ns 1465 ns 479512 1.30226M items/s
SchemaEqualsWithMetadata 922 ns 922 ns 763752 2.0683M items/s
```
After:
```
----------------------------------------------------------------
Benchmark Time CPU Iterations
----------------------------------------------------------------
TypeEqualsSimple 11 ns 11 ns 65531752 178.723M items/s
TypeEqualsComplex 20 ns 20 ns 33939830 95.1497M items/s
TypeEqualsWithMetadata 31 ns 31 ns 22979555 62.4052M items/s
SchemaEquals 40 ns 40 ns 17786532 48.1683M items/s
SchemaEqualsWithMetadata 46 ns 46 ns 15173158 41.3242M items/s
```
Closes #4983 from pitrou/ARROW-6038-faster-type-equality and squashes the following commits:
2fdaf4adb <Antoine Pitrou> ARROW-6038: Faster type equality
Authored-by: Antoine Pitrou <an...@python.org>
Signed-off-by: Wes McKinney <we...@apache.org>
---
cpp/src/arrow/CMakeLists.txt | 1 +
cpp/src/arrow/compare.cc | 24 +-
cpp/src/arrow/extension_type-test.cc | 11 +
cpp/src/arrow/type-benchmark.cc | 170 +++++++++++++
cpp/src/arrow/type-test.cc | 268 +++++++++++++++----
cpp/src/arrow/type.cc | 354 +++++++++++++++++++++++++-
cpp/src/arrow/type.h | 155 ++++++++++-
cpp/src/arrow/util/key-value-metadata-test.cc | 18 ++
cpp/src/arrow/util/key_value_metadata.cc | 11 +
cpp/src/arrow/util/key_value_metadata.h | 2 +
integration/integration_test.py | 61 ++---
11 files changed, 961 insertions(+), 114 deletions(-)
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 0085238..4839fb8 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -381,6 +381,7 @@ add_arrow_test(tensor-test)
add_arrow_test(sparse_tensor-test)
add_arrow_benchmark(builder-benchmark)
+add_arrow_benchmark(type-benchmark)
add_subdirectory(array)
add_subdirectory(csv)
diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc
index 05a1d1f..222d4f9 100644
--- a/cpp/src/arrow/compare.cc
+++ b/cpp/src/arrow/compare.cc
@@ -1163,21 +1163,35 @@ bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right) {
}
bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata) {
- bool are_equal;
// The arrays are the same object
if (&left == &right) {
- are_equal = true;
+ return true;
} else if (left.id() != right.id()) {
- are_equal = false;
+ return false;
} else {
+ // First try to compute fingerprints
+ if (check_metadata) {
+ const auto& left_metadata_fp = left.metadata_fingerprint();
+ const auto& right_metadata_fp = right.metadata_fingerprint();
+ if (left_metadata_fp != right_metadata_fp) {
+ return false;
+ }
+ }
+
+ const auto& left_fp = left.fingerprint();
+ const auto& right_fp = right.fingerprint();
+ if (!left_fp.empty() && !right_fp.empty()) {
+ return left_fp == right_fp;
+ }
+
+ // TODO remove check_metadata here?
internal::TypeEqualsVisitor visitor(right, check_metadata);
auto error = VisitTypeInline(left, &visitor);
if (!error.ok()) {
DCHECK(false) << "Types are not comparable: " << error.ToString();
}
- are_equal = visitor.result();
+ return visitor.result();
}
- return are_equal;
}
bool ScalarEquals(const Scalar& left, const Scalar& right) {
diff --git a/cpp/src/arrow/extension_type-test.cc b/cpp/src/arrow/extension_type-test.cc
index 2f680af..06fd6a9 100644
--- a/cpp/src/arrow/extension_type-test.cc
+++ b/cpp/src/arrow/extension_type-test.cc
@@ -329,4 +329,15 @@ TEST_F(TestExtensionType, ParametricTypes) {
CompareBatch(*batch, *read_batch, false /* compare_metadata */);
}
+TEST_F(TestExtensionType, ParametricEquals) {
+ auto p1_type = std::make_shared<Parametric1Type>(6);
+ auto p2_type = std::make_shared<Parametric1Type>(6);
+ auto p3_type = std::make_shared<Parametric1Type>(3);
+
+ ASSERT_TRUE(p1_type->Equals(p2_type));
+ ASSERT_FALSE(p1_type->Equals(p3_type));
+
+ ASSERT_EQ(p1_type->fingerprint(), "");
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/type-benchmark.cc b/cpp/src/arrow/type-benchmark.cc
new file mode 100644
index 0000000..713bfc5
--- /dev/null
+++ b/cpp/src/arrow/type-benchmark.cc
@@ -0,0 +1,170 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <random>
+#include <string>
+#include <vector>
+
+#include "benchmark/benchmark.h"
+
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+
+static void TypeEqualsSimple(benchmark::State& state) { // NOLINT non-const reference
+ auto a = uint8();
+ auto b = uint8();
+ auto c = float64();
+
+ int64_t total = 0;
+ for (auto _ : state) {
+ total += a->Equals(*b);
+ total += a->Equals(*c);
+ }
+ benchmark::DoNotOptimize(total);
+ state.SetItemsProcessed(state.iterations() * 2);
+}
+
+static void TypeEqualsComplex(benchmark::State& state) { // NOLINT non-const reference
+ auto fa1 = field("as", list(float16()));
+ auto fa2 = field("as", list(float16()));
+ auto fb1 = field("bs", utf8());
+ auto fb2 = field("bs", utf8());
+ auto fc1 = field("cs", list(fixed_size_binary(10)));
+ auto fc2 = field("cs", list(fixed_size_binary(10)));
+ auto fc3 = field("cs", list(fixed_size_binary(11)));
+
+ auto a = struct_({fa1, fb1, fc1});
+ auto b = struct_({fa2, fb2, fc2});
+ auto c = struct_({fa2, fb2, fc3});
+
+ int64_t total = 0;
+ for (auto _ : state) {
+ total += a->Equals(*b);
+ total += a->Equals(*c);
+ }
+ benchmark::DoNotOptimize(total);
+ state.SetItemsProcessed(state.iterations() * 2);
+}
+
+static void TypeEqualsWithMetadata(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto md1 = key_value_metadata({"k1", "k2"}, {"some value1", "some value2"});
+ auto md2 = key_value_metadata({"k1", "k2"}, {"some value1", "some value2"});
+ auto md3 = key_value_metadata({"k2", "k1"}, {"some value2", "some value1"});
+
+ auto fa1 = field("as", list(float16()));
+ auto fa2 = field("as", list(float16()));
+ auto fb1 = field("bs", utf8(), /*nullable=*/true, md1);
+ auto fb2 = field("bs", utf8(), /*nullable=*/true, md2);
+ auto fb3 = field("bs", utf8(), /*nullable=*/true, md3);
+
+ auto a = struct_({fa1, fb1});
+ auto b = struct_({fa2, fb2});
+ auto c = struct_({fa2, fb3});
+
+ int64_t total = 0;
+ for (auto _ : state) {
+ total += a->Equals(*b);
+ total += a->Equals(*c);
+ }
+ benchmark::DoNotOptimize(total);
+ state.SetItemsProcessed(state.iterations() * 2);
+}
+
+static std::vector<std::shared_ptr<Schema>> SampleSchemas() {
+ auto fa1 = field("as", list(float16()));
+ auto fa2 = field("as", list(float16()));
+ auto fb1 = field("bs", utf8());
+ auto fb2 = field("bs", utf8());
+ auto fc1 = field("cs", list(fixed_size_binary(10)));
+ auto fc2 = field("cs", list(fixed_size_binary(10)));
+ auto fd1 = field("ds", decimal(19, 5));
+ auto fd2 = field("ds", decimal(19, 5));
+ auto fe1 = field("es", map(utf8(), int32()));
+ auto fe2 = field("es", map(utf8(), int32()));
+ auto ff1 = field("fs", dictionary(int8(), binary()));
+ auto ff2 = field("fs", dictionary(int8(), binary()));
+ auto fg1 = field(
+ "gs", struct_({field("A", int8()), field("B", int16()), field("C", float32())}));
+ auto fg2 = field(
+ "gs", struct_({field("A", int8()), field("B", int16()), field("C", float32())}));
+ auto fh1 = field("hs", large_binary());
+ auto fh2 = field("hs", large_binary());
+
+ auto fz1 = field("zs", duration(TimeUnit::MICRO));
+ auto fz2 = field("zs", duration(TimeUnit::MICRO));
+ auto fz3 = field("zs", duration(TimeUnit::NANO));
+
+ auto schema1 = ::arrow::schema({fa1, fb1, fc1, fd1, fe1, ff1, fg1, fh1, fz1});
+ auto schema2 = ::arrow::schema({fa2, fb2, fc2, fd2, fe2, ff2, fg2, fh2, fz2});
+ auto schema3 = ::arrow::schema({fa2, fb2, fc2, fd2, fe2, ff2, fg2, fh2, fz3});
+
+ return {schema1, schema2, schema3};
+}
+
+static void SchemaEquals(benchmark::State& state) { // NOLINT non-const reference
+ auto schemas = SampleSchemas();
+
+ auto schema1 = schemas[0];
+ auto schema2 = schemas[1];
+ auto schema3 = schemas[2];
+
+ int64_t total = 0;
+ for (auto _ : state) {
+ total += schema1->Equals(*schema2, /*check_metadata =*/false);
+ total += schema1->Equals(*schema3, /*check_metadata =*/false);
+ }
+ benchmark::DoNotOptimize(total);
+ state.SetItemsProcessed(state.iterations() * 2);
+}
+
+static void SchemaEqualsWithMetadata(
+ benchmark::State& state) { // NOLINT non-const reference
+ auto schemas = SampleSchemas();
+
+ auto schema1 = schemas[0];
+ auto schema2 = schemas[1];
+ auto schema3 = schemas[2];
+
+ auto md1 = key_value_metadata({"k1", "k2"}, {"some value1", "some value2"});
+ auto md2 = key_value_metadata({"k1", "k2"}, {"some value1", "some value2"});
+ auto md3 = key_value_metadata({"k2", "k1"}, {"some value2", "some value1"});
+
+ schema1 = schema1->AddMetadata(md1);
+ schema2 = schema1->AddMetadata(md2);
+ schema3 = schema1->AddMetadata(md3);
+
+ int64_t total = 0;
+ for (auto _ : state) {
+ total += schema1->Equals(*schema2);
+ total += schema1->Equals(*schema3);
+ }
+ benchmark::DoNotOptimize(total);
+ state.SetItemsProcessed(state.iterations() * 2);
+}
+
+BENCHMARK(TypeEqualsSimple);
+BENCHMARK(TypeEqualsComplex);
+BENCHMARK(TypeEqualsWithMetadata);
+BENCHMARK(SchemaEquals);
+BENCHMARK(SchemaEqualsWithMetadata);
+
+} // namespace arrow
diff --git a/cpp/src/arrow/type-test.cc b/cpp/src/arrow/type-test.cc
index eb49227..b3a2545 100644
--- a/cpp/src/arrow/type-test.cc
+++ b/cpp/src/arrow/type-test.cc
@@ -36,6 +36,67 @@ namespace arrow {
using internal::checked_cast;
+template <typename T>
+void AssertFingerprintablesEqual(const T& left, const T& right, bool check_metadata,
+ const char* types_plural) {
+ ASSERT_TRUE(left.Equals(right, check_metadata))
+ << types_plural << " '" << left.ToString() << "' and '" << right.ToString()
+ << "' should have compared equal";
+ auto lfp = left.fingerprint();
+ auto rfp = right.fingerprint();
+ // All types tested in this file should implement fingerprinting
+ ASSERT_NE(lfp, "") << "fingerprint for '" << left.ToString() << "' should not be empty";
+ ASSERT_NE(rfp, "") << "fingerprint for '" << right.ToString()
+ << "' should not be empty";
+ if (check_metadata) {
+ lfp += left.metadata_fingerprint();
+ rfp += right.metadata_fingerprint();
+ }
+ ASSERT_EQ(lfp, rfp) << "Fingerprints for " << types_plural << " '" << left.ToString()
+ << "' and '" << right.ToString() << "' should have compared equal";
+}
+
+template <typename T>
+void AssertFingerprintablesNotEqual(const T& left, const T& right, bool check_metadata,
+ const char* types_plural) {
+ ASSERT_FALSE(left.Equals(right, check_metadata))
+ << types_plural << " '" << left.ToString() << "' and '" << right.ToString()
+ << "' should have compared unequal";
+ auto lfp = left.fingerprint();
+ auto rfp = right.fingerprint();
+ // All types tested in this file should implement fingerprinting
+ ASSERT_NE(lfp, "") << "fingerprint for '" << left.ToString() << "' should not be empty";
+ ASSERT_NE(rfp, "") << "fingerprint for '" << right.ToString()
+ << "' should not be empty";
+ if (check_metadata) {
+ lfp += left.metadata_fingerprint();
+ rfp += right.metadata_fingerprint();
+ }
+ ASSERT_NE(lfp, rfp) << "Fingerprints for " << types_plural << " '" << left.ToString()
+ << "' and '" << right.ToString()
+ << "' should have compared unequal";
+}
+
+void AssertTypesEqual(const DataType& left, const DataType& right,
+ bool check_metadata = true) {
+ AssertFingerprintablesEqual(left, right, check_metadata, "types");
+}
+
+void AssertTypesNotEqual(const DataType& left, const DataType& right,
+ bool check_metadata = true) {
+ AssertFingerprintablesNotEqual(left, right, check_metadata, "types");
+}
+
+void AssertFieldsEqual(const Field& left, const Field& right,
+ bool check_metadata = true) {
+ AssertFingerprintablesEqual(left, right, check_metadata, "fields");
+}
+
+void AssertFieldsNotEqual(const Field& left, const Field& right,
+ bool check_metadata = true) {
+ AssertFingerprintablesNotEqual(left, right, check_metadata, "fields");
+}
+
TEST(TestField, Basics) {
Field f0("f0", int32());
Field f0_nn("f0", int32(), false);
@@ -48,17 +109,28 @@ TEST(TestField, Basics) {
}
TEST(TestField, Equals) {
- auto meta = key_value_metadata({{"a", "1"}, {"b", "2"}});
+ auto meta1 = key_value_metadata({{"a", "1"}, {"b", "2"}});
+ // Different from meta1
+ auto meta2 = key_value_metadata({{"a", "1"}, {"b", "3"}});
+ // Equal to meta1, though in different order
+ auto meta3 = key_value_metadata({{"b", "2"}, {"a", "1"}});
Field f0("f0", int32());
Field f0_nn("f0", int32(), false);
Field f0_other("f0", int32());
- Field f0_with_meta("f0", int32(), true, meta);
-
- ASSERT_TRUE(f0.Equals(f0_other));
- ASSERT_FALSE(f0.Equals(f0_nn));
- ASSERT_FALSE(f0.Equals(f0_with_meta));
- ASSERT_TRUE(f0.Equals(f0_with_meta, false));
+ Field f0_with_meta1("f0", int32(), true, meta1);
+ Field f0_with_meta2("f0", int32(), true, meta2);
+ Field f0_with_meta3("f0", int32(), true, meta3);
+
+ AssertFieldsEqual(f0, f0_other);
+ AssertFieldsNotEqual(f0, f0_nn);
+ AssertFieldsNotEqual(f0, f0_with_meta1);
+ AssertFieldsNotEqual(f0_with_meta1, f0_with_meta2);
+ AssertFieldsEqual(f0_with_meta1, f0_with_meta3);
+
+ AssertFieldsEqual(f0, f0_with_meta1, false);
+ AssertFieldsEqual(f0, f0_with_meta2, false);
+ AssertFieldsEqual(f0_with_meta1, f0_with_meta2, false);
}
TEST(TestField, TestMetadataConstruction) {
@@ -68,7 +140,7 @@ TEST(TestField, TestMetadataConstruction) {
auto f0 = field("f0", int32(), true, metadata);
auto f1 = field("f0", int32(), true, metadata2);
ASSERT_TRUE(metadata->Equals(*f0->metadata()));
- ASSERT_TRUE(f0->Equals(*f1));
+ AssertFieldsEqual(*f0, *f1);
}
TEST(TestField, TestAddMetadata) {
@@ -78,8 +150,10 @@ TEST(TestField, TestAddMetadata) {
auto f1 = field("f0", int32(), true, metadata);
std::shared_ptr<Field> f2 = f0->AddMetadata(metadata);
- ASSERT_FALSE(f2->Equals(*f0));
- ASSERT_TRUE(f2->Equals(*f1));
+ AssertFieldsEqual(*f1, *f2);
+ AssertFieldsNotEqual(*f0, *f2);
+ ASSERT_TRUE(f1->Equals(f2, /*check_metadata =*/false));
+ ASSERT_TRUE(f0->Equals(f2, /*check_metadata =*/false));
// Not copied
ASSERT_TRUE(metadata.get() == f1->metadata().get());
@@ -94,6 +168,21 @@ TEST(TestField, TestRemoveMetadata) {
ASSERT_TRUE(f2->metadata() == nullptr);
}
+TEST(TestField, TestEmptyMetadata) {
+ // Empty metadata should be equivalent to no metadata at all
+ auto metadata1 = key_value_metadata({});
+ auto metadata2 = key_value_metadata({"foo"}, {"foo value"});
+
+ auto f0 = field("f0", int32());
+ auto f1 = field("f0", int32(), true, metadata1);
+ auto f2 = field("f0", int32(), true, metadata2);
+
+ AssertFieldsEqual(*f0, *f1);
+ AssertFieldsNotEqual(*f0, *f2);
+ ASSERT_TRUE(f0->Equals(f1, /*check_metadata =*/false));
+ ASSERT_TRUE(f0->Equals(f2, /*check_metadata =*/false));
+}
+
TEST(TestField, TestFlatten) {
auto metadata = std::shared_ptr<KeyValueMetadata>(
new KeyValueMetadata({"foo", "bar"}, {"bizz", "buzz"}));
@@ -128,9 +217,9 @@ TEST(TestField, TestReplacement) {
auto fzero = f0->WithType(utf8());
auto f1 = f0->WithName("f1");
- ASSERT_FALSE(f0->Equals(fzero));
- ASSERT_FALSE(fzero->Equals(f1));
- ASSERT_FALSE(f1->Equals(f0));
+ AssertFieldsNotEqual(*f0, *fzero);
+ AssertFieldsNotEqual(*fzero, *f1);
+ AssertFieldsNotEqual(*f1, *f0);
ASSERT_EQ(fzero->name(), "f0");
ASSERT_TRUE(fzero->type()->Equals(utf8()));
@@ -166,6 +255,8 @@ TEST_F(TestSchema, Basics) {
auto schema3 = std::make_shared<Schema>(fields3);
ASSERT_TRUE(schema->Equals(*schema2));
ASSERT_FALSE(schema->Equals(*schema3));
+ ASSERT_EQ(schema->fingerprint(), schema2->fingerprint());
+ ASSERT_NE(schema->fingerprint(), schema3->fingerprint());
}
TEST_F(TestSchema, ToString) {
@@ -275,7 +366,14 @@ TEST_F(TestSchema, TestMetadataConstruction) {
ASSERT_TRUE(schema0->Equals(*schema2));
ASSERT_FALSE(schema0->Equals(*schema1));
ASSERT_FALSE(schema2->Equals(*schema1));
- ASSERT_FALSE(schema2->Equals(*schema3));
+ ASSERT_FALSE(schema2->Equals(*schema3)); // Field has different metadata
+
+ ASSERT_EQ(schema0->fingerprint(), schema1->fingerprint());
+ ASSERT_EQ(schema0->fingerprint(), schema2->fingerprint());
+ ASSERT_EQ(schema0->fingerprint(), schema3->fingerprint());
+ ASSERT_NE(schema0->metadata_fingerprint(), schema1->metadata_fingerprint());
+ ASSERT_EQ(schema0->metadata_fingerprint(), schema2->metadata_fingerprint());
+ ASSERT_NE(schema0->metadata_fingerprint(), schema3->metadata_fingerprint());
// don't check metadata
ASSERT_TRUE(schema0->Equals(*schema1, false));
@@ -283,6 +381,25 @@ TEST_F(TestSchema, TestMetadataConstruction) {
ASSERT_TRUE(schema2->Equals(*schema3, false));
}
+TEST_F(TestSchema, TestEmptyMetadata) {
+ // Empty metadata should be equivalent to no metadata at all
+ auto f1 = field("f1", int32());
+ auto metadata1 = key_value_metadata({});
+ auto metadata2 = key_value_metadata({"foo"}, {"foo value"});
+
+ auto schema1 = ::arrow::schema({f1});
+ auto schema2 = ::arrow::schema({f1}, metadata1);
+ auto schema3 = ::arrow::schema({f1}, metadata2);
+
+ ASSERT_TRUE(schema1->Equals(*schema2));
+ ASSERT_FALSE(schema1->Equals(*schema3));
+
+ ASSERT_EQ(schema1->fingerprint(), schema2->fingerprint());
+ ASSERT_EQ(schema1->fingerprint(), schema3->fingerprint());
+ ASSERT_EQ(schema1->metadata_fingerprint(), schema2->metadata_fingerprint());
+ ASSERT_NE(schema1->metadata_fingerprint(), schema3->metadata_fingerprint());
+}
+
TEST_F(TestSchema, TestAddMetadata) {
auto f0 = field("f0", int32());
auto f1 = field("f1", uint8(), false);
@@ -342,8 +459,8 @@ TEST(TestBinaryType, ToString) {
BinaryType t1;
BinaryType e1;
StringType t2;
- EXPECT_TRUE(t1.Equals(e1));
- EXPECT_FALSE(t1.Equals(t2));
+ AssertTypesEqual(t1, e1);
+ AssertTypesNotEqual(t1, t2);
ASSERT_EQ(t1.id(), Type::BINARY);
ASSERT_EQ(t1.ToString(), std::string("binary"));
}
@@ -359,9 +476,9 @@ TEST(TestLargeBinaryTypes, ToString) {
LargeBinaryType t1;
LargeBinaryType e1;
LargeStringType t2;
- EXPECT_TRUE(t1.Equals(e1));
- EXPECT_FALSE(t1.Equals(t2));
- EXPECT_FALSE(t1.Equals(bt1));
+ AssertTypesEqual(t1, e1);
+ AssertTypesNotEqual(t1, t2);
+ AssertTypesNotEqual(t1, bt1);
ASSERT_EQ(t1.id(), Type::LARGE_BINARY);
ASSERT_EQ(t1.ToString(), std::string("large_binary"));
ASSERT_EQ(t2.id(), Type::LARGE_STRING);
@@ -379,9 +496,8 @@ TEST(TestFixedSizeBinaryType, Equals) {
auto t2 = fixed_size_binary(10);
auto t3 = fixed_size_binary(3);
- ASSERT_TRUE(t1->Equals(t1));
- ASSERT_TRUE(t1->Equals(t2));
- ASSERT_FALSE(t1->Equals(t3));
+ AssertTypesEqual(*t1, *t2);
+ AssertTypesNotEqual(*t1, *t3);
}
TEST(TestListType, Basics) {
@@ -488,11 +604,11 @@ TEST(TestTimeType, Equals) {
ASSERT_EQ(32, t0.bit_width());
ASSERT_EQ(64, t3.bit_width());
- ASSERT_TRUE(t0.Equals(t2));
- ASSERT_TRUE(t1.Equals(t1));
- ASSERT_FALSE(t1.Equals(t3));
- ASSERT_FALSE(t3.Equals(t4));
- ASSERT_TRUE(t3.Equals(t5));
+ AssertTypesEqual(t0, t2);
+ AssertTypesEqual(t1, t1);
+ AssertTypesNotEqual(t1, t3);
+ AssertTypesNotEqual(t3, t4);
+ AssertTypesEqual(t3, t5);
}
TEST(TestTimeType, ToString) {
@@ -512,8 +628,8 @@ TEST(TestMonthIntervalType, Equals) {
MonthIntervalType t2;
DayTimeIntervalType t3;
- ASSERT_TRUE(t1.Equals(t2));
- ASSERT_FALSE(t1.Equals(t3));
+ AssertTypesEqual(t1, t2);
+ AssertTypesNotEqual(t1, t3);
}
TEST(TestMonthIntervalType, ToString) {
@@ -527,8 +643,8 @@ TEST(TestDayTimeIntervalType, Equals) {
DayTimeIntervalType t2;
MonthIntervalType t3;
- ASSERT_TRUE(t1.Equals(t2));
- ASSERT_FALSE(t1.Equals(t3));
+ AssertTypesEqual(t1, t2);
+ AssertTypesNotEqual(t1, t3);
}
TEST(TestDayTimeIntervalType, ToString) {
@@ -543,9 +659,9 @@ TEST(TestDurationType, Equals) {
DurationType t3(TimeUnit::NANO);
DurationType t4(TimeUnit::NANO);
- ASSERT_TRUE(t1.Equals(t2));
- ASSERT_FALSE(t1.Equals(t3));
- ASSERT_TRUE(t3.Equals(t4));
+ AssertTypesEqual(t1, t2);
+ AssertTypesNotEqual(t1, t3);
+ AssertTypesEqual(t3, t4);
}
TEST(TestDurationType, ToString) {
@@ -566,9 +682,15 @@ TEST(TestTimestampType, Equals) {
TimestampType t3(TimeUnit::NANO);
TimestampType t4(TimeUnit::NANO);
- ASSERT_TRUE(t1.Equals(t2));
- ASSERT_FALSE(t1.Equals(t3));
- ASSERT_TRUE(t3.Equals(t4));
+ DurationType dt1;
+ DurationType dt2(TimeUnit::NANO);
+
+ AssertTypesEqual(t1, t2);
+ AssertTypesNotEqual(t1, t3);
+ AssertTypesEqual(t3, t4);
+
+ AssertTypesNotEqual(t1, dt1);
+ AssertTypesNotEqual(t3, dt2);
}
TEST(TestTimestampType, ToString) {
@@ -591,11 +713,39 @@ TEST(TestListType, Equals) {
auto t5 = large_list(binary());
auto t6 = large_list(float64());
- ASSERT_TRUE(t1->Equals(t2));
- ASSERT_FALSE(t1->Equals(t3));
- ASSERT_FALSE(t3->Equals(t4));
- ASSERT_TRUE(t4->Equals(t5));
- ASSERT_FALSE(t5->Equals(t6));
+ AssertTypesEqual(*t1, *t2);
+ AssertTypesNotEqual(*t1, *t3);
+ AssertTypesNotEqual(*t3, *t4);
+ AssertTypesEqual(*t4, *t5);
+ AssertTypesNotEqual(*t5, *t6);
+}
+
+TEST(TestListType, Metadata) {
+ auto md1 = key_value_metadata({"foo", "bar"}, {"foo value", "bar value"});
+ auto md2 = key_value_metadata({"foo", "bar"}, {"foo value", "bar value"});
+ auto md3 = key_value_metadata({"foo"}, {"foo value"});
+
+ auto f1 = field("item", utf8(), /*nullable =*/true, md1);
+ auto f2 = field("item", utf8(), /*nullable =*/true, md2);
+ auto f3 = field("item", utf8(), /*nullable =*/true, md3);
+ auto f4 = field("item", utf8());
+ auto f5 = field("item", utf8(), /*nullable =*/false, md1);
+
+ auto t1 = list(f1);
+ auto t2 = list(f2);
+ auto t3 = list(f3);
+ auto t4 = list(f4);
+ auto t5 = list(f5);
+
+ AssertTypesEqual(*t1, *t2);
+ AssertTypesNotEqual(*t1, *t3);
+ AssertTypesNotEqual(*t1, *t4);
+ AssertTypesNotEqual(*t1, *t5);
+
+ AssertTypesEqual(*t1, *t2, /*check_metadata =*/false);
+ AssertTypesEqual(*t1, *t3, /*check_metadata =*/false);
+ AssertTypesEqual(*t1, *t4, /*check_metadata =*/false);
+ AssertTypesNotEqual(*t1, *t5, /*check_metadata =*/false);
}
TEST(TestNestedType, Equals) {
@@ -621,18 +771,18 @@ TEST(TestNestedType, Equals) {
auto s0_bad = create_struct("f1", "s0");
auto s1 = create_struct("f1", "s1");
- ASSERT_TRUE(s0->Equals(s0_other));
- ASSERT_FALSE(s0->Equals(s1));
- ASSERT_FALSE(s0->Equals(s0_bad));
+ AssertFieldsEqual(*s0, *s0_other);
+ AssertFieldsNotEqual(*s0, *s1);
+ AssertFieldsNotEqual(*s0, *s0_bad);
auto u0 = create_union("f0", "u0");
auto u0_other = create_union("f0", "u0");
auto u0_bad = create_union("f1", "u0");
auto u1 = create_union("f1", "u1");
- ASSERT_TRUE(u0->Equals(u0_other));
- ASSERT_FALSE(u0->Equals(u1));
- ASSERT_FALSE(u0->Equals(u0_bad));
+ AssertFieldsEqual(*u0, *u0_other);
+ AssertFieldsNotEqual(*u0, *u1);
+ AssertFieldsNotEqual(*u0, *u0_bad);
}
TEST(TestStructType, Basics) {
@@ -752,11 +902,9 @@ TEST(TestDictionaryType, Equals) {
auto t3 = dictionary(int16(), int32());
auto t4 = dictionary(int8(), int16());
- ASSERT_TRUE(t1->Equals(t2));
- // Different index type
- ASSERT_FALSE(t1->Equals(t3));
- // Different value type
- ASSERT_FALSE(t1->Equals(t4));
+ AssertTypesEqual(*t1, *t2);
+ AssertTypesNotEqual(*t1, *t3);
+ AssertTypesNotEqual(*t1, *t4);
}
TEST(TestDictionaryType, UnifyNumeric) {
@@ -939,4 +1087,18 @@ TEST(TypesTest, TestDecimal128Large) {
ASSERT_EQ(t1.bit_width(), 128);
}
+TEST(TypesTest, TestDecimalEquals) {
+ Decimal128Type t1(8, 4);
+ Decimal128Type t2(8, 4);
+ Decimal128Type t3(8, 5);
+ Decimal128Type t4(27, 5);
+
+ FixedSizeBinaryType t9(16);
+
+ AssertTypesEqual(t1, t2);
+ AssertTypesNotEqual(t1, t3);
+ AssertTypesNotEqual(t1, t4);
+ AssertTypesNotEqual(t1, t9);
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc
index d8ed7bb..7a56ba7 100644
--- a/cpp/src/arrow/type.cc
+++ b/cpp/src/arrow/type.cc
@@ -39,6 +39,8 @@ namespace arrow {
using internal::checked_cast;
+Field::~Field() {}
+
bool Field::HasMetadata() const {
return (metadata_ != nullptr) && (metadata_->size() > 0);
}
@@ -491,13 +493,14 @@ class Schema::Impl {
Schema::Schema(const std::vector<std::shared_ptr<Field>>& fields,
const std::shared_ptr<const KeyValueMetadata>& metadata)
- : impl_(new Impl(fields, metadata)) {}
+ : detail::Fingerprintable(), impl_(new Impl(fields, metadata)) {}
Schema::Schema(std::vector<std::shared_ptr<Field>>&& fields,
const std::shared_ptr<const KeyValueMetadata>& metadata)
- : impl_(new Impl(std::move(fields), metadata)) {}
+ : detail::Fingerprintable(), impl_(new Impl(std::move(fields), metadata)) {}
-Schema::Schema(const Schema& schema) : impl_(new Impl(*schema.impl_)) {}
+Schema::Schema(const Schema& schema)
+ : detail::Fingerprintable(), impl_(new Impl(*schema.impl_)) {}
Schema::~Schema() {}
@@ -522,22 +525,30 @@ bool Schema::Equals(const Schema& other, bool check_metadata) const {
if (num_fields() != other.num_fields()) {
return false;
}
+
+ if (check_metadata) {
+ const auto& metadata_fp = metadata_fingerprint();
+ const auto& other_metadata_fp = other.metadata_fingerprint();
+ if (metadata_fp != other_metadata_fp) {
+ return false;
+ }
+ }
+
+ // Fast path using fingerprints, if possible
+ const auto& fp = fingerprint();
+ const auto& other_fp = other.fingerprint();
+ if (!fp.empty() && !other_fp.empty()) {
+ return fp == other_fp;
+ }
+
+ // Fall back on field-by-field comparison
for (int i = 0; i < num_fields(); ++i) {
if (!field(i)->Equals(*other.field(i).get(), check_metadata)) {
return false;
}
}
- // check metadata equality
- if (!check_metadata) {
- return true;
- } else if (this->HasMetadata() && other.HasMetadata()) {
- return impl_->metadata_->Equals(*other.impl_->metadata_);
- } else if (!this->HasMetadata() && !other.HasMetadata()) {
- return true;
- } else {
- return false;
- }
+ return true;
}
std::shared_ptr<Field> Schema::GetFieldByName(const std::string& name) const {
@@ -629,7 +640,7 @@ std::string Schema::ToString() const {
++i;
}
- if (impl_->metadata_) {
+ if (HasMetadata()) {
buffer << impl_->metadata_->ToString();
}
@@ -655,6 +666,321 @@ std::shared_ptr<Schema> schema(std::vector<std::shared_ptr<Field>>&& fields,
}
// ----------------------------------------------------------------------
+// Fingerprint computations
+
+namespace detail {
+
+Fingerprintable::~Fingerprintable() {
+ delete fingerprint_.load();
+ delete metadata_fingerprint_.load();
+}
+
+template <typename ComputeFingerprint>
+static const std::string& LoadFingerprint(std::atomic<std::string*>* fingerprint,
+ ComputeFingerprint&& compute_fingerprint) {
+ auto new_p = new std::string(std::forward<ComputeFingerprint>(compute_fingerprint)());
+ // Since fingerprint() and metadata_fingerprint() return a *reference* to the
+ // allocated string, the first allocation ever should never be replaced by another
+ // one. Hence the compare_exchange_strong() against nullptr.
+ std::string* expected = nullptr;
+ if (fingerprint->compare_exchange_strong(expected, new_p)) {
+ return *new_p;
+ } else {
+ delete new_p;
+ DCHECK_NE(expected, nullptr);
+ return *expected;
+ }
+}
+
+const std::string& Fingerprintable::LoadFingerprintSlow() const {
+ return LoadFingerprint(&fingerprint_, [this]() { return ComputeFingerprint(); });
+}
+
+const std::string& Fingerprintable::LoadMetadataFingerprintSlow() const {
+ return LoadFingerprint(&metadata_fingerprint_,
+ [this]() { return ComputeMetadataFingerprint(); });
+}
+
+} // namespace detail
+
+static inline std::string TypeIdFingerprint(const DataType& type) {
+ auto c = static_cast<int>(type.id()) + 'A';
+ DCHECK_GE(c, 0);
+ DCHECK_LT(c, 128); // Unlikely to happen any soon
+ // Prefix with an unusual character in order to disambiguate
+ std::string s{'@', static_cast<char>(c)};
+ return s;
+}
+
+static char TimeUnitFingerprint(TimeUnit::type unit) {
+ switch (unit) {
+ case TimeUnit::SECOND:
+ return 's';
+ case TimeUnit::MILLI:
+ return 'm';
+ case TimeUnit::MICRO:
+ return 'u';
+ case TimeUnit::NANO:
+ return 'n';
+ default:
+ DCHECK(false) << "Unexpected TimeUnit";
+ return '\0';
+ }
+}
+
+static char IntervalTypeFingerprint(IntervalType::type unit) {
+ switch (unit) {
+ case IntervalType::DAY_TIME:
+ return 'd';
+ case IntervalType::MONTHS:
+ return 'M';
+ default:
+ DCHECK(false) << "Unexpected IntervalType::type";
+ return '\0';
+ }
+}
+
+static void AppendMetadataFingerprint(const KeyValueMetadata& metadata,
+ std::stringstream* ss) {
+ // Compute metadata fingerprint. KeyValueMetadata is not immutable,
+ // so we don't cache the result on the metadata instance.
+ const auto pairs = metadata.sorted_pairs();
+ if (!pairs.empty()) {
+ *ss << "!{";
+ for (const auto& p : pairs) {
+ const auto& k = p.first;
+ const auto& v = p.second;
+ // Since metadata strings can contain arbitrary characters, prefix with
+ // string length to disambiguate.
+ *ss << k.length() << ':' << k << ':';
+ *ss << v.length() << ':' << v << ';';
+ }
+ *ss << '}';
+ }
+}
+
+static void AppendEmptyMetadataFingerprint(std::stringstream* ss) {}
+
+std::string Field::ComputeFingerprint() const {
+ const auto& type_fingerprint = type_->fingerprint();
+ if (type_fingerprint.empty()) {
+ // Underlying DataType doesn't support fingerprinting.
+ return "";
+ }
+ std::stringstream ss;
+ ss << 'F';
+ if (nullable_) {
+ ss << 'n';
+ } else {
+ ss << 'N';
+ }
+ ss << name_;
+ ss << '{' << type_fingerprint << '}';
+ return ss.str();
+}
+
+std::string Field::ComputeMetadataFingerprint() const {
+ std::stringstream ss;
+ if (metadata_) {
+ AppendMetadataFingerprint(*metadata_, &ss);
+ } else {
+ AppendEmptyMetadataFingerprint(&ss);
+ }
+ return ss.str();
+}
+
+std::string Schema::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << "S{";
+ for (const auto& field : fields()) {
+ const auto& field_fingerprint = field->fingerprint();
+ if (field_fingerprint.empty()) {
+ return "";
+ }
+ ss << field_fingerprint << ";";
+ }
+ ss << "}";
+ return ss.str();
+}
+
+std::string Schema::ComputeMetadataFingerprint() const {
+ std::stringstream ss;
+ if (HasMetadata()) {
+ AppendMetadataFingerprint(*metadata(), &ss);
+ } else {
+ AppendEmptyMetadataFingerprint(&ss);
+ }
+ ss << "S{";
+ for (const auto& field : fields()) {
+ const auto& field_fingerprint = field->metadata_fingerprint();
+ ss << field_fingerprint << ";";
+ }
+ ss << "}";
+ return ss.str();
+}
+
+std::string DataType::ComputeFingerprint() const {
+ // Default implementation returns empty string, signalling non-implemented
+ // functionality.
+ return "";
+}
+
+std::string DataType::ComputeMetadataFingerprint() const {
+ // Whatever the data type, metadata can only be found on child fields
+ std::string s;
+ for (const auto& child : children_) {
+ s += child->metadata_fingerprint();
+ }
+ return s;
+}
+
+#define PARAMETER_LESS_FINGERPRINT(TYPE_CLASS) \
+ std::string TYPE_CLASS##Type::ComputeFingerprint() const { \
+ return TypeIdFingerprint(*this); \
+ }
+
+PARAMETER_LESS_FINGERPRINT(Null)
+PARAMETER_LESS_FINGERPRINT(Boolean)
+PARAMETER_LESS_FINGERPRINT(Int8)
+PARAMETER_LESS_FINGERPRINT(Int16)
+PARAMETER_LESS_FINGERPRINT(Int32)
+PARAMETER_LESS_FINGERPRINT(Int64)
+PARAMETER_LESS_FINGERPRINT(UInt8)
+PARAMETER_LESS_FINGERPRINT(UInt16)
+PARAMETER_LESS_FINGERPRINT(UInt32)
+PARAMETER_LESS_FINGERPRINT(UInt64)
+PARAMETER_LESS_FINGERPRINT(HalfFloat)
+PARAMETER_LESS_FINGERPRINT(Float)
+PARAMETER_LESS_FINGERPRINT(Double)
+PARAMETER_LESS_FINGERPRINT(Binary)
+PARAMETER_LESS_FINGERPRINT(LargeBinary)
+PARAMETER_LESS_FINGERPRINT(String)
+PARAMETER_LESS_FINGERPRINT(LargeString)
+PARAMETER_LESS_FINGERPRINT(Date32)
+PARAMETER_LESS_FINGERPRINT(Date64)
+
+#undef PARAMETER_LESS_FINGERPRINT
+
+std::string DictionaryType::ComputeFingerprint() const {
+ const auto& index_fingerprint = index_type_->fingerprint();
+ const auto& value_fingerprint = value_type_->fingerprint();
+ DCHECK(!index_fingerprint.empty()); // it's an integer type
+ if (!value_fingerprint.empty()) {
+ return TypeIdFingerprint(*this) + index_fingerprint + value_fingerprint;
+ }
+ return "";
+}
+
+std::string ListType::ComputeFingerprint() const {
+ const auto& child_fingerprint = children_[0]->fingerprint();
+ if (!child_fingerprint.empty()) {
+ return TypeIdFingerprint(*this) + "{" + child_fingerprint + "}";
+ }
+ return "";
+}
+
+std::string LargeListType::ComputeFingerprint() const {
+ const auto& child_fingerprint = children_[0]->fingerprint();
+ if (!child_fingerprint.empty()) {
+ return TypeIdFingerprint(*this) + "{" + child_fingerprint + "}";
+ }
+ return "";
+}
+
+std::string MapType::ComputeFingerprint() const {
+ const auto& child_fingerprint = children_[0]->fingerprint();
+ if (!child_fingerprint.empty()) {
+ if (keys_sorted_) {
+ return TypeIdFingerprint(*this) + "s{" + child_fingerprint + "}";
+ } else {
+ return TypeIdFingerprint(*this) + "{" + child_fingerprint + "}";
+ }
+ }
+ return "";
+}
+
+std::string FixedSizeBinaryType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << "[" << byte_width_ << "]";
+ return ss.str();
+}
+
+std::string DecimalType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << "[" << byte_width_ << "," << precision_ << ","
+ << scale_ << "]";
+ return ss.str();
+}
+
+std::string StructType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << "{";
+ for (const auto& child : children_) {
+ const auto& child_fingerprint = child->fingerprint();
+ if (child_fingerprint.empty()) {
+ return "";
+ }
+ ss << child_fingerprint << ";";
+ }
+ ss << "}";
+ return ss.str();
+}
+
+std::string UnionType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this);
+ switch (mode_) {
+ case UnionMode::SPARSE:
+ ss << "[s";
+ break;
+ case UnionMode::DENSE:
+ ss << "[d";
+ break;
+ default:
+ DCHECK(false) << "Unexpected UnionMode";
+ }
+ for (const auto code : type_codes_) {
+ // Represent code as integer, not raw character
+ ss << ':' << static_cast<uint32_t>(code);
+ }
+ ss << "]{";
+ for (const auto& child : children_) {
+ const auto& child_fingerprint = child->fingerprint();
+ if (child_fingerprint.empty()) {
+ return "";
+ }
+ ss << child_fingerprint << ";";
+ }
+ ss << "}";
+ return ss.str();
+}
+
+std::string TimeType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << TimeUnitFingerprint(unit_);
+ return ss.str();
+}
+
+std::string TimestampType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << TimeUnitFingerprint(unit_) << timezone_.length()
+ << ':' << timezone_;
+ return ss.str();
+}
+
+std::string IntervalType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << IntervalTypeFingerprint(interval_type());
+ return ss.str();
+}
+
+std::string DurationType::ComputeFingerprint() const {
+ std::stringstream ss;
+ ss << TypeIdFingerprint(*this) << TimeUnitFingerprint(unit_);
+ return ss.str();
+}
+
+// ----------------------------------------------------------------------
// Visitors and factory functions
Status DataType::Accept(TypeVisitor* visitor) const {
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index 2e4e1b3..e153369 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -18,6 +18,7 @@
#ifndef ARROW_TYPE_H
#define ARROW_TYPE_H
+#include <atomic>
#include <climits>
#include <cstdint>
#include <iosfwd>
@@ -156,6 +157,41 @@ struct Type {
};
};
+namespace detail {
+
+class ARROW_EXPORT Fingerprintable {
+ public:
+ virtual ~Fingerprintable();
+
+ const std::string& fingerprint() const {
+ auto p = fingerprint_.load();
+ if (ARROW_PREDICT_TRUE(p != NULLPTR)) {
+ return *p;
+ }
+ return LoadFingerprintSlow();
+ }
+
+ const std::string& metadata_fingerprint() const {
+ auto p = metadata_fingerprint_.load();
+ if (ARROW_PREDICT_TRUE(p != NULLPTR)) {
+ return *p;
+ }
+ return LoadMetadataFingerprintSlow();
+ }
+
+ protected:
+ const std::string& LoadFingerprintSlow() const;
+ const std::string& LoadMetadataFingerprintSlow() const;
+
+ virtual std::string ComputeFingerprint() const = 0;
+ virtual std::string ComputeMetadataFingerprint() const = 0;
+
+ mutable std::atomic<std::string*> fingerprint_;
+ mutable std::atomic<std::string*> metadata_fingerprint_;
+};
+
+} // namespace detail
+
struct ARROW_EXPORT DataTypeLayout {
// The bit width for each buffer in this DataType's representation
// (kVariableSizeBuffer if the item size for a given buffer is unknown or variable,
@@ -177,10 +213,10 @@ struct ARROW_EXPORT DataTypeLayout {
///
/// Simple datatypes may be entirely described by their Type::type id, but
/// complex datatypes are usually parametric.
-class ARROW_EXPORT DataType {
+class ARROW_EXPORT DataType : public detail::Fingerprintable {
public:
- explicit DataType(Type::type id) : id_(id) {}
- virtual ~DataType();
+ explicit DataType(Type::type id) : detail::Fingerprintable(), id_(id) {}
+ ~DataType() override;
/// \brief Return whether the types are equal
///
@@ -217,6 +253,13 @@ class ARROW_EXPORT DataType {
Type::type id() const { return id_; }
protected:
+ // Dummy version that returns a null string (indicating not implemented).
+ // Subclasses should override for fast equality checks.
+ std::string ComputeFingerprint() const override;
+
+ // Generic versions that works for all regular types, nested or not.
+ std::string ComputeMetadataFingerprint() const override;
+
Type::type id_;
std::vector<std::shared_ptr<Field>> children_;
@@ -279,12 +322,18 @@ class NoExtraMeta {};
///
/// A field's metadata is represented by a KeyValueMetadata instance,
/// which holds arbitrary key-value pairs.
-class ARROW_EXPORT Field {
+class ARROW_EXPORT Field : public detail::Fingerprintable {
public:
Field(const std::string& name, const std::shared_ptr<DataType>& type,
bool nullable = true,
const std::shared_ptr<const KeyValueMetadata>& metadata = NULLPTR)
- : name_(name), type_(type), nullable_(nullable), metadata_(metadata) {}
+ : detail::Fingerprintable(),
+ name_(name),
+ type_(type),
+ nullable_(nullable),
+ metadata_(metadata) {}
+
+ ~Field() override;
/// \brief Return the field's attached metadata
std::shared_ptr<const KeyValueMetadata> metadata() const { return metadata_; }
@@ -322,6 +371,9 @@ class ARROW_EXPORT Field {
std::shared_ptr<Field> Copy() const;
private:
+ std::string ComputeFingerprint() const override;
+ std::string ComputeMetadataFingerprint() const override;
+
// Field name
std::string name_;
@@ -379,6 +431,9 @@ class ARROW_EXPORT NullType : public DataType, public NoExtraMeta {
}
std::string name() const override { return "null"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// Concrete type class for boolean data
@@ -397,6 +452,9 @@ class ARROW_EXPORT BooleanType : public FixedWidthType, public NoExtraMeta {
int bit_width() const override { return 1; }
std::string name() const override { return "bool"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// Concrete type class for unsigned 8-bit integer data
@@ -404,6 +462,9 @@ class ARROW_EXPORT UInt8Type
: public detail::IntegerTypeImpl<UInt8Type, Type::UINT8, uint8_t> {
public:
static constexpr const char* type_name() { return "uint8"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// Concrete type class for signed 8-bit integer data
@@ -411,6 +472,9 @@ class ARROW_EXPORT Int8Type
: public detail::IntegerTypeImpl<Int8Type, Type::INT8, int8_t> {
public:
static constexpr const char* type_name() { return "int8"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// Concrete type class for unsigned 16-bit integer data
@@ -418,6 +482,9 @@ class ARROW_EXPORT UInt16Type
: public detail::IntegerTypeImpl<UInt16Type, Type::UINT16, uint16_t> {
public:
static constexpr const char* type_name() { return "uint16"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// Concrete type class for signed 16-bit integer data
@@ -425,6 +492,9 @@ class ARROW_EXPORT Int16Type
: public detail::IntegerTypeImpl<Int16Type, Type::INT16, int16_t> {
public:
static constexpr const char* type_name() { return "int16"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// Concrete type class for unsigned 32-bit integer data
@@ -432,6 +502,9 @@ class ARROW_EXPORT UInt32Type
: public detail::IntegerTypeImpl<UInt32Type, Type::UINT32, uint32_t> {
public:
static constexpr const char* type_name() { return "uint32"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// Concrete type class for signed 32-bit integer data
@@ -439,6 +512,9 @@ class ARROW_EXPORT Int32Type
: public detail::IntegerTypeImpl<Int32Type, Type::INT32, int32_t> {
public:
static constexpr const char* type_name() { return "int32"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// Concrete type class for unsigned 64-bit integer data
@@ -446,6 +522,9 @@ class ARROW_EXPORT UInt64Type
: public detail::IntegerTypeImpl<UInt64Type, Type::UINT64, uint64_t> {
public:
static constexpr const char* type_name() { return "uint64"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// Concrete type class for signed 64-bit integer data
@@ -453,6 +532,9 @@ class ARROW_EXPORT Int64Type
: public detail::IntegerTypeImpl<Int64Type, Type::INT64, int64_t> {
public:
static constexpr const char* type_name() { return "int64"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// Concrete type class for 16-bit floating-point data
@@ -462,6 +544,9 @@ class ARROW_EXPORT HalfFloatType
public:
Precision precision() const override;
static constexpr const char* type_name() { return "halffloat"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// Concrete type class for 32-bit floating-point data (C "float")
@@ -470,6 +555,9 @@ class ARROW_EXPORT FloatType
public:
Precision precision() const override;
static constexpr const char* type_name() { return "float"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// Concrete type class for 64-bit floating-point data (C "double")
@@ -478,6 +566,9 @@ class ARROW_EXPORT DoubleType
public:
Precision precision() const override;
static constexpr const char* type_name() { return "double"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// \brief Base class for all variable-size list data types
@@ -517,6 +608,9 @@ class ARROW_EXPORT ListType : public BaseListType {
std::string ToString() const override;
std::string name() const override { return "list"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// \brief Concrete type class for large list data
@@ -549,6 +643,9 @@ class ARROW_EXPORT LargeListType : public BaseListType {
std::string ToString() const override;
std::string name() const override { return "large_list"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// \brief Concrete type class for map data
@@ -576,6 +673,8 @@ class ARROW_EXPORT MapType : public ListType {
bool keys_sorted() const { return keys_sorted_; }
private:
+ std::string ComputeFingerprint() const override;
+
bool keys_sorted_;
};
@@ -638,6 +737,8 @@ class ARROW_EXPORT BinaryType : public BaseBinaryType {
std::string name() const override { return "binary"; }
protected:
+ std::string ComputeFingerprint() const override;
+
// Allow subclasses like StringType to change the logical type.
explicit BinaryType(Type::type logical_type) : BaseBinaryType(logical_type) {}
};
@@ -662,6 +763,8 @@ class ARROW_EXPORT LargeBinaryType : public BaseBinaryType {
std::string name() const override { return "large_binary"; }
protected:
+ std::string ComputeFingerprint() const override;
+
// Allow subclasses like LargeStringType to change the logical type.
explicit LargeBinaryType(Type::type logical_type) : BaseBinaryType(logical_type) {}
};
@@ -679,6 +782,9 @@ class ARROW_EXPORT StringType : public BinaryType {
std::string ToString() const override;
std::string name() const override { return "utf8"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// \brief Concrete type class for large variable-size string data, utf8-encoded
@@ -694,6 +800,9 @@ class ARROW_EXPORT LargeStringType : public LargeBinaryType {
std::string ToString() const override;
std::string name() const override { return "large_utf8"; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// \brief Concrete type class for fixed-size binary data
@@ -717,6 +826,8 @@ class ARROW_EXPORT FixedSizeBinaryType : public FixedWidthType, public Parametri
int bit_width() const override;
protected:
+ std::string ComputeFingerprint() const override;
+
int32_t byte_width_;
};
@@ -756,6 +867,8 @@ class ARROW_EXPORT StructType : public NestedType {
int GetChildIndex(const std::string& name) const;
private:
+ std::string ComputeFingerprint() const override;
+
class Impl;
std::unique_ptr<Impl> impl_;
};
@@ -772,6 +885,8 @@ class ARROW_EXPORT DecimalType : public FixedSizeBinaryType {
int32_t scale() const { return scale_; }
protected:
+ std::string ComputeFingerprint() const override;
+
int32_t precision_;
int32_t scale_;
};
@@ -816,6 +931,8 @@ class ARROW_EXPORT UnionType : public NestedType {
UnionMode::type mode() const { return mode_; }
private:
+ std::string ComputeFingerprint() const override;
+
UnionMode::type mode_;
// The type id used in the data to indicate each data type in the union. For
@@ -864,6 +981,9 @@ class ARROW_EXPORT Date32Type : public DateType {
std::string name() const override { return "date32"; }
DateUnit unit() const override { return UNIT; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// Concrete type class for 64-bit date data (as number of milliseconds since UNIX epoch)
@@ -884,6 +1004,9 @@ class ARROW_EXPORT Date64Type : public DateType {
std::string name() const override { return "date64"; }
DateUnit unit() const override { return UNIT; }
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
struct TimeUnit {
@@ -901,6 +1024,8 @@ class ARROW_EXPORT TimeType : public TemporalType, public ParametricType {
protected:
TimeType(Type::type type_id, TimeUnit::type unit);
+ std::string ComputeFingerprint() const override;
+
TimeUnit::type unit_;
};
@@ -995,6 +1120,9 @@ class ARROW_EXPORT TimestampType : public TemporalType, public ParametricType {
TimeUnit::type unit() const { return unit_; }
const std::string& timezone() const { return timezone_; }
+ protected:
+ std::string ComputeFingerprint() const override;
+
private:
TimeUnit::type unit_;
std::string timezone_;
@@ -1007,7 +1135,9 @@ class ARROW_EXPORT IntervalType : public TemporalType, public ParametricType {
IntervalType() : TemporalType(Type::INTERVAL) {}
virtual type interval_type() const = 0;
- virtual ~IntervalType() = default;
+
+ protected:
+ std::string ComputeFingerprint() const override;
};
/// \brief Represents a some number of months.
@@ -1080,6 +1210,9 @@ class ARROW_EXPORT DurationType : public TemporalType, public ParametricType {
TimeUnit::type unit() const { return unit_; }
+ protected:
+ std::string ComputeFingerprint() const override;
+
private:
TimeUnit::type unit_;
};
@@ -1134,6 +1267,8 @@ class ARROW_EXPORT DictionaryType : public FixedWidthType {
std::vector<std::vector<int32_t>>* out_transpose_maps = NULLPTR);
protected:
+ std::string ComputeFingerprint() const override;
+
// Must be an integer type (not currently checked)
std::shared_ptr<DataType> index_type_;
std::shared_ptr<DataType> value_type_;
@@ -1146,7 +1281,7 @@ class ARROW_EXPORT DictionaryType : public FixedWidthType {
/// \class Schema
/// \brief Sequence of arrow::Field objects describing the columns of a record
/// batch or table data structure
-class ARROW_EXPORT Schema {
+class ARROW_EXPORT Schema : public detail::Fingerprintable {
public:
explicit Schema(const std::vector<std::shared_ptr<Field>>& fields,
const std::shared_ptr<const KeyValueMetadata>& metadata = NULLPTR);
@@ -1156,7 +1291,7 @@ class ARROW_EXPORT Schema {
Schema(const Schema&);
- virtual ~Schema();
+ ~Schema() override;
/// Returns true if all of the schema fields are equal
bool Equals(const Schema& other, bool check_metadata = true) const;
@@ -1210,6 +1345,10 @@ class ARROW_EXPORT Schema {
/// \brief Indicates that Schema has non-empty KevValueMetadata
bool HasMetadata() const;
+ protected:
+ std::string ComputeFingerprint() const override;
+ std::string ComputeMetadataFingerprint() const override;
+
private:
class Impl;
std::unique_ptr<Impl> impl_;
diff --git a/cpp/src/arrow/util/key-value-metadata-test.cc b/cpp/src/arrow/util/key-value-metadata-test.cc
index 71cd3ad..2a147f6 100644
--- a/cpp/src/arrow/util/key-value-metadata-test.cc
+++ b/cpp/src/arrow/util/key-value-metadata-test.cc
@@ -134,4 +134,22 @@ bar: buzz)";
ASSERT_EQ(expected, result);
}
+TEST(KeyValueMetadataTest, SortedPairs) {
+ std::vector<std::string> keys = {"foo", "bar"};
+ std::vector<std::string> values = {"bizz", "buzz"};
+
+ KeyValueMetadata metadata1(keys, values);
+ std::reverse(keys.begin(), keys.end());
+ KeyValueMetadata metadata2(keys, values);
+ std::reverse(values.begin(), values.end());
+ KeyValueMetadata metadata3(keys, values);
+
+ std::vector<std::pair<std::string, std::string>> expected = {{"bar", "buzz"},
+ {"foo", "bizz"}};
+ ASSERT_EQ(metadata1.sorted_pairs(), expected);
+ ASSERT_EQ(metadata3.sorted_pairs(), expected);
+ expected = {{"bar", "bizz"}, {"foo", "buzz"}};
+ ASSERT_EQ(metadata2.sorted_pairs(), expected);
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/util/key_value_metadata.cc b/cpp/src/arrow/util/key_value_metadata.cc
index 2335dd2..ca597f1 100644
--- a/cpp/src/arrow/util/key_value_metadata.cc
+++ b/cpp/src/arrow/util/key_value_metadata.cc
@@ -106,6 +106,17 @@ const std::string& KeyValueMetadata::value(int64_t i) const {
return values_[i];
}
+std::vector<std::pair<std::string, std::string>> KeyValueMetadata::sorted_pairs() const {
+ std::vector<std::pair<std::string, std::string>> pairs;
+ pairs.reserve(size());
+
+ auto indices = internal::ArgSort(keys_);
+ for (const auto i : indices) {
+ pairs.emplace_back(keys_[i], values_[i]);
+ }
+ return pairs;
+}
+
int KeyValueMetadata::FindKey(const std::string& key) const {
for (size_t i = 0; i < keys_.size(); ++i) {
if (keys_[i] == key) {
diff --git a/cpp/src/arrow/util/key_value_metadata.h b/cpp/src/arrow/util/key_value_metadata.h
index 2820c98..d84e060 100644
--- a/cpp/src/arrow/util/key_value_metadata.h
+++ b/cpp/src/arrow/util/key_value_metadata.h
@@ -22,6 +22,7 @@
#include <memory>
#include <string>
#include <unordered_map>
+#include <utility>
#include <vector>
#include "arrow/util/macros.h"
@@ -47,6 +48,7 @@ class ARROW_EXPORT KeyValueMetadata {
const std::string& key(int64_t i) const;
const std::string& value(int64_t i) const;
+ std::vector<std::pair<std::string, std::string>> sorted_pairs() const;
/// \brief Perform linear search for key, returning -1 if not found
int FindKey(const std::string& key) const;
diff --git a/integration/integration_test.py b/integration/integration_test.py
index dbc03c7..b791af5 100644
--- a/integration/integration_test.py
+++ b/integration/integration_test.py
@@ -1177,11 +1177,12 @@ def get_generated_json_files(tempdir=None, flight=False):
class IntegrationRunner(object):
- def __init__(self, json_files, testers, tempdir=None, debug=False):
+ def __init__(self, json_files, testers, args):
self.json_files = json_files
self.testers = testers
- self.temp_dir = tempdir or tempfile.mkdtemp()
- self.debug = debug
+ self.temp_dir = args.tempdir or tempfile.mkdtemp()
+ self.debug = args.debug
+ self.stop_on_error = args.stop_on_error
def run(self):
failures = []
@@ -1291,7 +1292,10 @@ class IntegrationRunner(object):
except Exception:
traceback.print_exc()
yield (test_case, producer, consumer, sys.exc_info())
- continue
+ if self.stop_on_error:
+ break
+ else:
+ continue
def _compare_flight_implementations(self, producer, consumer):
print('##########################################################')
@@ -1339,6 +1343,12 @@ class Tester(object):
self.args = args
self.debug = args.debug
+ def run_shell_command(self, cmd):
+ cmd = ' '.join(cmd)
+ if self.debug:
+ print(cmd)
+ subprocess.check_call(cmd, shell=True)
+
def json_to_file(self, json_path, arrow_path):
raise NotImplementedError
@@ -1502,17 +1512,11 @@ class CPPTester(Tester):
def stream_to_file(self, stream_path, file_path):
cmd = ['cat', stream_path, '|', self.STREAM_TO_FILE, '>', file_path]
- cmd = ' '.join(cmd)
- if self.debug:
- print(cmd)
- os.system(cmd)
+ self.run_shell_command(cmd)
def file_to_stream(self, file_path, stream_path):
cmd = [self.FILE_TO_STREAM, file_path, '>', stream_path]
- cmd = ' '.join(cmd)
- if self.debug:
- print(cmd)
- os.system(cmd)
+ self.run_shell_command(cmd)
@contextlib.contextmanager
def flight_server(self):
@@ -1578,28 +1582,19 @@ class JSTester(Tester):
'--no-warnings', self.JSON_TO_ARROW,
'-a', arrow_path,
'-j', json_path]
- cmd = ' '.join(cmd)
- if self.debug:
- print(cmd)
- os.system(cmd)
+ self.run_shell_command(cmd)
def stream_to_file(self, stream_path, file_path):
cmd = ['cat', stream_path, '|',
'node', '--no-warnings', self.STREAM_TO_FILE, '>',
file_path]
- cmd = ' '.join(cmd)
- if self.debug:
- print(cmd)
- os.system(cmd)
+ self.run_shell_command(cmd)
def file_to_stream(self, file_path, stream_path):
cmd = ['cat', file_path, '|',
'node', '--no-warnings', self.FILE_TO_STREAM, '>',
stream_path]
- cmd = ' '.join(cmd)
- if self.debug:
- print(cmd)
- os.system(cmd)
+ self.run_shell_command(cmd)
class GoTester(Tester):
@@ -1640,17 +1635,11 @@ class GoTester(Tester):
def stream_to_file(self, stream_path, file_path):
cmd = ['cat', stream_path, '|', self.STREAM_TO_FILE, '>', file_path]
- cmd = ' '.join(cmd)
- if self.debug:
- print(cmd)
- os.system(cmd)
+ self.run_shell_command(cmd)
def file_to_stream(self, file_path, stream_path):
cmd = [self.FILE_TO_STREAM, file_path, '>', stream_path]
- cmd = ' '.join(cmd)
- if self.debug:
- print(cmd)
- os.system(cmd)
+ self.run_shell_command(cmd)
def get_static_json_files():
@@ -1680,8 +1669,7 @@ def run_all_tests(args):
flight=args.run_flight)
json_files = static_json_files + generated_json_files
- runner = IntegrationRunner(json_files, testers,
- tempdir=args.tempdir, debug=args.debug)
+ runner = IntegrationRunner(json_files, testers, args)
failures = []
failures.extend(runner.run())
if args.run_flight:
@@ -1748,6 +1736,11 @@ if __name__ == '__main__':
default=tempfile.mkdtemp(),
help=('Directory to use for writing '
'integration test temporary files'))
+
+ parser.add_argument('-x', '--stop-on-error', dest='stop_on_error',
+ action='store_true', default=False,
+ help='Stop on first error')
+
args = parser.parse_args()
if args.generated_json_path:
try: