You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2017/02/26 23:25:09 UTC
arrow git commit: ARROW-451: [C++] Implement DataType::Equals as
TypeVisitor. Add default implementations for TypeVisitor, ArrayVisitor methods
Repository: arrow
Updated Branches:
refs/heads/master 8afe92c6c -> ef3b6b344
ARROW-451: [C++] Implement DataType::Equals as TypeVisitor. Add default implementations for TypeVisitor, ArrayVisitor methods
This patch also resolves ARROW-568. Added tests for TimeType, TimestampType, which were not having their `unit` metadata compared due to an oversight.
Author: Wes McKinney <we...@twosigma.com>
Closes #350 from wesm/ARROW-451 and squashes the following commits:
97e75d8 [Wes McKinney] Export ArrayVisitor, TypeVisitor symbols
a3332be [Wes McKinney] Typo
635e74d [Wes McKinney] Implement DataType::Equals as TypeVisitor, compare child metadata. Add default implementations for TypeVisitor, ArrayVisitor methods
Project: http://git-wip-us.apache.org/repos/asf/arrow/repo
Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/ef3b6b34
Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/ef3b6b34
Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/ef3b6b34
Branch: refs/heads/master
Commit: ef3b6b34482c36615af5064f474363126e755a18
Parents: 8afe92c
Author: Wes McKinney <we...@twosigma.com>
Authored: Sun Feb 26 18:25:03 2017 -0500
Committer: Wes McKinney <we...@twosigma.com>
Committed: Sun Feb 26 18:25:03 2017 -0500
----------------------------------------------------------------------
cpp/src/arrow/CMakeLists.txt | 2 +-
cpp/src/arrow/array.cc | 36 ++++++++
cpp/src/arrow/array.h | 50 +++++------
cpp/src/arrow/compare.cc | 108 +++++++++++++++++++++--
cpp/src/arrow/compare.h | 5 ++
cpp/src/arrow/ipc/adapter.cc | 20 -----
cpp/src/arrow/ipc/json-internal.cc | 30 -------
cpp/src/arrow/schema-test.cc | 122 --------------------------
cpp/src/arrow/type-test.cc | 146 ++++++++++++++++++++++++++++++++
cpp/src/arrow/type.cc | 69 +++++++++++----
cpp/src/arrow/type.h | 64 ++++++--------
11 files changed, 394 insertions(+), 258 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 824ced1..d1efa02 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -58,8 +58,8 @@ ADD_ARROW_TEST(buffer-test)
ADD_ARROW_TEST(column-test)
ADD_ARROW_TEST(memory_pool-test)
ADD_ARROW_TEST(pretty_print-test)
-ADD_ARROW_TEST(schema-test)
ADD_ARROW_TEST(status-test)
+ADD_ARROW_TEST(type-test)
ADD_ARROW_TEST(table-test)
ADD_ARROW_BENCHMARK(builder-benchmark)
http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/array.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc
index 81678e3..eb4c210 100644
--- a/cpp/src/arrow/array.cc
+++ b/cpp/src/arrow/array.cc
@@ -503,4 +503,40 @@ Status MakePrimitiveArray(const std::shared_ptr<DataType>& type, int32_t length,
#endif
}
+// ----------------------------------------------------------------------
+// Default implementations of ArrayVisitor methods
+
+#define ARRAY_VISITOR_DEFAULT(ARRAY_CLASS) \
+ Status ArrayVisitor::Visit(const ARRAY_CLASS& array) { \
+ return Status::NotImplemented(array.type()->ToString()); \
+ }
+
+ARRAY_VISITOR_DEFAULT(NullArray);
+ARRAY_VISITOR_DEFAULT(BooleanArray);
+ARRAY_VISITOR_DEFAULT(Int8Array);
+ARRAY_VISITOR_DEFAULT(Int16Array);
+ARRAY_VISITOR_DEFAULT(Int32Array);
+ARRAY_VISITOR_DEFAULT(Int64Array);
+ARRAY_VISITOR_DEFAULT(UInt8Array);
+ARRAY_VISITOR_DEFAULT(UInt16Array);
+ARRAY_VISITOR_DEFAULT(UInt32Array);
+ARRAY_VISITOR_DEFAULT(UInt64Array);
+ARRAY_VISITOR_DEFAULT(HalfFloatArray);
+ARRAY_VISITOR_DEFAULT(FloatArray);
+ARRAY_VISITOR_DEFAULT(DoubleArray);
+ARRAY_VISITOR_DEFAULT(StringArray);
+ARRAY_VISITOR_DEFAULT(BinaryArray);
+ARRAY_VISITOR_DEFAULT(DateArray);
+ARRAY_VISITOR_DEFAULT(TimeArray);
+ARRAY_VISITOR_DEFAULT(TimestampArray);
+ARRAY_VISITOR_DEFAULT(IntervalArray);
+ARRAY_VISITOR_DEFAULT(ListArray);
+ARRAY_VISITOR_DEFAULT(StructArray);
+ARRAY_VISITOR_DEFAULT(UnionArray);
+ARRAY_VISITOR_DEFAULT(DictionaryArray);
+
+Status ArrayVisitor::Visit(const DecimalArray& array) {
+ return Status::NotImplemented("decimal");
+}
+
} // namespace arrow
http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/array.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h
index 9bb06af..8bb914e 100644
--- a/cpp/src/arrow/array.h
+++ b/cpp/src/arrow/array.h
@@ -38,34 +38,34 @@ class MemoryPool;
class MutableBuffer;
class Status;
-class ArrayVisitor {
+class ARROW_EXPORT ArrayVisitor {
public:
virtual ~ArrayVisitor() = default;
- virtual Status Visit(const NullArray& array) = 0;
- virtual Status Visit(const BooleanArray& array) = 0;
- virtual Status Visit(const Int8Array& array) = 0;
- virtual Status Visit(const Int16Array& array) = 0;
- virtual Status Visit(const Int32Array& array) = 0;
- virtual Status Visit(const Int64Array& array) = 0;
- virtual Status Visit(const UInt8Array& array) = 0;
- virtual Status Visit(const UInt16Array& array) = 0;
- virtual Status Visit(const UInt32Array& array) = 0;
- virtual Status Visit(const UInt64Array& array) = 0;
- virtual Status Visit(const HalfFloatArray& array) = 0;
- virtual Status Visit(const FloatArray& array) = 0;
- virtual Status Visit(const DoubleArray& array) = 0;
- virtual Status Visit(const StringArray& array) = 0;
- virtual Status Visit(const BinaryArray& array) = 0;
- virtual Status Visit(const DateArray& array) = 0;
- virtual Status Visit(const TimeArray& array) = 0;
- virtual Status Visit(const TimestampArray& array) = 0;
- virtual Status Visit(const IntervalArray& array) = 0;
- virtual Status Visit(const DecimalArray& array) = 0;
- virtual Status Visit(const ListArray& array) = 0;
- virtual Status Visit(const StructArray& array) = 0;
- virtual Status Visit(const UnionArray& array) = 0;
- virtual Status Visit(const DictionaryArray& type) = 0;
+ virtual Status Visit(const NullArray& array);
+ virtual Status Visit(const BooleanArray& array);
+ virtual Status Visit(const Int8Array& array);
+ virtual Status Visit(const Int16Array& array);
+ virtual Status Visit(const Int32Array& array);
+ virtual Status Visit(const Int64Array& array);
+ virtual Status Visit(const UInt8Array& array);
+ virtual Status Visit(const UInt16Array& array);
+ virtual Status Visit(const UInt32Array& array);
+ virtual Status Visit(const UInt64Array& array);
+ virtual Status Visit(const HalfFloatArray& array);
+ virtual Status Visit(const FloatArray& array);
+ virtual Status Visit(const DoubleArray& array);
+ virtual Status Visit(const StringArray& array);
+ virtual Status Visit(const BinaryArray& array);
+ virtual Status Visit(const DateArray& array);
+ virtual Status Visit(const TimeArray& array);
+ virtual Status Visit(const TimestampArray& array);
+ virtual Status Visit(const IntervalArray& array);
+ virtual Status Visit(const DecimalArray& array);
+ virtual Status Visit(const ListArray& array);
+ virtual Status Visit(const StructArray& array);
+ virtual Status Visit(const UnionArray& array);
+ virtual Status Visit(const DictionaryArray& type);
};
/// Immutable data array with some logical type and some length.
http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/compare.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc
index 21fdb66..ff3c59f 100644
--- a/cpp/src/arrow/compare.cc
+++ b/cpp/src/arrow/compare.cc
@@ -301,9 +301,9 @@ class RangeEqualsVisitor : public ArrayVisitor {
bool result_;
};
-class EqualsVisitor : public RangeEqualsVisitor {
+class ArrayEqualsVisitor : public RangeEqualsVisitor {
public:
- explicit EqualsVisitor(const Array& right)
+ explicit ArrayEqualsVisitor(const Array& right)
: RangeEqualsVisitor(right, 0, right.length(), 0) {}
Status Visit(const NullArray& left) override { return Status::OK(); }
@@ -511,9 +511,9 @@ inline bool FloatingApproxEquals(
return true;
}
-class ApproxEqualsVisitor : public EqualsVisitor {
+class ApproxEqualsVisitor : public ArrayEqualsVisitor {
public:
- using EqualsVisitor::EqualsVisitor;
+ using ArrayEqualsVisitor::ArrayEqualsVisitor;
Status Visit(const FloatArray& left) override {
result_ =
@@ -549,7 +549,7 @@ Status ArrayEquals(const Array& left, const Array& right, bool* are_equal) {
} else if (left.length() == 0) {
*are_equal = true;
} else {
- EqualsVisitor visitor(right);
+ ArrayEqualsVisitor visitor(right);
RETURN_NOT_OK(left.Accept(&visitor));
*are_equal = visitor.result();
}
@@ -588,4 +588,102 @@ Status ArrayApproxEquals(const Array& left, const Array& right, bool* are_equal)
return Status::OK();
}
+// ----------------------------------------------------------------------
+// Implement TypeEquals
+
+class TypeEqualsVisitor : public TypeVisitor {
+ public:
+ explicit TypeEqualsVisitor(const DataType& right) : right_(right), result_(false) {}
+
+ Status VisitChildren(const DataType& left) {
+ if (left.num_children() != right_.num_children()) {
+ result_ = false;
+ return Status::OK();
+ }
+
+ for (int i = 0; i < left.num_children(); ++i) {
+ if (!left.child(i)->Equals(right_.child(i))) {
+ result_ = false;
+ break;
+ }
+ }
+ result_ = true;
+ return Status::OK();
+ }
+
+ Status Visit(const TimeType& left) override {
+ const auto& right = static_cast<const TimeType&>(right_);
+ result_ = left.unit == right.unit;
+ return Status::OK();
+ }
+
+ Status Visit(const TimestampType& left) override {
+ const auto& right = static_cast<const TimestampType&>(right_);
+ result_ = left.unit == right.unit;
+ return Status::OK();
+ }
+
+ Status Visit(const ListType& left) override { return VisitChildren(left); }
+
+ Status Visit(const StructType& left) override { return VisitChildren(left); }
+
+ Status Visit(const UnionType& left) override {
+ const auto& right = static_cast<const UnionType&>(right_);
+
+ if (left.mode != right.mode || left.type_codes.size() != right.type_codes.size()) {
+ result_ = false;
+ return Status::OK();
+ }
+
+ const std::vector<uint8_t> left_codes = left.type_codes;
+ const std::vector<uint8_t> right_codes = right.type_codes;
+
+ for (size_t i = 0; i < left_codes.size(); ++i) {
+ if (left_codes[i] != right_codes[i]) {
+ result_ = false;
+ break;
+ }
+ }
+ result_ = true;
+ return Status::OK();
+ }
+
+ Status Visit(const DictionaryType& left) override {
+ const auto& right = static_cast<const DictionaryType&>(right_);
+ result_ = left.index_type()->Equals(right.index_type()) &&
+ left.dictionary()->Equals(right.dictionary());
+ return Status::OK();
+ }
+
+ bool result() const { return result_; }
+
+ protected:
+ const DataType& right_;
+ bool result_;
+};
+
+Status TypeEquals(const DataType& left, const DataType& right, bool* are_equal) {
+ // The arrays are the same object
+ if (&left == &right) {
+ *are_equal = true;
+ } else if (left.type != right.type) {
+ *are_equal = false;
+ } else {
+ TypeEqualsVisitor visitor(right);
+ Status s = left.Accept(&visitor);
+
+ // We do not implement any type visitors where there is no additional
+ // metadata to compare.
+ if (s.IsNotImplemented()) {
+ // Not implemented means there is no additional metadata to compare
+ *are_equal = true;
+ } else if (!s.ok()) {
+ return s;
+ } else {
+ *are_equal = visitor.result();
+ }
+ }
+ return Status::OK();
+}
+
} // namespace arrow
http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/compare.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compare.h b/cpp/src/arrow/compare.h
index 2093b65..6a71f9f 100644
--- a/cpp/src/arrow/compare.h
+++ b/cpp/src/arrow/compare.h
@@ -27,6 +27,7 @@
namespace arrow {
class Array;
+struct DataType;
class Status;
/// Returns true if the arrays are exactly equal
@@ -41,6 +42,10 @@ Status ARROW_EXPORT ArrayApproxEquals(
Status ARROW_EXPORT ArrayRangeEquals(const Array& left, const Array& right,
int32_t start_idx, int32_t end_idx, int32_t other_start_idx, bool* are_equal);
+/// Returns true if the type metadata are exactly equal
+Status ARROW_EXPORT TypeEquals(
+ const DataType& left, const DataType& right, bool* are_equal);
+
} // namespace arrow
#endif // ARROW_COMPARE_H
http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/ipc/adapter.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/ipc/adapter.cc b/cpp/src/arrow/ipc/adapter.cc
index 08ac983..2be87a3 100644
--- a/cpp/src/arrow/ipc/adapter.cc
+++ b/cpp/src/arrow/ipc/adapter.cc
@@ -227,8 +227,6 @@ class RecordBatchWriter : public ArrayVisitor {
}
protected:
- Status Visit(const NullArray& array) override { return Status::NotImplemented("null"); }
-
template <typename ArrayType>
Status VisitFixedWidth(const ArrayType& array) {
std::shared_ptr<Buffer> data_buffer = array.data();
@@ -360,14 +358,6 @@ class RecordBatchWriter : public ArrayVisitor {
return VisitFixedWidth<TimestampArray>(array);
}
- Status Visit(const IntervalArray& array) override {
- return Status::NotImplemented("interval");
- }
-
- Status Visit(const DecimalArray& array) override {
- return Status::NotImplemented("decimal");
- }
-
Status Visit(const ListArray& array) override {
std::shared_ptr<Buffer> value_offsets;
RETURN_NOT_OK(GetZeroBasedValueOffsets<ListArray>(array, &value_offsets));
@@ -653,8 +643,6 @@ class ArrayLoader : public TypeVisitor {
return Status::OK();
}
- Status Visit(const NullType& type) override { return Status::NotImplemented("null"); }
-
Status Visit(const BooleanType& type) override { return LoadPrimitive(type); }
Status Visit(const Int8Type& type) override { return LoadPrimitive(type); }
@@ -689,14 +677,6 @@ class ArrayLoader : public TypeVisitor {
Status Visit(const TimestampType& type) override { return LoadPrimitive(type); }
- Status Visit(const IntervalType& type) override {
- return Status::NotImplemented(type.ToString());
- }
-
- Status Visit(const DecimalType& type) override {
- return Status::NotImplemented(type.ToString());
- }
-
Status Visit(const ListType& type) override {
FieldMetadata field_meta;
std::shared_ptr<Buffer> null_bitmap, offsets;
http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/ipc/json-internal.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/ipc/json-internal.cc b/cpp/src/arrow/ipc/json-internal.cc
index b9f97dd..6253cd6 100644
--- a/cpp/src/arrow/ipc/json-internal.cc
+++ b/cpp/src/arrow/ipc/json-internal.cc
@@ -316,8 +316,6 @@ class JsonSchemaWriter : public TypeVisitor {
return WritePrimitive("interval", type);
}
- Status Visit(const DecimalType& type) override { return Status::NotImplemented("NYI"); }
-
Status Visit(const ListType& type) override {
WriteName("list", type);
RETURN_NOT_OK(WriteChildren(type.children()));
@@ -339,14 +337,6 @@ class JsonSchemaWriter : public TypeVisitor {
return Status::OK();
}
- Status Visit(const DictionaryType& type) override {
- // WriteName("dictionary", type);
- // WriteChildren(type.children());
- // WriteBufferLayout(type.GetBufferLayout());
- // return Status::OK();
- return Status::NotImplemented("dictionary type");
- }
-
private:
const Schema& schema_;
RjWriter* writer_;
@@ -531,22 +521,6 @@ class JsonArrayWriter : public ArrayVisitor {
Status Visit(const BinaryArray& array) override { return WriteVarBytes(array); }
- Status Visit(const DateArray& array) override { return Status::NotImplemented("date"); }
-
- Status Visit(const TimeArray& array) override { return Status::NotImplemented("time"); }
-
- Status Visit(const TimestampArray& array) override {
- return Status::NotImplemented("timestamp");
- }
-
- Status Visit(const IntervalArray& array) override {
- return Status::NotImplemented("interval");
- }
-
- Status Visit(const DecimalArray& array) override {
- return Status::NotImplemented("decimal");
- }
-
Status Visit(const ListArray& array) override {
WriteValidityField(array);
WriteIntegerField("OFFSET", array.raw_value_offsets(), array.length() + 1);
@@ -571,10 +545,6 @@ class JsonArrayWriter : public ArrayVisitor {
return WriteChildren(type->children(), array.children());
}
- Status Visit(const DictionaryArray& array) override {
- return Status::NotImplemented("dictionary");
- }
-
private:
const std::string& name_;
const Array& array_;
http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/schema-test.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/schema-test.cc b/cpp/src/arrow/schema-test.cc
deleted file mode 100644
index 4826199..0000000
--- a/cpp/src/arrow/schema-test.cc
+++ /dev/null
@@ -1,122 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "gtest/gtest.h"
-
-#include "arrow/schema.h"
-#include "arrow/type.h"
-
-using std::shared_ptr;
-using std::vector;
-
-namespace arrow {
-
-TEST(TestField, Basics) {
- Field f0("f0", int32());
- Field f0_nn("f0", int32(), false);
-
- ASSERT_EQ(f0.name, "f0");
- ASSERT_EQ(f0.type->ToString(), int32()->ToString());
-
- ASSERT_TRUE(f0.nullable);
- ASSERT_FALSE(f0_nn.nullable);
-}
-
-TEST(TestField, Equals) {
- Field f0("f0", int32());
- Field f0_nn("f0", int32(), false);
- Field f0_other("f0", int32());
-
- ASSERT_EQ(f0, f0_other);
- ASSERT_NE(f0, f0_nn);
-}
-
-class TestSchema : public ::testing::Test {
- public:
- void SetUp() {}
-};
-
-TEST_F(TestSchema, Basics) {
- auto f0 = field("f0", int32());
- auto f1 = field("f1", uint8(), false);
- auto f1_optional = field("f1", uint8());
-
- auto f2 = field("f2", utf8());
-
- vector<shared_ptr<Field>> fields = {f0, f1, f2};
- auto schema = std::make_shared<Schema>(fields);
-
- ASSERT_EQ(3, schema->num_fields());
- ASSERT_EQ(f0, schema->field(0));
- ASSERT_EQ(f1, schema->field(1));
- ASSERT_EQ(f2, schema->field(2));
-
- auto schema2 = std::make_shared<Schema>(fields);
-
- vector<shared_ptr<Field>> fields3 = {f0, f1_optional, f2};
- auto schema3 = std::make_shared<Schema>(fields3);
- ASSERT_TRUE(schema->Equals(schema2));
- ASSERT_FALSE(schema->Equals(schema3));
-
- ASSERT_TRUE(schema->Equals(*schema2.get()));
- ASSERT_FALSE(schema->Equals(*schema3.get()));
-}
-
-TEST_F(TestSchema, ToString) {
- auto f0 = field("f0", int32());
- auto f1 = field("f1", uint8(), false);
- auto f2 = field("f2", utf8());
- auto f3 = field("f3", list(int16()));
-
- vector<shared_ptr<Field>> fields = {f0, f1, f2, f3};
- auto schema = std::make_shared<Schema>(fields);
-
- std::string result = schema->ToString();
- std::string expected = R"(f0: int32
-f1: uint8 not null
-f2: string
-f3: list<item: int16>)";
-
- ASSERT_EQ(expected, result);
-}
-
-TEST_F(TestSchema, GetFieldByName) {
- auto f0 = field("f0", int32());
- auto f1 = field("f1", uint8(), false);
- auto f2 = field("f2", utf8());
- auto f3 = field("f3", list(int16()));
-
- vector<shared_ptr<Field>> fields = {f0, f1, f2, f3};
- auto schema = std::make_shared<Schema>(fields);
-
- std::shared_ptr<Field> result;
-
- result = schema->GetFieldByName("f1");
- ASSERT_TRUE(f1->Equals(result));
-
- result = schema->GetFieldByName("f3");
- ASSERT_TRUE(f3->Equals(result));
-
- result = schema->GetFieldByName("not-found");
- ASSERT_TRUE(result == nullptr);
-}
-
-} // namespace arrow
http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/type-test.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/type-test.cc b/cpp/src/arrow/type-test.cc
new file mode 100644
index 0000000..fe6c62a
--- /dev/null
+++ b/cpp/src/arrow/type-test.cc
@@ -0,0 +1,146 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Unit tests for DataType (and subclasses), Field, and Schema
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "arrow/schema.h"
+#include "arrow/type.h"
+
+using std::shared_ptr;
+using std::vector;
+
+namespace arrow {
+
+TEST(TestField, Basics) {
+ Field f0("f0", int32());
+ Field f0_nn("f0", int32(), false);
+
+ ASSERT_EQ(f0.name, "f0");
+ ASSERT_EQ(f0.type->ToString(), int32()->ToString());
+
+ ASSERT_TRUE(f0.nullable);
+ ASSERT_FALSE(f0_nn.nullable);
+}
+
+TEST(TestField, Equals) {
+ Field f0("f0", int32());
+ Field f0_nn("f0", int32(), false);
+ Field f0_other("f0", int32());
+
+ ASSERT_TRUE(f0.Equals(f0_other));
+ ASSERT_FALSE(f0.Equals(f0_nn));
+}
+
+class TestSchema : public ::testing::Test {
+ public:
+ void SetUp() {}
+};
+
+TEST_F(TestSchema, Basics) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f1_optional = field("f1", uint8());
+
+ auto f2 = field("f2", utf8());
+
+ vector<shared_ptr<Field>> fields = {f0, f1, f2};
+ auto schema = std::make_shared<Schema>(fields);
+
+ ASSERT_EQ(3, schema->num_fields());
+ ASSERT_TRUE(f0->Equals(schema->field(0)));
+ ASSERT_TRUE(f1->Equals(schema->field(1)));
+ ASSERT_TRUE(f2->Equals(schema->field(2)));
+
+ auto schema2 = std::make_shared<Schema>(fields);
+
+ vector<shared_ptr<Field>> fields3 = {f0, f1_optional, f2};
+ auto schema3 = std::make_shared<Schema>(fields3);
+ ASSERT_TRUE(schema->Equals(schema2));
+ ASSERT_FALSE(schema->Equals(schema3));
+
+ ASSERT_TRUE(schema->Equals(*schema2.get()));
+ ASSERT_FALSE(schema->Equals(*schema3.get()));
+}
+
+TEST_F(TestSchema, ToString) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8());
+ auto f3 = field("f3", list(int16()));
+
+ vector<shared_ptr<Field>> fields = {f0, f1, f2, f3};
+ auto schema = std::make_shared<Schema>(fields);
+
+ std::string result = schema->ToString();
+ std::string expected = R"(f0: int32
+f1: uint8 not null
+f2: string
+f3: list<item: int16>)";
+
+ ASSERT_EQ(expected, result);
+}
+
+TEST_F(TestSchema, GetFieldByName) {
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8(), false);
+ auto f2 = field("f2", utf8());
+ auto f3 = field("f3", list(int16()));
+
+ vector<shared_ptr<Field>> fields = {f0, f1, f2, f3};
+ auto schema = std::make_shared<Schema>(fields);
+
+ std::shared_ptr<Field> result;
+
+ result = schema->GetFieldByName("f1");
+ ASSERT_TRUE(f1->Equals(result));
+
+ result = schema->GetFieldByName("f3");
+ ASSERT_TRUE(f3->Equals(result));
+
+ result = schema->GetFieldByName("not-found");
+ ASSERT_TRUE(result == nullptr);
+}
+
+TEST(TestTimeType, Equals) {
+ TimeType t1;
+ TimeType t2;
+ TimeType t3(TimeUnit::NANO);
+ TimeType t4(TimeUnit::NANO);
+
+ ASSERT_TRUE(t1.Equals(t2));
+ ASSERT_FALSE(t1.Equals(t3));
+ ASSERT_TRUE(t3.Equals(t4));
+}
+
+TEST(TestTimestampType, Equals) {
+ TimestampType t1;
+ TimestampType t2;
+ TimestampType t3(TimeUnit::NANO);
+ TimestampType t4(TimeUnit::NANO);
+
+ ASSERT_TRUE(t1.Equals(t2));
+ ASSERT_FALSE(t1.Equals(t3));
+ ASSERT_TRUE(t3.Equals(t4));
+}
+
+} // namespace arrow
http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/type.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc
index b97b465..23fa681 100644
--- a/cpp/src/arrow/type.cc
+++ b/cpp/src/arrow/type.cc
@@ -21,6 +21,7 @@
#include <string>
#include "arrow/array.h"
+#include "arrow/compare.h"
#include "arrow/status.h"
#include "arrow/util/logging.h"
@@ -46,16 +47,14 @@ std::string Field::ToString() const {
DataType::~DataType() {}
bool DataType::Equals(const DataType& other) const {
- bool equals =
- ((this == &other) || ((this->type == other.type) &&
- ((this->num_children() == other.num_children()))));
- if (equals) {
- for (int i = 0; i < num_children(); ++i) {
- // TODO(emkornfield) limit recursion
- if (!children_[i]->Equals(other.children_[i])) { return false; }
- }
- }
- return equals;
+ bool are_equal = false;
+ Status error = TypeEquals(*this, other, &are_equal);
+ if (!error.ok()) { DCHECK(false) << "Types not comparable: " << error.ToString(); }
+ return are_equal;
+}
+
+bool DataType::Equals(const std::shared_ptr<DataType>& other) const {
+ return Equals(*other.get());
}
std::string BooleanType::ToString() const {
@@ -104,6 +103,15 @@ std::string DateType::ToString() const {
return std::string("date");
}
+// ----------------------------------------------------------------------
+// Union type
+
+UnionType::UnionType(const std::vector<std::shared_ptr<Field>>& fields,
+ const std::vector<uint8_t>& type_codes, UnionMode mode)
+ : DataType(Type::UNION), mode(mode), type_codes(type_codes) {
+ children_ = fields;
+}
+
std::string UnionType::ToString() const {
std::stringstream s;
@@ -138,14 +146,6 @@ std::shared_ptr<Array> DictionaryType::dictionary() const {
return dictionary_;
}
-bool DictionaryType::Equals(const DataType& other) const {
- if (other.type != Type::DICTIONARY) { return false; }
- const auto& other_dict = static_cast<const DictionaryType&>(other);
-
- return index_type_->Equals(other_dict.index_type_) &&
- dictionary_->Equals(other_dict.dictionary_);
-}
-
std::string DictionaryType::ToString() const {
std::stringstream ss;
ss << "dictionary<values=" << dictionary_->type()->ToString()
@@ -286,4 +286,37 @@ std::vector<BufferDescr> DecimalType::GetBufferLayout() const {
return {};
}
+// ----------------------------------------------------------------------
+// Default implementations of TypeVisitor methods
+
+#define TYPE_VISITOR_DEFAULT(TYPE_CLASS) \
+ Status TypeVisitor::Visit(const TYPE_CLASS& type) { \
+ return Status::NotImplemented(type.ToString()); \
+ }
+
+TYPE_VISITOR_DEFAULT(NullType);
+TYPE_VISITOR_DEFAULT(BooleanType);
+TYPE_VISITOR_DEFAULT(Int8Type);
+TYPE_VISITOR_DEFAULT(Int16Type);
+TYPE_VISITOR_DEFAULT(Int32Type);
+TYPE_VISITOR_DEFAULT(Int64Type);
+TYPE_VISITOR_DEFAULT(UInt8Type);
+TYPE_VISITOR_DEFAULT(UInt16Type);
+TYPE_VISITOR_DEFAULT(UInt32Type);
+TYPE_VISITOR_DEFAULT(UInt64Type);
+TYPE_VISITOR_DEFAULT(HalfFloatType);
+TYPE_VISITOR_DEFAULT(FloatType);
+TYPE_VISITOR_DEFAULT(DoubleType);
+TYPE_VISITOR_DEFAULT(StringType);
+TYPE_VISITOR_DEFAULT(BinaryType);
+TYPE_VISITOR_DEFAULT(DateType);
+TYPE_VISITOR_DEFAULT(TimeType);
+TYPE_VISITOR_DEFAULT(TimestampType);
+TYPE_VISITOR_DEFAULT(IntervalType);
+TYPE_VISITOR_DEFAULT(DecimalType);
+TYPE_VISITOR_DEFAULT(ListType);
+TYPE_VISITOR_DEFAULT(StructType);
+TYPE_VISITOR_DEFAULT(UnionType);
+TYPE_VISITOR_DEFAULT(DictionaryType);
+
} // namespace arrow
http://git-wip-us.apache.org/repos/asf/arrow/blob/ef3b6b34/cpp/src/arrow/type.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index b15aa27..9a97fc3 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -112,34 +112,34 @@ class BufferDescr {
int bit_width_;
};
-class TypeVisitor {
+class ARROW_EXPORT TypeVisitor {
public:
virtual ~TypeVisitor() = default;
- virtual Status Visit(const NullType& type) = 0;
- virtual Status Visit(const BooleanType& type) = 0;
- virtual Status Visit(const Int8Type& type) = 0;
- virtual Status Visit(const Int16Type& type) = 0;
- virtual Status Visit(const Int32Type& type) = 0;
- virtual Status Visit(const Int64Type& type) = 0;
- virtual Status Visit(const UInt8Type& type) = 0;
- virtual Status Visit(const UInt16Type& type) = 0;
- virtual Status Visit(const UInt32Type& type) = 0;
- virtual Status Visit(const UInt64Type& type) = 0;
- virtual Status Visit(const HalfFloatType& type) = 0;
- virtual Status Visit(const FloatType& type) = 0;
- virtual Status Visit(const DoubleType& type) = 0;
- virtual Status Visit(const StringType& type) = 0;
- virtual Status Visit(const BinaryType& type) = 0;
- virtual Status Visit(const DateType& type) = 0;
- virtual Status Visit(const TimeType& type) = 0;
- virtual Status Visit(const TimestampType& type) = 0;
- virtual Status Visit(const IntervalType& type) = 0;
- virtual Status Visit(const DecimalType& type) = 0;
- virtual Status Visit(const ListType& type) = 0;
- virtual Status Visit(const StructType& type) = 0;
- virtual Status Visit(const UnionType& type) = 0;
- virtual Status Visit(const DictionaryType& type) = 0;
+ virtual Status Visit(const NullType& type);
+ virtual Status Visit(const BooleanType& type);
+ virtual Status Visit(const Int8Type& type);
+ virtual Status Visit(const Int16Type& type);
+ virtual Status Visit(const Int32Type& type);
+ virtual Status Visit(const Int64Type& type);
+ virtual Status Visit(const UInt8Type& type);
+ virtual Status Visit(const UInt16Type& type);
+ virtual Status Visit(const UInt32Type& type);
+ virtual Status Visit(const UInt64Type& type);
+ virtual Status Visit(const HalfFloatType& type);
+ virtual Status Visit(const FloatType& type);
+ virtual Status Visit(const DoubleType& type);
+ virtual Status Visit(const StringType& type);
+ virtual Status Visit(const BinaryType& type);
+ virtual Status Visit(const DateType& type);
+ virtual Status Visit(const TimeType& type);
+ virtual Status Visit(const TimestampType& type);
+ virtual Status Visit(const IntervalType& type);
+ virtual Status Visit(const DecimalType& type);
+ virtual Status Visit(const ListType& type);
+ virtual Status Visit(const StructType& type);
+ virtual Status Visit(const UnionType& type);
+ virtual Status Visit(const DictionaryType& type);
};
struct ARROW_EXPORT DataType {
@@ -156,10 +156,7 @@ struct ARROW_EXPORT DataType {
// Types that are logically convertable from one to another e.g. List<UInt8>
// and Binary are NOT equal).
virtual bool Equals(const DataType& other) const;
-
- bool Equals(const std::shared_ptr<DataType>& other) const {
- return Equals(*other.get());
- }
+ bool Equals(const std::shared_ptr<DataType>& other) const;
std::shared_ptr<Field> child(int i) const { return children_[i]; }
@@ -211,8 +208,6 @@ struct ARROW_EXPORT Field {
bool nullable = true)
: name(name), type(type), nullable(nullable) {}
- bool operator==(const Field& other) const { return this->Equals(other); }
- bool operator!=(const Field& other) const { return !this->Equals(other); }
bool Equals(const Field& other) const;
bool Equals(const std::shared_ptr<Field>& other) const;
@@ -411,10 +406,7 @@ struct ARROW_EXPORT UnionType : public DataType {
static constexpr Type::type type_id = Type::UNION;
UnionType(const std::vector<std::shared_ptr<Field>>& fields,
- const std::vector<uint8_t>& type_codes, UnionMode mode = UnionMode::SPARSE)
- : DataType(Type::UNION), mode(mode), type_codes(type_codes) {
- children_ = fields;
- }
+ const std::vector<uint8_t>& type_codes, UnionMode mode = UnionMode::SPARSE);
std::string ToString() const override;
static std::string name() { return "union"; }
@@ -523,8 +515,6 @@ class ARROW_EXPORT DictionaryType : public FixedWidthType {
std::shared_ptr<Array> dictionary() const;
- bool Equals(const DataType& other) const override;
-
Status Accept(TypeVisitor* visitor) const override;
std::string ToString() const override;