You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2017/03/30 22:03:31 UTC
arrow git commit: ARROW-717: [C++] Implement IPC zero-copy round trip
for tensors
Repository: arrow
Updated Branches:
refs/heads/master 15b874e47 -> 957a0e678
ARROW-717: [C++] Implement IPC zero-copy round trip for tensors
This patch provides:
```python
WriteTensor(tensor, file, &metadata_length, &body_length));
std::shared_ptr<Tensor> result;
ReadTensor(offset, file, &result));
```
Also implemented `Tensor::Equals` and did some refactoring / code simplification in compare.cc
Author: Wes McKinney <we...@twosigma.com>
Closes #454 from wesm/ARROW-717 and squashes the following commits:
6c15481 [Wes McKinney] Tensor IPC read/write, and refactoring / code scrubbing
Project: http://git-wip-us.apache.org/repos/asf/arrow/repo
Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/957a0e67
Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/957a0e67
Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/957a0e67
Branch: refs/heads/master
Commit: 957a0e67836b66f8ff4fc3fdae343553c589b53f
Parents: 15b874e
Author: Wes McKinney <we...@twosigma.com>
Authored: Thu Mar 30 18:03:26 2017 -0400
Committer: Wes McKinney <we...@twosigma.com>
Committed: Thu Mar 30 18:03:26 2017 -0400
----------------------------------------------------------------------
cpp/src/arrow/buffer.cc | 6 +-
cpp/src/arrow/compare.cc | 330 ++++++++++++--------------
cpp/src/arrow/compare.h | 4 +
cpp/src/arrow/ipc/ipc-read-write-test.cc | 54 ++++-
cpp/src/arrow/ipc/metadata.cc | 266 +++++++++++++++------
cpp/src/arrow/ipc/metadata.h | 67 +++---
cpp/src/arrow/ipc/reader.cc | 79 +++---
cpp/src/arrow/ipc/reader.h | 32 +--
cpp/src/arrow/ipc/writer.cc | 79 +++---
cpp/src/arrow/ipc/writer.h | 12 +-
cpp/src/arrow/tensor-test.cc | 25 +-
cpp/src/arrow/tensor.cc | 67 +++++-
cpp/src/arrow/tensor.h | 18 +-
cpp/src/arrow/type_traits.h | 11 +
cpp/src/arrow/visitor_inline.h | 26 ++
format/Tensor.fbs | 14 +-
16 files changed, 656 insertions(+), 434 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/buffer.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/buffer.cc b/cpp/src/arrow/buffer.cc
index be747e1..5962340 100644
--- a/cpp/src/arrow/buffer.cc
+++ b/cpp/src/arrow/buffer.cc
@@ -27,11 +27,9 @@
namespace arrow {
-Buffer::Buffer(const std::shared_ptr<Buffer>& parent, int64_t offset, int64_t size) {
- data_ = parent->data() + offset;
- size_ = size;
+Buffer::Buffer(const std::shared_ptr<Buffer>& parent, int64_t offset, int64_t size)
+ : Buffer(parent->data() + offset, size) {
parent_ = parent;
- capacity_ = size;
}
Buffer::~Buffer() {}
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/compare.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc
index f786222..c2580b4 100644
--- a/cpp/src/arrow/compare.cc
+++ b/cpp/src/arrow/compare.cc
@@ -25,6 +25,7 @@
#include "arrow/array.h"
#include "arrow/status.h"
+#include "arrow/tensor.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/bit-util.h"
@@ -36,7 +37,7 @@ namespace arrow {
// ----------------------------------------------------------------------
// Public method implementations
-class RangeEqualsVisitor : public ArrayVisitor {
+class RangeEqualsVisitor {
public:
RangeEqualsVisitor(const Array& right, int64_t left_start_idx, int64_t left_end_idx,
int64_t right_start_idx)
@@ -46,12 +47,6 @@ class RangeEqualsVisitor : public ArrayVisitor {
right_start_idx_(right_start_idx),
result_(false) {}
- Status Visit(const NullArray& left) override {
- UNUSED(left);
- result_ = true;
- return Status::OK();
- }
-
template <typename ArrayType>
inline Status CompareValues(const ArrayType& left) {
const auto& right = static_cast<const ArrayType&>(right_);
@@ -96,108 +91,6 @@ class RangeEqualsVisitor : public ArrayVisitor {
return true;
}
- Status Visit(const BooleanArray& left) override {
- return CompareValues<BooleanArray>(left);
- }
-
- Status Visit(const Int8Array& left) override { return CompareValues<Int8Array>(left); }
-
- Status Visit(const Int16Array& left) override {
- return CompareValues<Int16Array>(left);
- }
- Status Visit(const Int32Array& left) override {
- return CompareValues<Int32Array>(left);
- }
- Status Visit(const Int64Array& left) override {
- return CompareValues<Int64Array>(left);
- }
- Status Visit(const UInt8Array& left) override {
- return CompareValues<UInt8Array>(left);
- }
- Status Visit(const UInt16Array& left) override {
- return CompareValues<UInt16Array>(left);
- }
- Status Visit(const UInt32Array& left) override {
- return CompareValues<UInt32Array>(left);
- }
- Status Visit(const UInt64Array& left) override {
- return CompareValues<UInt64Array>(left);
- }
- Status Visit(const FloatArray& left) override {
- return CompareValues<FloatArray>(left);
- }
- Status Visit(const DoubleArray& left) override {
- return CompareValues<DoubleArray>(left);
- }
-
- Status Visit(const HalfFloatArray& left) override {
- return Status::NotImplemented("Half float type");
- }
-
- Status Visit(const StringArray& left) override {
- result_ = CompareBinaryRange(left);
- return Status::OK();
- }
-
- Status Visit(const BinaryArray& left) override {
- result_ = CompareBinaryRange(left);
- return Status::OK();
- }
-
- Status Visit(const FixedWidthBinaryArray& left) override {
- const auto& right = static_cast<const FixedWidthBinaryArray&>(right_);
-
- int32_t width = left.byte_width();
-
- const uint8_t* left_data = left.raw_data() + left.offset() * width;
- const uint8_t* right_data = right.raw_data() + right.offset() * width;
-
- for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
- ++i, ++o_i) {
- const bool is_null = left.IsNull(i);
- if (is_null != right.IsNull(o_i)) {
- result_ = false;
- return Status::OK();
- }
- if (is_null) continue;
-
- if (std::memcmp(left_data + width * i, right_data + width * o_i, width)) {
- result_ = false;
- return Status::OK();
- }
- }
- result_ = true;
- return Status::OK();
- }
-
- Status Visit(const Date32Array& left) override {
- return CompareValues<Date32Array>(left);
- }
-
- Status Visit(const Date64Array& left) override {
- return CompareValues<Date64Array>(left);
- }
-
- Status Visit(const Time32Array& left) override {
- return CompareValues<Time32Array>(left);
- }
-
- Status Visit(const Time64Array& left) override {
- return CompareValues<Time64Array>(left);
- }
-
- Status Visit(const TimestampArray& left) override {
- return CompareValues<TimestampArray>(left);
- }
-
- Status Visit(const IntervalArray& left) override {
- return CompareValues<IntervalArray>(left);
- }
-
- Status Visit(const DecimalArray& left) override {
- return Status::NotImplemented("Decimal type");
- }
-
bool CompareLists(const ListArray& left) {
const auto& right = static_cast<const ListArray&>(right_);
@@ -225,11 +118,6 @@ class RangeEqualsVisitor : public ArrayVisitor {
return true;
}
- Status Visit(const ListArray& left) override {
- result_ = CompareLists(left);
- return Status::OK();
- }
-
bool CompareStructs(const StructArray& left) {
const auto& right = static_cast<const StructArray&>(right_);
bool equal_fields = true;
@@ -251,11 +139,6 @@ class RangeEqualsVisitor : public ArrayVisitor {
return true;
}
- Status Visit(const StructArray& left) override {
- result_ = CompareStructs(left);
- return Status::OK();
- }
-
bool CompareUnions(const UnionArray& left) const {
const auto& right = static_cast<const UnionArray&>(right_);
@@ -314,12 +197,73 @@ class RangeEqualsVisitor : public ArrayVisitor {
return true;
}
- Status Visit(const UnionArray& left) override {
+ Status Visit(const BinaryArray& left) {
+ result_ = CompareBinaryRange(left);
+ return Status::OK();
+ }
+
+ Status Visit(const FixedWidthBinaryArray& left) {
+ const auto& right = static_cast<const FixedWidthBinaryArray&>(right_);
+
+ int32_t width = left.byte_width();
+
+ const uint8_t* left_data = nullptr;
+ const uint8_t* right_data = nullptr;
+
+ if (left.data()) { left_data = left.raw_data() + left.offset() * width; }
+
+ if (right.data()) { right_data = right.raw_data() + right.offset() * width; }
+
+ for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
+ ++i, ++o_i) {
+ const bool is_null = left.IsNull(i);
+ if (is_null != right.IsNull(o_i)) {
+ result_ = false;
+ return Status::OK();
+ }
+ if (is_null) continue;
+
+ if (std::memcmp(left_data + width * i, right_data + width * o_i, width)) {
+ result_ = false;
+ return Status::OK();
+ }
+ }
+ result_ = true;
+ return Status::OK();
+ }
+
+ Status Visit(const NullArray& left) {
+ UNUSED(left);
+ result_ = true;
+ return Status::OK();
+ }
+
+ template <typename T>
+ typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value, Status>::type Visit(
+ const T& left) {
+ return CompareValues<T>(left);
+ }
+
+ Status Visit(const DecimalArray& left) {
+ return Status::NotImplemented("Decimal type");
+ }
+
+ Status Visit(const ListArray& left) {
+ result_ = CompareLists(left);
+ return Status::OK();
+ }
+
+ Status Visit(const StructArray& left) {
+ result_ = CompareStructs(left);
+ return Status::OK();
+ }
+
+ Status Visit(const UnionArray& left) {
result_ = CompareUnions(left);
return Status::OK();
}
- Status Visit(const DictionaryArray& left) override {
+ Status Visit(const DictionaryArray& left) {
const auto& right = static_cast<const DictionaryArray&>(right_);
if (!left.dictionary()->Equals(right.dictionary())) {
result_ = false;
@@ -346,9 +290,9 @@ class ArrayEqualsVisitor : public RangeEqualsVisitor {
explicit ArrayEqualsVisitor(const Array& right)
: RangeEqualsVisitor(right, 0, right.length(), 0) {}
- Status Visit(const NullArray& left) override { return Status::OK(); }
+ Status Visit(const NullArray& left) { return Status::OK(); }
- Status Visit(const BooleanArray& left) override {
+ Status Visit(const BooleanArray& left) {
const auto& right = static_cast<const BooleanArray&>(right_);
if (left.null_count() > 0) {
const uint8_t* left_data = left.data()->data();
@@ -372,64 +316,39 @@ class ArrayEqualsVisitor : public RangeEqualsVisitor {
bool IsEqualPrimitive(const PrimitiveArray& left) {
const auto& right = static_cast<const PrimitiveArray&>(right_);
const auto& size_meta = dynamic_cast<const FixedWidthType&>(*left.type());
- const int value_byte_size = size_meta.bit_width() / 8;
- DCHECK_GT(value_byte_size, 0);
+ const int byte_width = size_meta.bit_width() / 8;
+
+ const uint8_t* left_data = nullptr;
+ const uint8_t* right_data = nullptr;
+
+ if (left.data()) { left_data = left.data()->data() + left.offset() * byte_width; }
- const uint8_t* left_data = left.data()->data() + left.offset() * value_byte_size;
- const uint8_t* right_data = right.data()->data() + right.offset() * value_byte_size;
+ if (right.data()) { right_data = right.data()->data() + right.offset() * byte_width; }
if (left.null_count() > 0) {
for (int64_t i = 0; i < left.length(); ++i) {
- if (!left.IsNull(i) && memcmp(left_data, right_data, value_byte_size)) {
+ if (!left.IsNull(i) && memcmp(left_data, right_data, byte_width)) {
return false;
}
- left_data += value_byte_size;
- right_data += value_byte_size;
+ left_data += byte_width;
+ right_data += byte_width;
}
return true;
} else {
return memcmp(left_data, right_data,
- static_cast<size_t>(value_byte_size * left.length())) == 0;
+ static_cast<size_t>(byte_width * left.length())) == 0;
}
}
- Status ComparePrimitive(const PrimitiveArray& left) {
+ template <typename T>
+ typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value &&
+ !std::is_base_of<BooleanArray, T>::value,
+ Status>::type
+ Visit(const T& left) {
result_ = IsEqualPrimitive(left);
return Status::OK();
}
- Status Visit(const Int8Array& left) override { return ComparePrimitive(left); }
-
- Status Visit(const Int16Array& left) override { return ComparePrimitive(left); }
-
- Status Visit(const Int32Array& left) override { return ComparePrimitive(left); }
-
- Status Visit(const Int64Array& left) override { return ComparePrimitive(left); }
-
- Status Visit(const UInt8Array& left) override { return ComparePrimitive(left); }
-
- Status Visit(const UInt16Array& left) override { return ComparePrimitive(left); }
-
- Status Visit(const UInt32Array& left) override { return ComparePrimitive(left); }
-
- Status Visit(const UInt64Array& left) override { return ComparePrimitive(left); }
-
- Status Visit(const FloatArray& left) override { return ComparePrimitive(left); }
-
- Status Visit(const DoubleArray& left) override { return ComparePrimitive(left); }
-
- Status Visit(const Date32Array& left) override { return ComparePrimitive(left); }
-
- Status Visit(const Date64Array& left) override { return ComparePrimitive(left); }
-
- Status Visit(const Time32Array& left) override { return ComparePrimitive(left); }
-
- Status Visit(const Time64Array& left) override { return ComparePrimitive(left); }
-
- Status Visit(const TimestampArray& left) override { return ComparePrimitive(left); }
-
- Status Visit(const IntervalArray& left) override { return ComparePrimitive(left); }
-
template <typename ArrayType>
bool ValueOffsetsEqual(const ArrayType& left) {
const auto& right = static_cast<const ArrayType&>(right_);
@@ -494,17 +413,12 @@ class ArrayEqualsVisitor : public RangeEqualsVisitor {
}
}
- Status Visit(const StringArray& left) override {
- result_ = CompareBinary(left);
- return Status::OK();
- }
-
- Status Visit(const BinaryArray& left) override {
+ Status Visit(const BinaryArray& left) {
result_ = CompareBinary(left);
return Status::OK();
}
- Status Visit(const ListArray& left) override {
+ Status Visit(const ListArray& left) {
const auto& right = static_cast<const ListArray&>(right_);
bool equal_offsets = ValueOffsetsEqual<ListArray>(left);
if (!equal_offsets) {
@@ -523,7 +437,7 @@ class ArrayEqualsVisitor : public RangeEqualsVisitor {
return Status::OK();
}
- Status Visit(const DictionaryArray& left) override {
+ Status Visit(const DictionaryArray& left) {
const auto& right = static_cast<const DictionaryArray&>(right_);
if (!left.dictionary()->Equals(right.dictionary())) {
result_ = false;
@@ -532,6 +446,13 @@ class ArrayEqualsVisitor : public RangeEqualsVisitor {
}
return Status::OK();
}
+
+ template <typename T>
+ typename std::enable_if<std::is_base_of<NestedType, typename T::TypeClass>::value,
+ Status>::type
+ Visit(const T& left) {
+ return RangeEqualsVisitor::Visit(left);
+ }
};
template <typename TYPE>
@@ -560,14 +481,15 @@ inline bool FloatingApproxEquals(
class ApproxEqualsVisitor : public ArrayEqualsVisitor {
public:
using ArrayEqualsVisitor::ArrayEqualsVisitor;
+ using ArrayEqualsVisitor::Visit;
- Status Visit(const FloatArray& left) override {
+ Status Visit(const FloatArray& left) {
result_ =
FloatingApproxEquals<FloatType>(left, static_cast<const FloatArray&>(right_));
return Status::OK();
}
- Status Visit(const DoubleArray& left) override {
+ Status Visit(const DoubleArray& left) {
result_ =
FloatingApproxEquals<DoubleType>(left, static_cast<const DoubleArray&>(right_));
return Status::OK();
@@ -586,7 +508,8 @@ static bool BaseDataEquals(const Array& left, const Array& right) {
return true;
}
-Status ArrayEquals(const Array& left, const Array& right, bool* are_equal) {
+template <typename VISITOR>
+inline Status ArrayEqualsImpl(const Array& left, const Array& right, bool* are_equal) {
// The arrays are the same object
if (&left == &right) {
*are_equal = true;
@@ -595,13 +518,21 @@ Status ArrayEquals(const Array& left, const Array& right, bool* are_equal) {
} else if (left.length() == 0) {
*are_equal = true;
} else {
- ArrayEqualsVisitor visitor(right);
- RETURN_NOT_OK(left.Accept(&visitor));
+ VISITOR visitor(right);
+ RETURN_NOT_OK(VisitArrayInline(left, &visitor));
*are_equal = visitor.result();
}
return Status::OK();
}
+Status ArrayEquals(const Array& left, const Array& right, bool* are_equal) {
+ return ArrayEqualsImpl<ArrayEqualsVisitor>(left, right, are_equal);
+}
+
+Status ArrayApproxEquals(const Array& left, const Array& right, bool* are_equal) {
+ return ArrayEqualsImpl<ApproxEqualsVisitor>(left, right, are_equal);
+}
+
Status ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx,
int64_t left_end_idx, int64_t right_start_idx, bool* are_equal) {
if (&left == &right) {
@@ -612,23 +543,56 @@ Status ArrayRangeEquals(const Array& left, const Array& right, int64_t left_star
*are_equal = true;
} else {
RangeEqualsVisitor visitor(right, left_start_idx, left_end_idx, right_start_idx);
- RETURN_NOT_OK(left.Accept(&visitor));
+ RETURN_NOT_OK(VisitArrayInline(left, &visitor));
*are_equal = visitor.result();
}
return Status::OK();
}
-Status ArrayApproxEquals(const Array& left, const Array& right, bool* are_equal) {
+// ----------------------------------------------------------------------
+// Implement TensorEquals
+
+class TensorEqualsVisitor {
+ public:
+ explicit TensorEqualsVisitor(const Tensor& right) : right_(right) {}
+
+ template <typename TensorType>
+ Status Visit(const TensorType& left) {
+ const auto& size_meta = dynamic_cast<const FixedWidthType&>(*left.type());
+ const int byte_width = size_meta.bit_width() / 8;
+ DCHECK_GT(byte_width, 0);
+
+ const uint8_t* left_data = left.data()->data();
+ const uint8_t* right_data = right_.data()->data();
+
+ result_ =
+ memcmp(left_data, right_data, static_cast<size_t>(byte_width * left.size())) == 0;
+ return Status::OK();
+ }
+
+ bool result() const { return result_; }
+
+ protected:
+ const Tensor& right_;
+ bool result_;
+};
+
+Status TensorEquals(const Tensor& left, const Tensor& right, bool* are_equal) {
// The arrays are the same object
if (&left == &right) {
*are_equal = true;
- } else if (!BaseDataEquals(left, right)) {
+ } else if (left.type_enum() != right.type_enum()) {
*are_equal = false;
- } else if (left.length() == 0) {
+ } else if (left.size() == 0) {
*are_equal = true;
} else {
- ApproxEqualsVisitor visitor(right);
- RETURN_NOT_OK(left.Accept(&visitor));
+ if (!left.is_contiguous() || !right.is_contiguous()) {
+ return Status::NotImplemented(
+ "Comparison not implemented for non-contiguous tensors");
+ }
+
+ TensorEqualsVisitor visitor(right);
+ RETURN_NOT_OK(VisitTensorInline(left, &visitor));
*are_equal = visitor.result();
}
return Status::OK();
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/compare.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compare.h b/cpp/src/arrow/compare.h
index 1ddf049..522b11d 100644
--- a/cpp/src/arrow/compare.h
+++ b/cpp/src/arrow/compare.h
@@ -29,10 +29,14 @@ namespace arrow {
class Array;
struct DataType;
class Status;
+class Tensor;
/// Returns true if the arrays are exactly equal
Status ARROW_EXPORT ArrayEquals(const Array& left, const Array& right, bool* are_equal);
+Status ARROW_EXPORT TensorEquals(
+ const Tensor& left, const Tensor& right, bool* are_equal);
+
/// Returns true if the arrays are approximately equal. For non-floating point
/// types, this is equivalent to ArrayEquals(left, right)
Status ARROW_EXPORT ArrayApproxEquals(
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/ipc/ipc-read-write-test.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/ipc/ipc-read-write-test.cc b/cpp/src/arrow/ipc/ipc-read-write-test.cc
index 6ddda3f..74ca017 100644
--- a/cpp/src/arrow/ipc/ipc-read-write-test.cc
+++ b/cpp/src/arrow/ipc/ipc-read-write-test.cc
@@ -25,16 +25,16 @@
#include "gtest/gtest.h"
#include "arrow/array.h"
+#include "arrow/buffer.h"
#include "arrow/io/memory.h"
#include "arrow/io/test-common.h"
#include "arrow/ipc/api.h"
#include "arrow/ipc/test-common.h"
#include "arrow/ipc/util.h"
-
-#include "arrow/buffer.h"
#include "arrow/memory_pool.h"
#include "arrow/pretty_print.h"
#include "arrow/status.h"
+#include "arrow/tensor.h"
#include "arrow/test-util.h"
#include "arrow/util/bit-util.h"
@@ -56,13 +56,10 @@ class TestSchemaMetadata : public ::testing::Test {
ASSERT_EQ(Message::SCHEMA, message->type());
- auto schema_msg = std::make_shared<SchemaMetadata>(message);
- ASSERT_EQ(schema.num_fields(), schema_msg->num_fields());
-
DictionaryMemo empty_memo;
std::shared_ptr<Schema> schema2;
- ASSERT_OK(schema_msg->GetSchema(empty_memo, &schema2));
+ ASSERT_OK(GetSchema(message->header(), empty_memo, &schema2));
AssertSchemaEqual(schema, *schema2);
}
@@ -90,7 +87,7 @@ TEST_F(TestSchemaMetadata, PrimitiveFields) {
}
TEST_F(TestSchemaMetadata, NestedFields) {
- auto type = std::make_shared<ListType>(std::make_shared<Int32Type>());
+ auto type = list(int32());
auto f0 = field("f0", type);
std::shared_ptr<StructType> type2(
@@ -532,7 +529,6 @@ TEST_F(TestIpcRoundTrip, LargeRecordBatch) {
// 512 MB
constexpr int64_t kBufferSize = 1 << 29;
-
ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(kBufferSize, path, &mmap_));
std::shared_ptr<RecordBatch> result;
@@ -580,5 +576,47 @@ TEST_F(TestFileFormat, DictionaryRoundTrip) {
CheckBatchDictionaries(*out_batches[0]);
}
+class TestTensorRoundTrip : public ::testing::Test, public IpcTestFixture {
+ public:
+ void SetUp() { pool_ = default_memory_pool(); }
+ void TearDown() { io::MemoryMapFixture::TearDown(); }
+
+ void CheckTensorRoundTrip(const Tensor& tensor) {
+ int32_t metadata_length;
+ int64_t body_length;
+
+ ASSERT_OK(mmap_->Seek(0));
+
+ ASSERT_OK(WriteTensor(tensor, mmap_.get(), &metadata_length, &body_length));
+
+ std::shared_ptr<Tensor> result;
+ ASSERT_OK(ReadTensor(0, mmap_.get(), &result));
+
+ ASSERT_TRUE(tensor.Equals(*result));
+ }
+};
+
+TEST_F(TestTensorRoundTrip, BasicRoundtrip) {
+ std::string path = "test-write-tensor";
+ constexpr int64_t kBufferSize = 1 << 20;
+ ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(kBufferSize, path, &mmap_));
+
+ std::vector<int64_t> shape = {4, 6};
+ std::vector<int64_t> strides = {48, 8};
+ std::vector<std::string> dim_names = {"foo", "bar"};
+ int64_t size = 24;
+
+ std::vector<int64_t> values;
+ test::randint<int64_t>(size, 0, 100, &values);
+
+ auto data = test::GetBufferFromVector(values);
+
+ Int64Tensor t0(data, shape, strides, dim_names);
+ Int64Tensor tzero(data, {}, {}, {});
+
+ CheckTensorRoundTrip(t0);
+ CheckTensorRoundTrip(tzero);
+}
+
} // namespace ipc
} // namespace arrow
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/ipc/metadata.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/ipc/metadata.cc b/cpp/src/arrow/ipc/metadata.cc
index 6d9fabd..076a6e7 100644
--- a/cpp/src/arrow/ipc/metadata.cc
+++ b/cpp/src/arrow/ipc/metadata.cc
@@ -20,6 +20,7 @@
#include <cstdint>
#include <memory>
#include <sstream>
+#include <string>
#include <vector>
#include "flatbuffers/flatbuffers.h"
@@ -29,7 +30,10 @@
#include "arrow/io/interfaces.h"
#include "arrow/ipc/File_generated.h"
#include "arrow/ipc/Message_generated.h"
+#include "arrow/ipc/Tensor_generated.h"
+#include "arrow/ipc/util.h"
#include "arrow/status.h"
+#include "arrow/tensor.h"
#include "arrow/type.h"
namespace arrow {
@@ -418,6 +422,46 @@ static Status TypeToFlatbuffer(FBB& fbb, const std::shared_ptr<DataType>& type,
return Status::OK();
}
+static Status TensorTypeToFlatbuffer(FBB& fbb, const std::shared_ptr<DataType>& type,
+ flatbuf::Type* out_type, Offset* offset) {
+ switch (type->type) {
+ case Type::UINT8:
+ INT_TO_FB_CASE(8, false);
+ case Type::INT8:
+ INT_TO_FB_CASE(8, true);
+ case Type::UINT16:
+ INT_TO_FB_CASE(16, false);
+ case Type::INT16:
+ INT_TO_FB_CASE(16, true);
+ case Type::UINT32:
+ INT_TO_FB_CASE(32, false);
+ case Type::INT32:
+ INT_TO_FB_CASE(32, true);
+ case Type::UINT64:
+ INT_TO_FB_CASE(64, false);
+ case Type::INT64:
+ INT_TO_FB_CASE(64, true);
+ case Type::HALF_FLOAT:
+ *out_type = flatbuf::Type_FloatingPoint;
+ *offset = FloatToFlatbuffer(fbb, flatbuf::Precision_HALF);
+ break;
+ case Type::FLOAT:
+ *out_type = flatbuf::Type_FloatingPoint;
+ *offset = FloatToFlatbuffer(fbb, flatbuf::Precision_SINGLE);
+ break;
+ case Type::DOUBLE:
+ *out_type = flatbuf::Type_FloatingPoint;
+ *offset = FloatToFlatbuffer(fbb, flatbuf::Precision_DOUBLE);
+ break;
+ default:
+ *out_type = flatbuf::Type_NONE; // Make clang-tidy happy
+ std::stringstream ss;
+ ss << "Unable to convert type: " << type->ToString() << std::endl;
+ return Status::NotImplemented(ss.str());
+ }
+ return Status::OK();
+}
+
static DictionaryOffset GetDictionaryEncoding(
FBB& fbb, const DictionaryType& type, DictionaryMemo* memo) {
int64_t dictionary_id = memo->GetId(type.dictionary());
@@ -552,7 +596,7 @@ static Status WriteFlatbufferBuilder(FBB& fbb, std::shared_ptr<Buffer>* out) {
return Status::OK();
}
-static Status WriteMessage(FBB& fbb, flatbuf::MessageHeader header_type,
+static Status WriteFBMessage(FBB& fbb, flatbuf::MessageHeader header_type,
flatbuffers::Offset<void> header, int64_t body_length, std::shared_ptr<Buffer>* out) {
auto message =
flatbuf::CreateMessage(fbb, kMetadataVersion, header_type, header, body_length);
@@ -565,7 +609,7 @@ Status WriteSchemaMessage(
FBB fbb;
flatbuffers::Offset<flatbuf::Schema> fb_schema;
RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, dictionary_memo, &fb_schema));
- return WriteMessage(fbb, flatbuf::MessageHeader_Schema, fb_schema.Union(), 0, out);
+ return WriteFBMessage(fbb, flatbuf::MessageHeader_Schema, fb_schema.Union(), 0, out);
}
using FieldNodeVector =
@@ -620,10 +664,39 @@ Status WriteRecordBatchMessage(int64_t length, int64_t body_length,
FBB fbb;
RecordBatchOffset record_batch;
RETURN_NOT_OK(MakeRecordBatch(fbb, length, body_length, nodes, buffers, &record_batch));
- return WriteMessage(
+ return WriteFBMessage(
fbb, flatbuf::MessageHeader_RecordBatch, record_batch.Union(), body_length, out);
}
+Status WriteTensorMessage(
+ const Tensor& tensor, int64_t buffer_start_offset, std::shared_ptr<Buffer>* out) {
+ using TensorDimOffset = flatbuffers::Offset<flatbuf::TensorDim>;
+ using TensorOffset = flatbuffers::Offset<flatbuf::Tensor>;
+
+ FBB fbb;
+
+ flatbuf::Type fb_type_type;
+ Offset fb_type;
+ RETURN_NOT_OK(TensorTypeToFlatbuffer(fbb, tensor.type(), &fb_type_type, &fb_type));
+
+ std::vector<TensorDimOffset> dims;
+ for (int i = 0; i < tensor.ndim(); ++i) {
+ FBString name = fbb.CreateString(tensor.dim_name(i));
+ dims.push_back(flatbuf::CreateTensorDim(fbb, tensor.shape()[i], name));
+ }
+
+ auto fb_shape = fbb.CreateVector(dims);
+ auto fb_strides = fbb.CreateVector(tensor.strides());
+ int64_t body_length = tensor.data()->size();
+ flatbuf::Buffer buffer(-1, buffer_start_offset, body_length);
+
+ TensorOffset fb_tensor =
+ flatbuf::CreateTensor(fbb, fb_type_type, fb_type, fb_shape, fb_strides, &buffer);
+
+ return WriteFBMessage(
+ fbb, flatbuf::MessageHeader_Tensor, fb_tensor.Union(), body_length, out);
+}
+
Status WriteDictionaryMessage(int64_t id, int64_t length, int64_t body_length,
const std::vector<FieldMetadata>& nodes, const std::vector<BufferMetadata>& buffers,
std::shared_ptr<Buffer>* out) {
@@ -631,7 +704,7 @@ Status WriteDictionaryMessage(int64_t id, int64_t length, int64_t body_length,
RecordBatchOffset record_batch;
RETURN_NOT_OK(MakeRecordBatch(fbb, length, body_length, nodes, buffers, &record_batch));
auto dictionary_batch = flatbuf::CreateDictionaryBatch(fbb, id, record_batch).Union();
- return WriteMessage(
+ return WriteFBMessage(
fbb, flatbuf::MessageHeader_DictionaryBatch, dictionary_batch, body_length, out);
}
@@ -746,6 +819,8 @@ class Message::MessageImpl {
return Message::DICTIONARY_BATCH;
case flatbuf::MessageHeader_RecordBatch:
return Message::RECORD_BATCH;
+ case flatbuf::MessageHeader_Tensor:
+ return Message::TENSOR;
default:
return Message::NONE;
}
@@ -790,95 +865,78 @@ const void* Message::header() const {
}
// ----------------------------------------------------------------------
-// SchemaMetadata
-
-class MessageHolder {
- public:
- void set_message(const std::shared_ptr<Message>& message) { message_ = message; }
- void set_buffer(const std::shared_ptr<Buffer>& buffer) { buffer_ = buffer; }
-
- protected:
- // Possible parents, owns the flatbuffer data
- std::shared_ptr<Message> message_;
- std::shared_ptr<Buffer> buffer_;
-};
-
-class SchemaMetadata::SchemaMetadataImpl : public MessageHolder {
- public:
- explicit SchemaMetadataImpl(const void* schema)
- : schema_(static_cast<const flatbuf::Schema*>(schema)) {}
-
- const flatbuf::Field* get_field(int i) const { return schema_->fields()->Get(i); }
- int num_fields() const { return schema_->fields()->size(); }
-
- Status VisitField(const flatbuf::Field* field, DictionaryTypeMap* id_to_field) const {
- const flatbuf::DictionaryEncoding* dict_metadata = field->dictionary();
- if (dict_metadata == nullptr) {
- // Field is not dictionary encoded. Visit children
- auto children = field->children();
- for (flatbuffers::uoffset_t i = 0; i < children->size(); ++i) {
- RETURN_NOT_OK(VisitField(children->Get(i), id_to_field));
- }
- } else {
- // Field is dictionary encoded. Construct the data type for the
- // dictionary (no descendents can be dictionary encoded)
- std::shared_ptr<Field> dictionary_field;
- RETURN_NOT_OK(FieldFromFlatbufferDictionary(field, &dictionary_field));
- (*id_to_field)[dict_metadata->id()] = dictionary_field;
+static Status VisitField(const flatbuf::Field* field, DictionaryTypeMap* id_to_field) {
+ const flatbuf::DictionaryEncoding* dict_metadata = field->dictionary();
+ if (dict_metadata == nullptr) {
+ // Field is not dictionary encoded. Visit children
+ auto children = field->children();
+ for (flatbuffers::uoffset_t i = 0; i < children->size(); ++i) {
+ RETURN_NOT_OK(VisitField(children->Get(i), id_to_field));
}
- return Status::OK();
+ } else {
+ // Field is dictionary encoded. Construct the data type for the
+ // dictionary (no descendents can be dictionary encoded)
+ std::shared_ptr<Field> dictionary_field;
+ RETURN_NOT_OK(FieldFromFlatbufferDictionary(field, &dictionary_field));
+ (*id_to_field)[dict_metadata->id()] = dictionary_field;
}
+ return Status::OK();
+}
- Status GetDictionaryTypes(DictionaryTypeMap* id_to_field) const {
- for (int i = 0; i < num_fields(); ++i) {
- RETURN_NOT_OK(VisitField(get_field(i), id_to_field));
- }
- return Status::OK();
+Status GetDictionaryTypes(const void* opaque_schema, DictionaryTypeMap* id_to_field) {
+ auto schema = static_cast<const flatbuf::Schema*>(opaque_schema);
+ int num_fields = static_cast<int>(schema->fields()->size());
+ for (int i = 0; i < num_fields; ++i) {
+ RETURN_NOT_OK(VisitField(schema->fields()->Get(i), id_to_field));
}
-
- private:
- const flatbuf::Schema* schema_;
-};
-
-SchemaMetadata::SchemaMetadata(const std::shared_ptr<Message>& message)
- : SchemaMetadata(message->impl_->header()) {
- impl_->set_message(message);
+ return Status::OK();
}
-SchemaMetadata::SchemaMetadata(const void* header) {
- impl_.reset(new SchemaMetadataImpl(header));
-}
+Status GetSchema(const void* opaque_schema, const DictionaryMemo& dictionary_memo,
+ std::shared_ptr<Schema>* out) {
+ auto schema = static_cast<const flatbuf::Schema*>(opaque_schema);
+ int num_fields = static_cast<int>(schema->fields()->size());
-SchemaMetadata::SchemaMetadata(const std::shared_ptr<Buffer>& buffer, int64_t offset)
- : SchemaMetadata(buffer->data() + offset) {
- // Preserve ownership
- impl_->set_buffer(buffer);
+ std::vector<std::shared_ptr<Field>> fields(num_fields);
+ for (int i = 0; i < num_fields; ++i) {
+ const flatbuf::Field* field = schema->fields()->Get(i);
+ RETURN_NOT_OK(FieldFromFlatbuffer(field, dictionary_memo, &fields[i]));
+ }
+ *out = std::make_shared<Schema>(fields);
+ return Status::OK();
}
-SchemaMetadata::~SchemaMetadata() {}
+Status GetTensorMetadata(const void* opaque_tensor, std::shared_ptr<DataType>* type,
+ std::vector<int64_t>* shape, std::vector<int64_t>* strides,
+ std::vector<std::string>* dim_names) {
+ auto tensor = static_cast<const flatbuf::Tensor*>(opaque_tensor);
-int SchemaMetadata::num_fields() const {
- return impl_->num_fields();
-}
+ int ndim = static_cast<int>(tensor->shape()->size());
-Status SchemaMetadata::GetDictionaryTypes(DictionaryTypeMap* id_to_field) const {
- return impl_->GetDictionaryTypes(id_to_field);
-}
+ for (int i = 0; i < ndim; ++i) {
+ auto dim = tensor->shape()->Get(i);
-Status SchemaMetadata::GetSchema(
- const DictionaryMemo& dictionary_memo, std::shared_ptr<Schema>* out) const {
- std::vector<std::shared_ptr<Field>> fields(num_fields());
- for (int i = 0; i < this->num_fields(); ++i) {
- const flatbuf::Field* field = impl_->get_field(i);
- RETURN_NOT_OK(FieldFromFlatbuffer(field, dictionary_memo, &fields[i]));
+ shape->push_back(dim->size());
+ auto fb_name = dim->name();
+ if (fb_name == 0) {
+ dim_names->push_back("");
+ } else {
+ dim_names->push_back(fb_name->str());
+ }
}
- *out = std::make_shared<Schema>(fields);
- return Status::OK();
+
+ if (tensor->strides()->size() > 0) {
+ for (int i = 0; i < ndim; ++i) {
+ strides->push_back(tensor->strides()->Get(i));
+ }
+ }
+
+ return TypeFromFlatbuffer(tensor->type_type(), tensor->type(), {}, type);
}
// ----------------------------------------------------------------------
-// Conveniences
+// Read and write messages
Status ReadMessage(int64_t offset, int32_t metadata_length, io::RandomAccessFile* file,
std::shared_ptr<Message>* message) {
@@ -896,5 +954,61 @@ Status ReadMessage(int64_t offset, int32_t metadata_length, io::RandomAccessFile
return Message::Open(buffer, 4, message);
}
+Status ReadMessage(io::InputStream* file, std::shared_ptr<Message>* message) {
+ std::shared_ptr<Buffer> buffer;
+ RETURN_NOT_OK(file->Read(sizeof(int32_t), &buffer));
+
+ if (buffer->size() != sizeof(int32_t)) {
+ *message = nullptr;
+ return Status::OK();
+ }
+
+ int32_t message_length = *reinterpret_cast<const int32_t*>(buffer->data());
+
+ if (message_length == 0) {
+ // Optional 0 EOS control message
+ *message = nullptr;
+ return Status::OK();
+ }
+
+ RETURN_NOT_OK(file->Read(message_length, &buffer));
+ if (buffer->size() != message_length) {
+ return Status::IOError("Unexpected end of stream trying to read message");
+ }
+
+ return Message::Open(buffer, 0, message);
+}
+
+Status WriteMessage(
+ const Buffer& message, io::OutputStream* file, int32_t* message_length) {
+ // Need to write 4 bytes (message size), the message, plus padding to
+ // end on an 8-byte offset
+ int64_t start_offset;
+ RETURN_NOT_OK(file->Tell(&start_offset));
+
+ int32_t padded_message_length = static_cast<int32_t>(message.size()) + 4;
+ const int32_t remainder =
+ (padded_message_length + static_cast<int32_t>(start_offset)) % 8;
+ if (remainder != 0) { padded_message_length += 8 - remainder; }
+
+ // The returned message size includes the length prefix, the flatbuffer,
+ // plus padding
+ *message_length = padded_message_length;
+
+ // Write the flatbuffer size prefix including padding
+ int32_t flatbuffer_size = padded_message_length - 4;
+ RETURN_NOT_OK(
+ file->Write(reinterpret_cast<const uint8_t*>(&flatbuffer_size), sizeof(int32_t)));
+
+ // Write the flatbuffer
+ RETURN_NOT_OK(file->Write(message.data(), message.size()));
+
+ // Write any padding
+ int32_t padding = padded_message_length - static_cast<int32_t>(message.size()) - 4;
+ if (padding > 0) { RETURN_NOT_OK(file->Write(kPaddingBytes, padding)); }
+
+ return Status::OK();
+}
+
} // namespace ipc
} // namespace arrow
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/ipc/metadata.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/ipc/metadata.h b/cpp/src/arrow/ipc/metadata.h
index 798abdc..fac4a70 100644
--- a/cpp/src/arrow/ipc/metadata.h
+++ b/cpp/src/arrow/ipc/metadata.h
@@ -22,6 +22,7 @@
#include <cstdint>
#include <memory>
+#include <string>
#include <unordered_map>
#include <vector>
@@ -37,9 +38,11 @@ struct DataType;
struct Field;
class Schema;
class Status;
+class Tensor;
namespace io {
+class InputStream;
class OutputStream;
class RandomAccessFile;
@@ -53,7 +56,7 @@ struct MetadataVersion {
static constexpr const char* kArrowMagicBytes = "ARROW1";
-struct ARROW_EXPORT FileBlock {
+struct FileBlock {
FileBlock() {}
FileBlock(int64_t offset, int32_t metadata_length, int64_t body_length)
: offset(offset), metadata_length(metadata_length), body_length(body_length) {}
@@ -104,44 +107,25 @@ class DictionaryMemo {
class Message;
-// Container for serialized Schema metadata contained in an IPC message
-class ARROW_EXPORT SchemaMetadata {
- public:
- explicit SchemaMetadata(const void* header);
- explicit SchemaMetadata(const std::shared_ptr<Message>& message);
- SchemaMetadata(const std::shared_ptr<Buffer>& message, int64_t offset);
-
- ~SchemaMetadata();
-
- int num_fields() const;
-
- // Retrieve a list of all the dictionary ids and types required by the schema for
- // reconstruction. The presumption is that these will be loaded either from
- // the stream or file (or they may already be somewhere else in memory)
- Status GetDictionaryTypes(DictionaryTypeMap* id_to_field) const;
+// Retrieve a list of all the dictionary ids and types required by the schema for
+// reconstruction. The presumption is that these will be loaded either from
+// the stream or file (or they may already be somewhere else in memory)
+Status GetDictionaryTypes(const void* opaque_schema, DictionaryTypeMap* id_to_field);
- // Construct a complete Schema from the message. May be expensive for very
- // large schemas if you are only interested in a few fields
- Status GetSchema(
- const DictionaryMemo& dictionary_memo, std::shared_ptr<Schema>* out) const;
-
- private:
- class SchemaMetadataImpl;
- std::unique_ptr<SchemaMetadataImpl> impl_;
-
- DISALLOW_COPY_AND_ASSIGN(SchemaMetadata);
-};
+// Construct a complete Schema from the message. May be expensive for very
+// large schemas if you are only interested in a few fields
+Status GetSchema(const void* opaque_schema, const DictionaryMemo& dictionary_memo,
+ std::shared_ptr<Schema>* out);
-struct ARROW_EXPORT BufferMetadata {
- int32_t page;
- int64_t offset;
- int64_t length;
-};
+Status GetTensorMetadata(const void* opaque_tensor, std::shared_ptr<DataType>* type,
+ std::vector<int64_t>* shape, std::vector<int64_t>* strides,
+ std::vector<std::string>* dim_names);
class ARROW_EXPORT Message {
public:
+ enum Type { NONE, SCHEMA, DICTIONARY_BATCH, RECORD_BATCH, TENSOR };
+
~Message();
- enum Type { NONE, SCHEMA, DICTIONARY_BATCH, RECORD_BATCH };
static Status Open(const std::shared_ptr<Buffer>& buffer, int64_t offset,
std::shared_ptr<Message>* out);
@@ -155,9 +139,6 @@ class ARROW_EXPORT Message {
private:
Message(const std::shared_ptr<Buffer>& buffer, int64_t offset);
- friend class DictionaryBatchMetadata;
- friend class SchemaMetadata;
-
// Hide serialization details from user API
class MessageImpl;
std::unique_ptr<MessageImpl> impl_;
@@ -179,6 +160,17 @@ class ARROW_EXPORT Message {
Status ReadMessage(int64_t offset, int32_t metadata_length, io::RandomAccessFile* file,
std::shared_ptr<Message>* message);
+/// Read length-prefixed message with as-yet unknown length. Returns nullptr if
+/// there are not enough bytes available or the message length is 0 (e.g. EOS
+/// in a stream)
+Status ReadMessage(io::InputStream* stream, std::shared_ptr<Message>* message);
+
+/// Write a serialized message with a length-prefix and padding to an 8-byte offset
+///
+/// <message_size: int32><message: const void*><padding>
+Status WriteMessage(
+ const Buffer& message, io::OutputStream* file, int32_t* message_length);
+
// Serialize arrow::Schema as a Flatbuffer
//
// \param[in] schema a Schema instance
@@ -193,6 +185,9 @@ Status WriteRecordBatchMessage(int64_t length, int64_t body_length,
const std::vector<FieldMetadata>& nodes, const std::vector<BufferMetadata>& buffers,
std::shared_ptr<Buffer>* out);
+Status WriteTensorMessage(
+ const Tensor& tensor, int64_t buffer_start_offset, std::shared_ptr<Buffer>* out);
+
Status WriteDictionaryMessage(int64_t id, int64_t length, int64_t body_length,
const std::vector<FieldMetadata>& nodes, const std::vector<BufferMetadata>& buffers,
std::shared_ptr<Buffer>* out);
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/ipc/reader.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc
index 28320d9..b47b773 100644
--- a/cpp/src/arrow/ipc/reader.cc
+++ b/cpp/src/arrow/ipc/reader.cc
@@ -33,6 +33,7 @@
#include "arrow/status.h"
#include "arrow/table.h"
#include "arrow/type.h"
+#include "arrow/tensor.h"
#include "arrow/util/logging.h"
namespace arrow {
@@ -186,28 +187,9 @@ class StreamReader::StreamReaderImpl {
}
Status ReadNextMessage(Message::Type expected_type, std::shared_ptr<Message>* message) {
- std::shared_ptr<Buffer> buffer;
- RETURN_NOT_OK(stream_->Read(sizeof(int32_t), &buffer));
-
- if (buffer->size() != sizeof(int32_t)) {
- *message = nullptr;
- return Status::OK();
- }
-
- int32_t message_length = *reinterpret_cast<const int32_t*>(buffer->data());
-
- if (message_length == 0) {
- // Optional 0 EOS control message
- *message = nullptr;
- return Status::OK();
- }
-
- RETURN_NOT_OK(stream_->Read(message_length, &buffer));
- if (buffer->size() != message_length) {
- return Status::IOError("Unexpected end of stream trying to read message");
- }
+ RETURN_NOT_OK(ReadMessage(stream_.get(), message));
- RETURN_NOT_OK(Message::Open(buffer, 0, message));
+ if ((*message) == nullptr) { return Status::OK(); }
if ((*message)->type() != expected_type) {
std::stringstream ss;
@@ -245,8 +227,7 @@ class StreamReader::StreamReaderImpl {
std::shared_ptr<Message> message;
RETURN_NOT_OK(ReadNextMessage(Message::SCHEMA, &message));
- SchemaMetadata schema_meta(message);
- RETURN_NOT_OK(schema_meta.GetDictionaryTypes(&dictionary_types_));
+ RETURN_NOT_OK(GetDictionaryTypes(message->header(), &dictionary_types_));
// TODO(wesm): In future, we may want to reconcile the ids in the stream with
// those found in the schema
@@ -255,7 +236,7 @@ class StreamReader::StreamReaderImpl {
RETURN_NOT_OK(ReadNextDictionary());
}
- return schema_meta.GetSchema(dictionary_memo_, &schema_);
+ return GetSchema(message->header(), dictionary_memo_, &schema_);
}
Status GetNextRecordBatch(std::shared_ptr<RecordBatch>* batch) {
@@ -343,7 +324,6 @@ class FileReader::FileReaderImpl {
// TODO(wesm): Verify the footer
footer_ = flatbuf::GetFooter(footer_buffer_->data());
- schema_metadata_.reset(new SchemaMetadata(footer_->schema()));
return Status::OK();
}
@@ -372,8 +352,6 @@ class FileReader::FileReaderImpl {
return FileBlockFromFlatbuffer(footer_->dictionaries()->Get(i));
}
- const SchemaMetadata& schema_metadata() const { return *schema_metadata_; }
-
Status GetRecordBatch(int i, std::shared_ptr<RecordBatch>* batch) {
DCHECK_GE(i, 0);
DCHECK_LT(i, num_record_batches());
@@ -393,7 +371,7 @@ class FileReader::FileReaderImpl {
}
Status ReadSchema() {
- RETURN_NOT_OK(schema_metadata_->GetDictionaryTypes(&dictionary_fields_));
+ RETURN_NOT_OK(GetDictionaryTypes(footer_->schema(), &dictionary_fields_));
// Read all the dictionaries
for (int i = 0; i < num_dictionaries(); ++i) {
@@ -419,7 +397,7 @@ class FileReader::FileReaderImpl {
}
// Get the schema
- return schema_metadata_->GetSchema(*dictionary_memo_, &schema_);
+ return GetSchema(footer_->schema(), *dictionary_memo_, &schema_);
}
Status Open(const std::shared_ptr<io::RandomAccessFile>& file, int64_t footer_offset) {
@@ -441,7 +419,6 @@ class FileReader::FileReaderImpl {
// Footer metadata
std::shared_ptr<Buffer> footer_buffer_;
const flatbuf::Footer* footer_;
- std::unique_ptr<SchemaMetadata> schema_metadata_;
DictionaryTypeMap dictionary_fields_;
std::shared_ptr<DictionaryMemo> dictionary_memo_;
@@ -485,26 +462,46 @@ Status FileReader::GetRecordBatch(int i, std::shared_ptr<RecordBatch>* batch) {
return impl_->GetRecordBatch(i, batch);
}
-Status ReadRecordBatch(const std::shared_ptr<Schema>& schema, int64_t offset,
- io::RandomAccessFile* file, std::shared_ptr<RecordBatch>* out) {
+static Status ReadContiguousPayload(int64_t offset, io::RandomAccessFile* file,
+ std::shared_ptr<Message>* message, std::shared_ptr<Buffer>* payload) {
std::shared_ptr<Buffer> buffer;
RETURN_NOT_OK(file->Seek(offset));
+ RETURN_NOT_OK(ReadMessage(file, message));
- RETURN_NOT_OK(file->Read(sizeof(int32_t), &buffer));
- int32_t flatbuffer_size = *reinterpret_cast<const int32_t*>(buffer->data());
-
- std::shared_ptr<Message> message;
- RETURN_NOT_OK(file->Read(flatbuffer_size, &buffer));
- RETURN_NOT_OK(Message::Open(buffer, 0, &message));
+ if (*message == nullptr) {
+ return Status::Invalid("Unable to read metadata at offset");
+ }
// TODO(ARROW-388): The buffer offsets start at 0, so we must construct a
// RandomAccessFile according to that frame of reference
- std::shared_ptr<Buffer> buffer_payload;
- RETURN_NOT_OK(file->Read(message->body_length(), &buffer_payload));
- io::BufferReader buffer_reader(buffer_payload);
+ RETURN_NOT_OK(file->Read((*message)->body_length(), payload));
+ return Status::OK();
+}
+Status ReadRecordBatch(const std::shared_ptr<Schema>& schema, int64_t offset,
+ io::RandomAccessFile* file, std::shared_ptr<RecordBatch>* out) {
+ std::shared_ptr<Buffer> payload;
+ std::shared_ptr<Message> message;
+
+ RETURN_NOT_OK(ReadContiguousPayload(offset, file, &message, &payload));
+ io::BufferReader buffer_reader(payload);
return ReadRecordBatch(*message, schema, kMaxNestingDepth, &buffer_reader, out);
}
+Status ReadTensor(
+ int64_t offset, io::RandomAccessFile* file, std::shared_ptr<Tensor>* out) {
+ std::shared_ptr<Message> message;
+ std::shared_ptr<Buffer> data;
+ RETURN_NOT_OK(ReadContiguousPayload(offset, file, &message, &data));
+
+ std::shared_ptr<DataType> type;
+ std::vector<int64_t> shape;
+ std::vector<int64_t> strides;
+ std::vector<std::string> dim_names;
+ RETURN_NOT_OK(
+ GetTensorMetadata(message->header(), &type, &shape, &strides, &dim_names));
+ return MakeTensor(type, data, shape, strides, dim_names, out);
+}
+
} // namespace ipc
} // namespace arrow
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/ipc/reader.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/ipc/reader.h b/cpp/src/arrow/ipc/reader.h
index 6d9e6ca..b62f052 100644
--- a/cpp/src/arrow/ipc/reader.h
+++ b/cpp/src/arrow/ipc/reader.h
@@ -17,8 +17,8 @@
// Implement Arrow file layout for IPC/RPC purposes and short-lived storage
-#ifndef ARROW_IPC_FILE_H
-#define ARROW_IPC_FILE_H
+#ifndef ARROW_IPC_READER_H
+#define ARROW_IPC_READER_H
#include <cstdint>
#include <memory>
@@ -33,6 +33,7 @@ class Buffer;
class RecordBatch;
class Schema;
class Status;
+class Tensor;
namespace io {
@@ -43,18 +44,6 @@ class RandomAccessFile;
namespace ipc {
-// Generic read functionsh; does not copy data if the input supports zero copy reads
-
-Status ReadRecordBatch(const Message& metadata, const std::shared_ptr<Schema>& schema,
- io::RandomAccessFile* file, std::shared_ptr<RecordBatch>* out);
-
-Status ReadRecordBatch(const Message& metadata, const std::shared_ptr<Schema>& schema,
- int max_recursion_depth, io::RandomAccessFile* file,
- std::shared_ptr<RecordBatch>* out);
-
-Status ReadDictionary(const Message& metadata, const DictionaryTypeMap& dictionary_types,
- io::RandomAccessFile* file, std::shared_ptr<Array>* out);
-
class ARROW_EXPORT StreamReader {
public:
~StreamReader();
@@ -118,11 +107,24 @@ class ARROW_EXPORT FileReader {
std::unique_ptr<FileReaderImpl> impl_;
};
+// Generic read functionsh; does not copy data if the input supports zero copy reads
+Status ARROW_EXPORT ReadRecordBatch(const Message& metadata,
+ const std::shared_ptr<Schema>& schema, io::RandomAccessFile* file,
+ std::shared_ptr<RecordBatch>* out);
+
+Status ARROW_EXPORT ReadRecordBatch(const Message& metadata,
+ const std::shared_ptr<Schema>& schema, int max_recursion_depth,
+ io::RandomAccessFile* file, std::shared_ptr<RecordBatch>* out);
+
/// Read encapsulated message and RecordBatch
Status ARROW_EXPORT ReadRecordBatch(const std::shared_ptr<Schema>& schema, int64_t offset,
io::RandomAccessFile* file, std::shared_ptr<RecordBatch>* out);
+/// EXPERIMENTAL: Read arrow::Tensor from a contiguous message
+Status ARROW_EXPORT ReadTensor(
+ int64_t offset, io::RandomAccessFile* file, std::shared_ptr<Tensor>* out);
+
} // namespace ipc
} // namespace arrow
-#endif // ARROW_IPC_FILE_H
+#endif // ARROW_IPC_READER_H
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/ipc/writer.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc
index 0a19f69..249ef20 100644
--- a/cpp/src/arrow/ipc/writer.cc
+++ b/cpp/src/arrow/ipc/writer.cc
@@ -34,6 +34,7 @@
#include "arrow/memory_pool.h"
#include "arrow/status.h"
#include "arrow/table.h"
+#include "arrow/tensor.h"
#include "arrow/type.h"
#include "arrow/util/bit-util.h"
#include "arrow/util/logging.h"
@@ -143,46 +144,6 @@ class RecordBatchWriter : public ArrayVisitor {
num_rows, body_length, field_nodes_, buffer_meta_, out);
}
- Status WriteMetadata(int64_t num_rows, int64_t body_length, io::OutputStream* dst,
- int32_t* metadata_length) {
- // Now that we have computed the locations of all of the buffers in shared
- // memory, the data header can be converted to a flatbuffer and written out
- //
- // Note: The memory written here is prefixed by the size of the flatbuffer
- // itself as an int32_t.
- std::shared_ptr<Buffer> metadata_fb;
- RETURN_NOT_OK(WriteMetadataMessage(num_rows, body_length, &metadata_fb));
-
- // Need to write 4 bytes (metadata size), the metadata, plus padding to
- // end on an 8-byte offset
- int64_t start_offset;
- RETURN_NOT_OK(dst->Tell(&start_offset));
-
- int32_t padded_metadata_length = static_cast<int32_t>(metadata_fb->size()) + 4;
- const int32_t remainder =
- (padded_metadata_length + static_cast<int32_t>(start_offset)) % 8;
- if (remainder != 0) { padded_metadata_length += 8 - remainder; }
-
- // The returned metadata size includes the length prefix, the flatbuffer,
- // plus padding
- *metadata_length = padded_metadata_length;
-
- // Write the flatbuffer size prefix including padding
- int32_t flatbuffer_size = padded_metadata_length - 4;
- RETURN_NOT_OK(
- dst->Write(reinterpret_cast<const uint8_t*>(&flatbuffer_size), sizeof(int32_t)));
-
- // Write the flatbuffer
- RETURN_NOT_OK(dst->Write(metadata_fb->data(), metadata_fb->size()));
-
- // Write any padding
- int32_t padding =
- padded_metadata_length - static_cast<int32_t>(metadata_fb->size()) - 4;
- if (padding > 0) { RETURN_NOT_OK(dst->Write(kPaddingBytes, padding)); }
-
- return Status::OK();
- }
-
Status Write(const RecordBatch& batch, io::OutputStream* dst, int32_t* metadata_length,
int64_t* body_length) {
RETURN_NOT_OK(Assemble(batch, body_length));
@@ -192,7 +153,14 @@ class RecordBatchWriter : public ArrayVisitor {
RETURN_NOT_OK(dst->Tell(&start_position));
#endif
- RETURN_NOT_OK(WriteMetadata(batch.num_rows(), *body_length, dst, metadata_length));
+ // Now that we have computed the locations of all of the buffers in shared
+ // memory, the data header can be converted to a flatbuffer and written out
+ //
+ // Note: The memory written here is prefixed by the size of the flatbuffer
+ // itself as an int32_t.
+ std::shared_ptr<Buffer> metadata_fb;
+ RETURN_NOT_OK(WriteMetadataMessage(batch.num_rows(), *body_length, &metadata_fb));
+ RETURN_NOT_OK(WriteMessage(*metadata_fb, dst, metadata_length));
#ifndef NDEBUG
RETURN_NOT_OK(dst->Tell(¤t_position));
@@ -504,6 +472,28 @@ Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset,
return writer.Write(batch, dst, metadata_length, body_length);
}
+Status WriteLargeRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset,
+ io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length,
+ MemoryPool* pool) {
+ return WriteRecordBatch(batch, buffer_start_offset, dst, metadata_length, body_length,
+ pool, kMaxNestingDepth, true);
+}
+
+Status WriteTensor(const Tensor& tensor, io::OutputStream* dst, int32_t* metadata_length,
+ int64_t* body_length) {
+ std::shared_ptr<Buffer> metadata;
+ RETURN_NOT_OK(WriteTensorMessage(tensor, 0, &metadata));
+ RETURN_NOT_OK(WriteMessage(*metadata, dst, metadata_length));
+ auto data = tensor.data();
+ if (data) {
+ *body_length = data->size();
+ return dst->Write(data->data(), *body_length);
+ } else {
+ *body_length = 0;
+ return Status::OK();
+ }
+}
+
Status WriteDictionary(int64_t dictionary_id, const std::shared_ptr<Array>& dictionary,
int64_t buffer_start_offset, io::OutputStream* dst, int32_t* metadata_length,
int64_t* body_length, MemoryPool* pool) {
@@ -736,12 +726,5 @@ Status FileWriter::Close() {
return impl_->Close();
}
-Status WriteLargeRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset,
- io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length,
- MemoryPool* pool) {
- return WriteRecordBatch(batch, buffer_start_offset, dst, metadata_length, body_length,
- pool, kMaxNestingDepth, true);
-}
-
} // namespace ipc
} // namespace arrow
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/ipc/writer.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h
index c572157..8b2dc9c 100644
--- a/cpp/src/arrow/ipc/writer.h
+++ b/cpp/src/arrow/ipc/writer.h
@@ -17,8 +17,8 @@
// Implement Arrow streaming binary format
-#ifndef ARROW_IPC_STREAM_H
-#define ARROW_IPC_STREAM_H
+#ifndef ARROW_IPC_WRITER_H
+#define ARROW_IPC_WRITER_H
#include <cstdint>
#include <memory>
@@ -36,6 +36,7 @@ class MemoryPool;
class RecordBatch;
class Schema;
class Status;
+class Tensor;
namespace io {
@@ -125,7 +126,12 @@ Status WriteLargeRecordBatch(const RecordBatch& batch, int64_t buffer_start_offs
io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length,
MemoryPool* pool);
+/// EXPERIMENTAL: Write arrow::Tensor as a contiguous message
+/// <metadata size><metadata><tensor data>
+Status WriteTensor(const Tensor& tensor, io::OutputStream* dst, int32_t* metadata_length,
+ int64_t* body_length);
+
} // namespace ipc
} // namespace arrow
-#endif // ARROW_IPC_STREAM_H
+#endif // ARROW_IPC_WRITER_H
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/tensor-test.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/tensor-test.cc b/cpp/src/arrow/tensor-test.cc
index 99a9493..336905c 100644
--- a/cpp/src/arrow/tensor-test.cc
+++ b/cpp/src/arrow/tensor-test.cc
@@ -61,13 +61,36 @@ TEST(TestTensor, BasicCtors) {
ASSERT_EQ(24, t1.size());
ASSERT_TRUE(t1.is_mutable());
- ASSERT_FALSE(t1.has_dim_names());
ASSERT_EQ(strides, t1.strides());
ASSERT_EQ(strides, t2.strides());
ASSERT_EQ("foo", t3.dim_name(0));
ASSERT_EQ("bar", t3.dim_name(1));
+ ASSERT_EQ("", t1.dim_name(0));
+ ASSERT_EQ("", t1.dim_name(1));
+}
+
+TEST(TestTensor, IsContiguous) {
+ const int64_t values = 24;
+ std::vector<int64_t> shape = {4, 6};
+ std::vector<int64_t> strides = {48, 8};
+
+ using T = int64_t;
+
+ std::shared_ptr<MutableBuffer> buffer;
+ ASSERT_OK(AllocateBuffer(default_memory_pool(), values * sizeof(T), &buffer));
+
+ std::vector<int64_t> c_strides = {48, 8};
+ std::vector<int64_t> f_strides = {8, 32};
+ std::vector<int64_t> noncontig_strides = {8, 8};
+ Int64Tensor t1(buffer, shape, c_strides);
+ Int64Tensor t2(buffer, shape, f_strides);
+ Int64Tensor t3(buffer, shape, noncontig_strides);
+
+ ASSERT_TRUE(t1.is_contiguous());
+ ASSERT_TRUE(t2.is_contiguous());
+ ASSERT_FALSE(t3.is_contiguous());
}
} // namespace arrow
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/tensor.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/tensor.cc b/cpp/src/arrow/tensor.cc
index 7c4593f..9a8de51 100644
--- a/cpp/src/arrow/tensor.cc
+++ b/cpp/src/arrow/tensor.cc
@@ -27,14 +27,15 @@
#include "arrow/array.h"
#include "arrow/buffer.h"
+#include "arrow/compare.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/logging.h"
namespace arrow {
-void ComputeRowMajorStrides(const FixedWidthType& type, const std::vector<int64_t>& shape,
- std::vector<int64_t>* strides) {
+static void ComputeRowMajorStrides(const FixedWidthType& type,
+ const std::vector<int64_t>& shape, std::vector<int64_t>* strides) {
int64_t remaining = type.bit_width() / 8;
for (int64_t dimsize : shape) {
remaining *= dimsize;
@@ -46,6 +47,15 @@ void ComputeRowMajorStrides(const FixedWidthType& type, const std::vector<int64_
}
}
+static void ComputeColumnMajorStrides(const FixedWidthType& type,
+ const std::vector<int64_t>& shape, std::vector<int64_t>* strides) {
+ int64_t total = type.bit_width() / 8;
+ for (int64_t dimsize : shape) {
+ strides->push_back(total);
+ total *= dimsize;
+ }
+}
+
/// Constructor with strides and dimension names
Tensor::Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
const std::vector<int64_t>& shape, const std::vector<int64_t>& strides,
@@ -66,14 +76,36 @@ Tensor::Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buff
: Tensor(type, data, shape, {}, {}) {}
const std::string& Tensor::dim_name(int i) const {
- DCHECK_LT(i, static_cast<int>(dim_names_.size()));
- return dim_names_[i];
+ static const std::string kEmpty = "";
+ if (dim_names_.size() == 0) {
+ return kEmpty;
+ } else {
+ DCHECK_LT(i, static_cast<int>(dim_names_.size()));
+ return dim_names_[i];
+ }
}
int64_t Tensor::size() const {
return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int64_t>());
}
+bool Tensor::is_contiguous() const {
+ std::vector<int64_t> c_strides;
+ std::vector<int64_t> f_strides;
+
+ const auto& fw_type = static_cast<const FixedWidthType&>(*type_);
+ ComputeRowMajorStrides(fw_type, shape_, &c_strides);
+ ComputeColumnMajorStrides(fw_type, shape_, &f_strides);
+ return strides_ == c_strides || strides_ == f_strides;
+}
+
+bool Tensor::Equals(const Tensor& other) const {
+ bool are_equal = false;
+ Status error = TensorEquals(*this, other, &are_equal);
+ if (!error.ok()) { DCHECK(false) << "Tensors not comparable: " << error.ToString(); }
+ return are_equal;
+}
+
template <typename T>
NumericTensor<T>::NumericTensor(const std::shared_ptr<Buffer>& data,
const std::vector<int64_t>& shape, const std::vector<int64_t>& strides,
@@ -112,4 +144,31 @@ template class ARROW_TEMPLATE_EXPORT NumericTensor<HalfFloatType>;
template class ARROW_TEMPLATE_EXPORT NumericTensor<FloatType>;
template class ARROW_TEMPLATE_EXPORT NumericTensor<DoubleType>;
+#define TENSOR_CASE(TYPE, TENSOR_TYPE) \
+ case Type::TYPE: \
+ *tensor = std::make_shared<TENSOR_TYPE>(data, shape, strides, dim_names); \
+ break;
+
+Status ARROW_EXPORT MakeTensor(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& strides, const std::vector<std::string>& dim_names,
+ std::shared_ptr<Tensor>* tensor) {
+ switch (type->type) {
+ TENSOR_CASE(INT8, Int8Tensor);
+ TENSOR_CASE(INT16, Int16Tensor);
+ TENSOR_CASE(INT32, Int32Tensor);
+ TENSOR_CASE(INT64, Int64Tensor);
+ TENSOR_CASE(UINT8, UInt8Tensor);
+ TENSOR_CASE(UINT16, UInt16Tensor);
+ TENSOR_CASE(UINT32, UInt32Tensor);
+ TENSOR_CASE(UINT64, UInt64Tensor);
+ TENSOR_CASE(HALF_FLOAT, HalfFloatTensor);
+ TENSOR_CASE(FLOAT, FloatTensor);
+ TENSOR_CASE(DOUBLE, DoubleTensor);
+ default:
+ return Status::NotImplemented(type->ToString());
+ }
+ return Status::OK();
+}
+
} // namespace arrow
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/tensor.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/tensor.h b/cpp/src/arrow/tensor.h
index 7bee867..eeb5c3e 100644
--- a/cpp/src/arrow/tensor.h
+++ b/cpp/src/arrow/tensor.h
@@ -73,12 +73,15 @@ class ARROW_EXPORT Tensor {
const std::vector<int64_t>& shape, const std::vector<int64_t>& strides,
const std::vector<std::string>& dim_names);
+ std::shared_ptr<DataType> type() const { return type_; }
std::shared_ptr<Buffer> data() const { return data_; }
+
const std::vector<int64_t>& shape() const { return shape_; }
const std::vector<int64_t>& strides() const { return strides_; }
+ int ndim() const { return static_cast<int>(shape_.size()); }
+
const std::string& dim_name(int i) const;
- bool has_dim_names() const { return shape_.size() > 0 && dim_names_.size() > 0; }
/// Total number of value cells in the tensor
int64_t size() const;
@@ -86,13 +89,17 @@ class ARROW_EXPORT Tensor {
/// Return true if the underlying data buffer is mutable
bool is_mutable() const { return data_->is_mutable(); }
+ bool is_contiguous() const;
+
+ Type::type type_enum() const { return type_->type; }
+
+ bool Equals(const Tensor& other) const;
+
protected:
Tensor() {}
std::shared_ptr<DataType> type_;
-
std::shared_ptr<Buffer> data_;
-
std::vector<int64_t> shape_;
std::vector<int64_t> strides_;
@@ -126,6 +133,11 @@ class ARROW_EXPORT NumericTensor : public Tensor {
value_type* mutable_raw_data_;
};
+Status ARROW_EXPORT MakeTensor(const std::shared_ptr<DataType>& type,
+ const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& strides, const std::vector<std::string>& dim_names,
+ std::shared_ptr<Tensor>* tensor);
+
// ----------------------------------------------------------------------
// extern templates and other details
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/type_traits.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h
index 1270aee..b73d5a6 100644
--- a/cpp/src/arrow/type_traits.h
+++ b/cpp/src/arrow/type_traits.h
@@ -38,6 +38,7 @@ template <>
struct TypeTraits<UInt8Type> {
using ArrayType = UInt8Array;
using BuilderType = UInt8Builder;
+ using TensorType = UInt8Tensor;
static inline int64_t bytes_required(int64_t elements) { return elements; }
constexpr static bool is_parameter_free = true;
static inline std::shared_ptr<DataType> type_singleton() { return uint8(); }
@@ -47,6 +48,7 @@ template <>
struct TypeTraits<Int8Type> {
using ArrayType = Int8Array;
using BuilderType = Int8Builder;
+ using TensorType = Int8Tensor;
static inline int64_t bytes_required(int64_t elements) { return elements; }
constexpr static bool is_parameter_free = true;
static inline std::shared_ptr<DataType> type_singleton() { return int8(); }
@@ -56,6 +58,7 @@ template <>
struct TypeTraits<UInt16Type> {
using ArrayType = UInt16Array;
using BuilderType = UInt16Builder;
+ using TensorType = UInt16Tensor;
static inline int64_t bytes_required(int64_t elements) {
return elements * sizeof(uint16_t);
@@ -68,6 +71,7 @@ template <>
struct TypeTraits<Int16Type> {
using ArrayType = Int16Array;
using BuilderType = Int16Builder;
+ using TensorType = Int16Tensor;
static inline int64_t bytes_required(int64_t elements) {
return elements * sizeof(int16_t);
@@ -80,6 +84,7 @@ template <>
struct TypeTraits<UInt32Type> {
using ArrayType = UInt32Array;
using BuilderType = UInt32Builder;
+ using TensorType = UInt32Tensor;
static inline int64_t bytes_required(int64_t elements) {
return elements * sizeof(uint32_t);
@@ -92,6 +97,7 @@ template <>
struct TypeTraits<Int32Type> {
using ArrayType = Int32Array;
using BuilderType = Int32Builder;
+ using TensorType = Int32Tensor;
static inline int64_t bytes_required(int64_t elements) {
return elements * sizeof(int32_t);
@@ -104,6 +110,7 @@ template <>
struct TypeTraits<UInt64Type> {
using ArrayType = UInt64Array;
using BuilderType = UInt64Builder;
+ using TensorType = UInt64Tensor;
static inline int64_t bytes_required(int64_t elements) {
return elements * sizeof(uint64_t);
@@ -116,6 +123,7 @@ template <>
struct TypeTraits<Int64Type> {
using ArrayType = Int64Array;
using BuilderType = Int64Builder;
+ using TensorType = Int64Tensor;
static inline int64_t bytes_required(int64_t elements) {
return elements * sizeof(int64_t);
@@ -185,6 +193,7 @@ template <>
struct TypeTraits<HalfFloatType> {
using ArrayType = HalfFloatArray;
using BuilderType = HalfFloatBuilder;
+ using TensorType = HalfFloatTensor;
static inline int64_t bytes_required(int64_t elements) {
return elements * sizeof(uint16_t);
@@ -197,6 +206,7 @@ template <>
struct TypeTraits<FloatType> {
using ArrayType = FloatArray;
using BuilderType = FloatBuilder;
+ using TensorType = FloatTensor;
static inline int64_t bytes_required(int64_t elements) {
return static_cast<int64_t>(elements * sizeof(float));
@@ -209,6 +219,7 @@ template <>
struct TypeTraits<DoubleType> {
using ArrayType = DoubleArray;
using BuilderType = DoubleBuilder;
+ using TensorType = DoubleTensor;
static inline int64_t bytes_required(int64_t elements) {
return static_cast<int64_t>(elements * sizeof(double));
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/visitor_inline.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/visitor_inline.h b/cpp/src/arrow/visitor_inline.h
index 586b123..cbc4d5a 100644
--- a/cpp/src/arrow/visitor_inline.h
+++ b/cpp/src/arrow/visitor_inline.h
@@ -22,6 +22,7 @@
#include "arrow/array.h"
#include "arrow/status.h"
+#include "arrow/tensor.h"
#include "arrow/type.h"
namespace arrow {
@@ -103,6 +104,31 @@ inline Status VisitArrayInline(const Array& array, VISITOR* visitor) {
return Status::NotImplemented("Type not implemented");
}
+#define TENSOR_VISIT_INLINE(TYPE_CLASS) \
+ case TYPE_CLASS::type_id: \
+ return visitor->Visit( \
+ static_cast<const typename TypeTraits<TYPE_CLASS>::TensorType&>(array));
+
+template <typename VISITOR>
+inline Status VisitTensorInline(const Tensor& array, VISITOR* visitor) {
+ switch (array.type_enum()) {
+ TENSOR_VISIT_INLINE(Int8Type);
+ TENSOR_VISIT_INLINE(UInt8Type);
+ TENSOR_VISIT_INLINE(Int16Type);
+ TENSOR_VISIT_INLINE(UInt16Type);
+ TENSOR_VISIT_INLINE(Int32Type);
+ TENSOR_VISIT_INLINE(UInt32Type);
+ TENSOR_VISIT_INLINE(Int64Type);
+ TENSOR_VISIT_INLINE(UInt64Type);
+ TENSOR_VISIT_INLINE(HalfFloatType);
+ TENSOR_VISIT_INLINE(FloatType);
+ TENSOR_VISIT_INLINE(DoubleType);
+ default:
+ break;
+ }
+ return Status::NotImplemented("Type not implemented");
+}
+
} // namespace arrow
#endif // ARROW_VISITOR_INLINE_H
http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/format/Tensor.fbs
----------------------------------------------------------------------
diff --git a/format/Tensor.fbs b/format/Tensor.fbs
index bc5b6d1..18b614c 100644
--- a/format/Tensor.fbs
+++ b/format/Tensor.fbs
@@ -32,16 +32,6 @@ table TensorDim {
name: string;
}
-enum TensorOrder : byte {
- /// Higher dimensions vary first when traversing data in byte-contiguous
- /// order, aka "C order"
- ROW_MAJOR,
-
- /// Lower dimensions vary first when traversing data in byte-contiguous
- /// order, aka "Fortran order"
- COLUMN_MAJOR
-}
-
table Tensor {
/// The type of data contained in a value cell. Currently only fixed-width
/// value types are supported, no strings or nested types
@@ -50,8 +40,8 @@ table Tensor {
/// The dimensions of the tensor, optionally named
shape: [TensorDim];
- /// The memory order of the tensor's data
- order: TensorOrder;
+ /// Non-negative byte offsets to advance one value cell along each dimension
+ strides: [long];
/// The location and size of the tensor's data
data: Buffer;