You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ze...@apache.org on 2021/11/09 18:59:41 UTC
[arrow] 01/04: add pqarrow module
This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch temp-parquet-pqarrow
in repository https://gitbox.apache.org/repos/asf/arrow.git
commit 846a4efd4505eaf634a6a0a9574fbdd51ad32d6d
Author: Matthew Topol <mt...@factset.com>
AuthorDate: Tue Sep 14 16:53:14 2021 -0400
add pqarrow module
---
go/parquet/internal/testutils/random.go | 47 +-
go/parquet/internal/testutils/random_arrow.go | 44 +-
go/parquet/pqarrow/column_readers.go | 750 ++++++++++++++
go/parquet/pqarrow/doc.go | 21 +
go/parquet/pqarrow/encode_arrow.go | 586 +++++++++++
go/parquet/pqarrow/encode_arrow_test.go | 1379 +++++++++++++++++++++++++
go/parquet/pqarrow/file_reader.go | 686 ++++++++++++
go/parquet/pqarrow/file_reader_test.go | 177 ++++
go/parquet/pqarrow/file_writer.go | 291 ++++++
go/parquet/pqarrow/path_builder.go | 738 +++++++++++++
go/parquet/pqarrow/path_builder_test.go | 628 +++++++++++
go/parquet/pqarrow/properties.go | 171 +++
go/parquet/pqarrow/reader_writer_test.go | 335 ++++++
go/parquet/pqarrow/schema.go | 1072 +++++++++++++++++++
go/parquet/pqarrow/schema_test.go | 245 +++++
15 files changed, 7125 insertions(+), 45 deletions(-)
diff --git a/go/parquet/internal/testutils/random.go b/go/parquet/internal/testutils/random.go
index 0ed0943..08c2e70 100644
--- a/go/parquet/internal/testutils/random.go
+++ b/go/parquet/internal/testutils/random.go
@@ -28,6 +28,7 @@ import (
"github.com/apache/arrow/go/arrow/bitutil"
"github.com/apache/arrow/go/arrow/memory"
"github.com/apache/arrow/go/parquet"
+ "github.com/apache/arrow/go/parquet/pqarrow"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/stat/distuv"
@@ -427,26 +428,26 @@ func RandomByteArray(seed uint64, out []parquet.ByteArray, heap *memory.Buffer,
}
}
-// // RandomDecimals generates n random decimal values with precision determining the byte width
-// // for the values and seed as the random generator seed to allow consistency for testing. The
-// // resulting values will be either 32 bytes or 16 bytes each depending on the precision.
-// func RandomDecimals(n int64, seed uint64, precision int32) []byte {
-// r := rand.New(rand.NewSource(seed))
-// nreqBytes := pqarrow.DecimalSize(precision)
-// byteWidth := 32
-// if precision <= 38 {
-// byteWidth = 16
-// }
-
-// out := make([]byte, int(int64(byteWidth)*n))
-// for i := int64(0); i < n; i++ {
-// start := int(i) * byteWidth
-// r.Read(out[start : start+int(nreqBytes)])
-// // sign extend if the sign bit is set for the last generated byte
-// // 0b10000000 == 0x80 == 128
-// if out[start+int(nreqBytes)-1]&byte(0x80) != 0 {
-// memory.Set(out[start+int(nreqBytes):start+byteWidth], 0xFF)
-// }
-// }
-// return out
-// }
+// RandomDecimals generates n random decimal values with precision determining the byte width
+// for the values and seed as the random generator seed to allow consistency for testing. The
+// resulting values will be either 32 bytes or 16 bytes each depending on the precision.
+func RandomDecimals(n int64, seed uint64, precision int32) []byte {
+ r := rand.New(rand.NewSource(seed))
+ nreqBytes := pqarrow.DecimalSize(precision)
+ byteWidth := 32
+ if precision <= 38 {
+ byteWidth = 16
+ }
+
+ out := make([]byte, int(int64(byteWidth)*n))
+ for i := int64(0); i < n; i++ {
+ start := int(i) * byteWidth
+ r.Read(out[start : start+int(nreqBytes)])
+ // sign extend if the sign bit is set for the last generated byte
+ // 0b10000000 == 0x80 == 128
+ if out[start+int(nreqBytes)-1]&byte(0x80) != 0 {
+ memory.Set(out[start+int(nreqBytes):start+byteWidth], 0xFF)
+ }
+ }
+ return out
+}
diff --git a/go/parquet/internal/testutils/random_arrow.go b/go/parquet/internal/testutils/random_arrow.go
index c3edf6b..39f250f 100644
--- a/go/parquet/internal/testutils/random_arrow.go
+++ b/go/parquet/internal/testutils/random_arrow.go
@@ -159,14 +159,14 @@ func RandomNonNull(dt arrow.DataType, size int) array.Interface {
bldr.Append(buf)
}
return bldr.NewArray()
- // case arrow.DECIMAL:
- // dectype := dt.(*arrow.Decimal128Type)
- // bldr := array.NewDecimal128Builder(memory.DefaultAllocator, dectype)
- // defer bldr.Release()
-
- // data := RandomDecimals(int64(size), 0, dectype.Precision)
- // bldr.AppendValues(arrow.Decimal128Traits.CastFromBytes(data), nil)
- // return bldr.NewArray()
+ case arrow.DECIMAL:
+ dectype := dt.(*arrow.Decimal128Type)
+ bldr := array.NewDecimal128Builder(memory.DefaultAllocator, dectype)
+ defer bldr.Release()
+
+ data := RandomDecimals(int64(size), 0, dectype.Precision)
+ bldr.AppendValues(arrow.Decimal128Traits.CastFromBytes(data), nil)
+ return bldr.NewArray()
case arrow.BOOL:
bldr := array.NewBooleanBuilder(memory.DefaultAllocator)
defer bldr.Release()
@@ -451,22 +451,22 @@ func RandomNullable(dt arrow.DataType, size int, numNulls int) array.Interface {
bldr.Append(buf)
}
return bldr.NewArray()
- // case arrow.DECIMAL:
- // dectype := dt.(*arrow.Decimal128Type)
- // bldr := array.NewDecimal128Builder(memory.DefaultAllocator, dectype)
- // defer bldr.Release()
+ case arrow.DECIMAL:
+ dectype := dt.(*arrow.Decimal128Type)
+ bldr := array.NewDecimal128Builder(memory.DefaultAllocator, dectype)
+ defer bldr.Release()
- // valid := make([]bool, size)
- // for idx := range valid {
- // valid[idx] = true
- // }
- // for i := 0; i < numNulls; i++ {
- // valid[i*2] = false
- // }
+ valid := make([]bool, size)
+ for idx := range valid {
+ valid[idx] = true
+ }
+ for i := 0; i < numNulls; i++ {
+ valid[i*2] = false
+ }
- // data := RandomDecimals(int64(size), 0, dectype.Precision)
- // bldr.AppendValues(arrow.Decimal128Traits.CastFromBytes(data), valid)
- // return bldr.NewArray()
+ data := RandomDecimals(int64(size), 0, dectype.Precision)
+ bldr.AppendValues(arrow.Decimal128Traits.CastFromBytes(data), valid)
+ return bldr.NewArray()
case arrow.BOOL:
bldr := array.NewBooleanBuilder(memory.DefaultAllocator)
defer bldr.Release()
diff --git a/go/parquet/pqarrow/column_readers.go b/go/parquet/pqarrow/column_readers.go
new file mode 100644
index 0000000..cc79e02
--- /dev/null
+++ b/go/parquet/pqarrow/column_readers.go
@@ -0,0 +1,750 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pqarrow
+
+import (
+ "encoding/binary"
+ "reflect"
+ "sync/atomic"
+ "time"
+ "unsafe"
+
+ "github.com/apache/arrow/go/arrow"
+ "github.com/apache/arrow/go/arrow/array"
+ "github.com/apache/arrow/go/arrow/bitutil"
+ "github.com/apache/arrow/go/arrow/decimal128"
+ "github.com/apache/arrow/go/arrow/memory"
+ "github.com/apache/arrow/go/parquet"
+ "github.com/apache/arrow/go/parquet/file"
+ "github.com/apache/arrow/go/parquet/internal/utils"
+ "github.com/apache/arrow/go/parquet/schema"
+ "golang.org/x/xerrors"
+)
+
+type leafReader struct {
+ out *array.Chunked
+ rctx *readerCtx
+ field *arrow.Field
+ input *columnIterator
+ descr *schema.Column
+ recordRdr file.RecordReader
+
+ refCount int64
+}
+
+func newLeafReader(rctx *readerCtx, field *arrow.Field, input *columnIterator, leafInfo file.LevelInfo) (*ColumnReader, error) {
+ ret := &leafReader{
+ rctx: rctx,
+ field: field,
+ input: input,
+ descr: input.Descr(),
+ recordRdr: file.NewRecordReader(input.Descr(), leafInfo, field.Type.ID() == arrow.DICTIONARY, rctx.mem),
+ refCount: 1,
+ }
+ err := ret.nextRowGroup()
+ return &ColumnReader{ret}, err
+}
+
+func (lr *leafReader) Retain() {
+ atomic.AddInt64(&lr.refCount, 1)
+}
+
+func (lr *leafReader) Release() {
+ if atomic.AddInt64(&lr.refCount, -1) == 0 {
+ if lr.out != nil {
+ lr.out.Release()
+ lr.out = nil
+ }
+ if lr.recordRdr != nil {
+ lr.recordRdr.Release()
+ lr.recordRdr = nil
+ }
+ }
+}
+
+func (lr *leafReader) GetDefLevels() ([]int16, error) {
+ return lr.recordRdr.DefLevels()[:int(lr.recordRdr.LevelsPos())], nil
+}
+
+func (lr *leafReader) GetRepLevels() ([]int16, error) {
+ return lr.recordRdr.RepLevels()[:int(lr.recordRdr.LevelsPos())], nil
+}
+
+func (lr *leafReader) IsOrHasRepeatedChild() bool { return false }
+
+func (lr *leafReader) LoadBatch(nrecords int64) (err error) {
+ if lr.out != nil {
+ lr.out.Release()
+ lr.out = nil
+ }
+ lr.recordRdr.Reset()
+
+ if err := lr.recordRdr.Reserve(nrecords); err != nil {
+ return err
+ }
+ for nrecords > 0 {
+ if !lr.recordRdr.HasMore() {
+ break
+ }
+ numRead, err := lr.recordRdr.ReadRecords(nrecords)
+ if err != nil {
+ return err
+ }
+ nrecords -= numRead
+ if numRead == 0 {
+ if err = lr.nextRowGroup(); err != nil {
+ return err
+ }
+ }
+ }
+ lr.out, err = transferColumnData(lr.recordRdr, lr.field.Type, lr.descr, lr.rctx.mem)
+ return
+}
+
+func (lr *leafReader) BuildArray(_ int64) (*array.Chunked, error) {
+ return lr.out, nil
+}
+
+func (lr *leafReader) Field() *arrow.Field { return lr.field }
+
+func (lr *leafReader) nextRowGroup() error {
+ pr, err := lr.input.NextChunk()
+ if err != nil {
+ return err
+ }
+ lr.recordRdr.SetPageReader(pr)
+ return nil
+}
+
+type structReader struct {
+ rctx *readerCtx
+ filtered *arrow.Field
+ levelInfo file.LevelInfo
+ children []*ColumnReader
+ defRepLevelChild *ColumnReader
+ hasRepeatedChild bool
+
+ refCount int64
+}
+
+func (sr *structReader) Retain() {
+ atomic.AddInt64(&sr.refCount, 1)
+}
+
+func (sr *structReader) Release() {
+ if atomic.AddInt64(&sr.refCount, -1) == 0 {
+ if sr.defRepLevelChild != nil {
+ sr.defRepLevelChild.Release()
+ sr.defRepLevelChild = nil
+ }
+ for _, c := range sr.children {
+ c.Release()
+ }
+ sr.children = nil
+ }
+}
+
+func newStructReader(rctx *readerCtx, filtered *arrow.Field, levelInfo file.LevelInfo, children []*ColumnReader) *ColumnReader {
+ // there could be a mix of children some might be repeated and some might not be
+ // if possible use one that isn't since that will be guaranteed to have the least
+ // number of levels to reconstruct a nullable bitmap
+ var result *ColumnReader
+ for _, child := range children {
+ if !child.IsOrHasRepeatedChild() {
+ result = child
+ }
+ }
+
+ ret := &structReader{
+ rctx: rctx,
+ filtered: filtered,
+ levelInfo: levelInfo,
+ children: children,
+ refCount: 1,
+ }
+ if result != nil {
+ ret.defRepLevelChild = result
+ ret.hasRepeatedChild = false
+ } else {
+ ret.defRepLevelChild = children[0]
+ ret.hasRepeatedChild = true
+ }
+ ret.defRepLevelChild.Retain()
+ return &ColumnReader{ret}
+}
+
+func (sr *structReader) IsOrHasRepeatedChild() bool { return sr.hasRepeatedChild }
+
+func (sr *structReader) GetDefLevels() ([]int16, error) {
+ if len(sr.children) == 0 {
+ return nil, xerrors.New("struct raeder has no children")
+ }
+
+ // this method should only be called when this struct or one of its parents
+ // are optional/repeated or has a repeated child
+ // meaning all children must have rep/def levels associated with them
+ return sr.defRepLevelChild.GetDefLevels()
+}
+
+func (sr *structReader) GetRepLevels() ([]int16, error) {
+ if len(sr.children) == 0 {
+ return nil, xerrors.New("struct raeder has no children")
+ }
+
+ // this method should only be called when this struct or one of its parents
+ // are optional/repeated or has a repeated child
+ // meaning all children must have rep/def levels associated with them
+ return sr.defRepLevelChild.GetRepLevels()
+}
+
+func (sr *structReader) LoadBatch(nrecords int64) error {
+ for _, rdr := range sr.children {
+ if err := rdr.LoadBatch(nrecords); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (sr *structReader) Field() *arrow.Field { return sr.filtered }
+
+func (sr *structReader) BuildArray(lenBound int64) (*array.Chunked, error) {
+ validityIO := file.ValidityBitmapInputOutput{
+ ReadUpperBound: lenBound,
+ Read: lenBound,
+ }
+
+ var nullBitmap *memory.Buffer
+
+ if sr.hasRepeatedChild {
+ nullBitmap = memory.NewResizableBuffer(sr.rctx.mem)
+ nullBitmap.Resize(int(bitutil.BytesForBits(lenBound)))
+ validityIO.ValidBits = nullBitmap.Bytes()
+ defLevels, err := sr.GetDefLevels()
+ if err != nil {
+ return nil, err
+ }
+ repLevels, err := sr.GetRepLevels()
+ if err != nil {
+ return nil, err
+ }
+
+ if err := file.DefRepLevelsToBitmap(defLevels, repLevels, sr.levelInfo, &validityIO); err != nil {
+ return nil, err
+ }
+
+ } else if sr.filtered.Nullable {
+ nullBitmap = memory.NewResizableBuffer(sr.rctx.mem)
+ nullBitmap.Resize(int(bitutil.BytesForBits(lenBound)))
+ validityIO.ValidBits = nullBitmap.Bytes()
+ defLevels, err := sr.GetDefLevels()
+ if err != nil {
+ return nil, err
+ }
+
+ file.DefLevelsToBitmap(defLevels, sr.levelInfo, &validityIO)
+ }
+
+ if nullBitmap != nil {
+ nullBitmap.Resize(int(bitutil.BytesForBits(validityIO.Read)))
+ }
+
+ childArrData := make([]*array.Data, 0)
+ // gather children arrays and def levels
+ for _, child := range sr.children {
+ field, err := child.BuildArray(validityIO.Read)
+ if err != nil {
+ return nil, err
+ }
+ arrdata, err := chunksToSingle(field)
+ if err != nil {
+ return nil, err
+ }
+ childArrData = append(childArrData, arrdata)
+ }
+
+ if !sr.filtered.Nullable && !sr.hasRepeatedChild {
+ validityIO.Read = int64(childArrData[0].Len())
+ }
+
+ buffers := make([]*memory.Buffer, 1)
+ if validityIO.NullCount > 0 {
+ buffers[0] = nullBitmap
+ }
+
+ data := array.NewData(sr.filtered.Type, int(validityIO.Read), buffers, childArrData, int(validityIO.NullCount), 0)
+ defer data.Release()
+ arr := array.MakeFromData(data)
+ defer arr.Release()
+ return array.NewChunked(sr.filtered.Type, []array.Interface{arr}), nil
+}
+
+type listReader struct {
+ rctx *readerCtx
+ field *arrow.Field
+ info file.LevelInfo
+ itemRdr *ColumnReader
+
+ refCount int64
+}
+
+func newListReader(rctx *readerCtx, field *arrow.Field, info file.LevelInfo, childRdr *ColumnReader) *ColumnReader {
+ childRdr.Retain()
+ return &ColumnReader{&listReader{rctx, field, info, childRdr, 1}}
+}
+
+func (lr *listReader) Retain() {
+ atomic.AddInt64(&lr.refCount, 1)
+}
+
+func (lr *listReader) Release() {
+ if atomic.AddInt64(&lr.refCount, -1) == 0 {
+ if lr.itemRdr != nil {
+ lr.itemRdr.Release()
+ lr.itemRdr = nil
+ }
+ }
+}
+
+func (lr *listReader) GetDefLevels() ([]int16, error) {
+ return lr.itemRdr.GetDefLevels()
+}
+
+func (lr *listReader) GetRepLevels() ([]int16, error) {
+ return lr.itemRdr.GetRepLevels()
+}
+
+func (lr *listReader) Field() *arrow.Field { return lr.field }
+
+func (lr *listReader) IsOrHasRepeatedChild() bool { return true }
+
+func (lr *listReader) LoadBatch(nrecords int64) error {
+ return lr.itemRdr.LoadBatch(nrecords)
+}
+
+func (lr *listReader) BuildArray(lenBound int64) (*array.Chunked, error) {
+ var (
+ defLevels []int16
+ repLevels []int16
+ err error
+ validityBuffer *memory.Buffer
+ )
+
+ if defLevels, err = lr.itemRdr.GetDefLevels(); err != nil {
+ return nil, err
+ }
+ if repLevels, err = lr.itemRdr.GetRepLevels(); err != nil {
+ return nil, err
+ }
+
+ validityIO := file.ValidityBitmapInputOutput{ReadUpperBound: lenBound}
+ if lr.field.Nullable {
+ validityBuffer = memory.NewResizableBuffer(lr.rctx.mem)
+ validityBuffer.Resize(int(bitutil.BytesForBits(lenBound)))
+ defer validityBuffer.Release()
+ validityIO.ValidBits = validityBuffer.Bytes()
+ }
+ offsetsBuffer := memory.NewResizableBuffer(lr.rctx.mem)
+ offsetsBuffer.Resize(arrow.Int32Traits.BytesRequired(int(lenBound) + 1))
+ defer offsetsBuffer.Release()
+
+ offsetData := arrow.Int32Traits.CastFromBytes(offsetsBuffer.Bytes())
+ if err = file.DefRepLevelsToListInfo(defLevels, repLevels, lr.info, &validityIO, offsetData); err != nil {
+ return nil, err
+ }
+
+ arr, err := lr.itemRdr.BuildArray(int64(offsetData[int(validityIO.Read)]))
+ if err != nil {
+ return nil, err
+ }
+
+ // resize to actual number of elems returned
+ offsetsBuffer.Resize(arrow.Int32Traits.BytesRequired(int(validityIO.Read) + 1))
+ if validityBuffer != nil {
+ validityBuffer.Resize(int(bitutil.BytesForBits(validityIO.Read)))
+ }
+
+ item, err := chunksToSingle(arr)
+ if err != nil {
+ return nil, err
+ }
+ defer item.Release()
+
+ buffers := []*memory.Buffer{nil, offsetsBuffer}
+ if validityIO.NullCount > 0 {
+ buffers[0] = validityBuffer
+ }
+
+ data := array.NewData(lr.field.Type, int(validityIO.Read), buffers, []*array.Data{item}, int(validityIO.NullCount), 0)
+ defer data.Release()
+ if lr.field.Type.ID() == arrow.FIXED_SIZE_LIST {
+ defer data.Buffers()[1].Release()
+ listSize := lr.field.Type.(*arrow.FixedSizeListType).Len()
+ for x := 1; x < data.Len(); x++ {
+ size := offsetData[x] - offsetData[x-1]
+ if size != listSize {
+ return nil, xerrors.Errorf("expected all lists to be of size=%d, but index %d had size=%d", listSize, x, size)
+ }
+ }
+ data.Buffers()[1] = nil
+ }
+ return array.NewChunked(lr.field.Type, []array.Interface{array.MakeFromData(data)}), nil
+}
+
+type fixedSizeListReader struct {
+ listReader
+}
+
+func newFixedSizeListReader(rctx *readerCtx, field *arrow.Field, info file.LevelInfo, childRdr *ColumnReader) *ColumnReader {
+ childRdr.Retain()
+ return &ColumnReader{&fixedSizeListReader{listReader{rctx, field, info, childRdr, 1}}}
+}
+
+func chunksToSingle(chunked *array.Chunked) (*array.Data, error) {
+ switch len(chunked.Chunks()) {
+ case 0:
+ return array.NewData(chunked.DataType(), 0, []*memory.Buffer{nil, nil}, nil, 0, 0), nil
+ case 1:
+ return chunked.Chunk(0).Data(), nil
+ default:
+ return nil, xerrors.New("not implemented")
+ }
+}
+
+func transferColumnData(rdr file.RecordReader, valueType arrow.DataType, descr *schema.Column, mem memory.Allocator) (*array.Chunked, error) {
+ var data array.Interface
+ switch valueType.ID() {
+ // case arrow.DICTIONARY:
+ case arrow.NULL:
+ return array.NewChunked(arrow.Null, []array.Interface{array.NewNull(rdr.ValuesWritten())}), nil
+ case arrow.INT32, arrow.INT64, arrow.FLOAT32, arrow.FLOAT64:
+ data = transferZeroCopy(rdr, valueType)
+ case arrow.BOOL:
+ data = transferBool(rdr)
+ case arrow.UINT8,
+ arrow.UINT16,
+ arrow.UINT32,
+ arrow.UINT64,
+ arrow.INT8,
+ arrow.INT16,
+ arrow.DATE32,
+ arrow.TIME32,
+ arrow.TIME64:
+ data = transferInt(rdr, valueType)
+ case arrow.DATE64:
+ data = transferDate64(rdr, valueType)
+ case arrow.FIXED_SIZE_BINARY, arrow.BINARY, arrow.STRING:
+ return transferBinary(rdr, valueType), nil
+ case arrow.DECIMAL:
+ switch descr.PhysicalType() {
+ case parquet.Types.Int32, parquet.Types.Int64:
+ data = transferDecimalInteger(rdr, valueType)
+ case parquet.Types.ByteArray, parquet.Types.FixedLenByteArray:
+ return transferDecimalBytes(rdr.(file.BinaryRecordReader), valueType)
+ default:
+ return nil, xerrors.New("physical type for decimal128 must be int32, int64, bytearray or fixed len byte array")
+ }
+ case arrow.TIMESTAMP:
+ tstype := valueType.(*arrow.TimestampType)
+ switch tstype.Unit {
+ case arrow.Millisecond, arrow.Microsecond:
+ data = transferZeroCopy(rdr, valueType)
+ case arrow.Nanosecond:
+ if descr.PhysicalType() == parquet.Types.Int96 {
+ data = transferInt96(rdr, valueType)
+ } else {
+ data = transferZeroCopy(rdr, valueType)
+ }
+ default:
+ return nil, xerrors.New("time unit not supported")
+ }
+ default:
+ return nil, xerrors.Errorf("no support for reading columns of type: %s", valueType.Name())
+ }
+
+ defer data.Release()
+ return array.NewChunked(valueType, []array.Interface{data}), nil
+}
+
+func transferZeroCopy(rdr file.RecordReader, dt arrow.DataType) array.Interface {
+ return array.MakeFromData(array.NewData(dt, rdr.ValuesWritten(), []*memory.Buffer{
+ rdr.ReleaseValidBits(), rdr.ReleaseValues(),
+ }, nil, int(rdr.NullCount()), 0))
+}
+
+func transferBinary(rdr file.RecordReader, dt arrow.DataType) *array.Chunked {
+ brdr := rdr.(file.BinaryRecordReader)
+ chunks := brdr.GetBuilderChunks()
+ if dt == arrow.BinaryTypes.String {
+ // convert chunks from binary to string without copying data
+ for idx := range chunks {
+ chunks[idx] = array.MakeFromData(chunks[idx].Data())
+ }
+ }
+ return array.NewChunked(dt, chunks)
+}
+
+func transferInt(rdr file.RecordReader, dt arrow.DataType) array.Interface {
+ var (
+ output reflect.Value
+ )
+
+ signed := true
+ data := make([]byte, rdr.ValuesWritten()*int(bitutil.BytesForBits(int64(dt.(arrow.FixedWidthDataType).BitWidth()))))
+ switch dt.ID() {
+ case arrow.INT8:
+ output = reflect.ValueOf(arrow.Int8Traits.CastFromBytes(data))
+ case arrow.UINT8:
+ signed = false
+ output = reflect.ValueOf(arrow.Uint8Traits.CastFromBytes(data))
+ case arrow.INT16:
+ output = reflect.ValueOf(arrow.Int16Traits.CastFromBytes(data))
+ case arrow.UINT16:
+ signed = false
+ output = reflect.ValueOf(arrow.Uint16Traits.CastFromBytes(data))
+ case arrow.UINT32:
+ signed = false
+ output = reflect.ValueOf(arrow.Uint32Traits.CastFromBytes(data))
+ case arrow.UINT64:
+ signed = false
+ output = reflect.ValueOf(arrow.Uint64Traits.CastFromBytes(data))
+ case arrow.DATE32:
+ output = reflect.ValueOf(arrow.Date32Traits.CastFromBytes(data))
+ case arrow.TIME32:
+ output = reflect.ValueOf(arrow.Time32Traits.CastFromBytes(data))
+ case arrow.TIME64:
+ output = reflect.ValueOf(arrow.Time64Traits.CastFromBytes(data))
+ }
+
+ length := rdr.ValuesWritten()
+ switch rdr.Type() {
+ case parquet.Types.Int32:
+ values := arrow.Int32Traits.CastFromBytes(rdr.Values())
+ if signed {
+ for idx, v := range values[:length] {
+ output.Index(idx).SetInt(int64(v))
+ }
+ } else {
+ for idx, v := range values[:length] {
+ output.Index(idx).SetUint(uint64(v))
+ }
+ }
+ case parquet.Types.Int64:
+ values := arrow.Int64Traits.CastFromBytes(rdr.Values())
+ if signed {
+ for idx, v := range values[:length] {
+ output.Index(idx).SetInt(v)
+ }
+ } else {
+ for idx, v := range values[:length] {
+ output.Index(idx).SetUint(uint64(v))
+ }
+ }
+ }
+
+ return array.MakeFromData(array.NewData(dt, rdr.ValuesWritten(), []*memory.Buffer{
+ rdr.ReleaseValidBits(), memory.NewBufferBytes(data),
+ }, nil, int(rdr.NullCount()), 0))
+}
+
+func transferBool(rdr file.RecordReader) array.Interface {
+ // TODO(mtopol): optimize this so we don't convert bitmap to []bool back to bitmap
+ length := rdr.ValuesWritten()
+ data := make([]byte, int(bitutil.BytesForBits(int64(length))))
+ bytedata := rdr.Values()
+ values := *(*[]bool)(unsafe.Pointer(&bytedata))
+
+ for idx, v := range values[:length] {
+ if v {
+ bitutil.SetBit(data, idx)
+ }
+ }
+
+ return array.MakeFromData(array.NewData(&arrow.BooleanType{}, length, []*memory.Buffer{
+ rdr.ReleaseValidBits(), memory.NewBufferBytes(data),
+ }, nil, int(rdr.NullCount()), 0))
+}
+
+var milliPerDay = time.Duration(24 * time.Hour).Milliseconds()
+
+func transferDate64(rdr file.RecordReader, dt arrow.DataType) array.Interface {
+ length := rdr.ValuesWritten()
+ values := arrow.Int32Traits.CastFromBytes(rdr.Values())
+
+ data := make([]byte, arrow.Int64Traits.BytesRequired(length))
+ out := arrow.Int64Traits.CastFromBytes(data)
+ for idx, val := range values[:length] {
+ out[idx] = int64(val) * milliPerDay
+ }
+
+ return array.MakeFromData(array.NewData(dt, length, []*memory.Buffer{
+ rdr.ReleaseValidBits(), memory.NewBufferBytes(data),
+ }, nil, int(rdr.NullCount()), 0))
+}
+
+func transferInt96(rdr file.RecordReader, dt arrow.DataType) array.Interface {
+ length := rdr.ValuesWritten()
+ values := parquet.Int96Traits.CastFromBytes(rdr.Values())
+
+ data := make([]byte, arrow.Int64SizeBytes*length)
+ out := arrow.Int64Traits.CastFromBytes(data)
+
+ for idx, val := range values[:length] {
+ if binary.LittleEndian.Uint32(val[8:]) == 0 {
+ out[idx] = 0
+ } else {
+ out[idx] = val.ToTime().UnixNano()
+ }
+ }
+
+ return array.MakeFromData(array.NewData(dt, length, []*memory.Buffer{
+ rdr.ReleaseValidBits(), memory.NewBufferBytes(data),
+ }, nil, int(rdr.NullCount()), 0))
+}
+
+func transferDecimalInteger(rdr file.RecordReader, dt arrow.DataType) array.Interface {
+ length := rdr.ValuesWritten()
+
+ var values reflect.Value
+ switch rdr.Type() {
+ case parquet.Types.Int32:
+ values = reflect.ValueOf(arrow.Int32Traits.CastFromBytes(rdr.Values())[:length])
+ case parquet.Types.Int64:
+ values = reflect.ValueOf(arrow.Int64Traits.CastFromBytes(rdr.Values())[:length])
+ }
+
+ data := make([]byte, arrow.Decimal128Traits.BytesRequired(length))
+ out := arrow.Decimal128Traits.CastFromBytes(data)
+ for i := 0; i < values.Len(); i++ {
+ out[i] = decimal128.FromI64(values.Index(i).Int())
+ }
+
+ var nullmap *memory.Buffer
+ if rdr.NullCount() > 0 {
+ nullmap = rdr.ReleaseValidBits()
+ }
+ return array.MakeFromData(array.NewData(dt, length, []*memory.Buffer{
+ nullmap, memory.NewBufferBytes(data),
+ }, nil, int(rdr.NullCount()), 0))
+}
+
+func uint64FromBigEndianShifted(buf []byte) uint64 {
+ var (
+ bytes [8]byte
+ )
+ copy(bytes[8-len(buf):], buf)
+ return binary.BigEndian.Uint64(bytes[:])
+}
+
+func bigEndianToDecimal128(buf []byte) (decimal128.Num, error) {
+ const (
+ minDecimalBytes = 1
+ maxDecimalBytes = 16
+ )
+
+ if len(buf) < minDecimalBytes || len(buf) > maxDecimalBytes {
+ return decimal128.Num{}, xerrors.Errorf("length of byte array passed to bigEndianToDecimal128 was %d but must be between %d and %d",
+ len(buf), minDecimalBytes, maxDecimalBytes)
+ }
+
+ // bytes are big endian so first byte is MSB and holds the sign bit
+ isNeg := int8(buf[0]) < 0
+
+ // 1. extract high bits
+ highBitsOffset := utils.MaxInt(0, len(buf)-8)
+ var (
+ highBits uint64
+ lowBits uint64
+ hi int64
+ lo int64
+ )
+ highBits = uint64FromBigEndianShifted(buf[:highBitsOffset])
+
+ if highBitsOffset == 8 {
+ hi = int64(highBits)
+ } else {
+ if isNeg && len(buf) < maxDecimalBytes {
+ hi = -1
+ }
+
+ hi = int64(uint64(hi) << (uint64(highBitsOffset) * 8))
+ hi |= int64(highBits)
+ }
+
+ // 2. extract lower bits
+ lowBitsOffset := utils.MinInt(len(buf), 8)
+ lowBits = uint64FromBigEndianShifted(buf[highBitsOffset:])
+
+ if lowBitsOffset == 8 {
+ lo = int64(lowBits)
+ } else {
+ if isNeg && len(buf) < 8 {
+ lo = -1
+ }
+
+ lo = int64(uint64(lo) << (uint64(lowBitsOffset) * 8))
+ lo |= int64(lowBits)
+ }
+
+ return decimal128.New(hi, uint64(lo)), nil
+}
+
+type varOrFixedBin interface {
+ array.Interface
+ Value(i int) []byte
+}
+
+func transferDecimalBytes(rdr file.BinaryRecordReader, dt arrow.DataType) (*array.Chunked, error) {
+ convert := func(arr array.Interface) (array.Interface, error) {
+ length := arr.Len()
+ data := make([]byte, arrow.Decimal128Traits.BytesRequired(length))
+ out := arrow.Decimal128Traits.CastFromBytes(data)
+
+ input := arr.(varOrFixedBin)
+ nullCount := input.NullN()
+
+ var err error
+ for i := 0; i < length; i++ {
+ if nullCount > 0 && input.IsNull(i) {
+ continue
+ }
+
+ rec := input.Value(i)
+ if len(rec) <= 0 {
+ return nil, xerrors.Errorf("invalud BYTEARRAY length for type: %s", dt)
+ }
+ out[i], err = bigEndianToDecimal128(rec)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return array.MakeFromData(array.NewData(dt, length, []*memory.Buffer{
+ input.Data().Buffers()[0], memory.NewBufferBytes(data),
+ }, nil, nullCount, 0)), nil
+ }
+
+ chunks := rdr.GetBuilderChunks()
+ var err error
+ for idx, chunk := range chunks {
+ defer chunk.Release()
+ if chunks[idx], err = convert(chunk); err != nil {
+ return nil, err
+ }
+ }
+ return array.NewChunked(dt, chunks), nil
+}
diff --git a/go/parquet/pqarrow/doc.go b/go/parquet/pqarrow/doc.go
new file mode 100644
index 0000000..488e12e
--- /dev/null
+++ b/go/parquet/pqarrow/doc.go
@@ -0,0 +1,21 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pqarrow provides the implementation for connecting Arrow directly
+// with the Parquet implementation, allowing isolation of all the explicitly
+// arrow related code to this package which has the interfaces for reading and
+// writing directly to and from arrow Arrays/Tables/Records
+package pqarrow
diff --git a/go/parquet/pqarrow/encode_arrow.go b/go/parquet/pqarrow/encode_arrow.go
new file mode 100644
index 0000000..65bee37
--- /dev/null
+++ b/go/parquet/pqarrow/encode_arrow.go
@@ -0,0 +1,586 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pqarrow
+
+import (
+ "context"
+ "encoding/binary"
+ "time"
+ "unsafe"
+
+ "github.com/apache/arrow/go/arrow"
+ "github.com/apache/arrow/go/arrow/array"
+ "github.com/apache/arrow/go/arrow/bitutil"
+ "github.com/apache/arrow/go/arrow/decimal128"
+ "github.com/apache/arrow/go/arrow/memory"
+ "github.com/apache/arrow/go/parquet"
+ "github.com/apache/arrow/go/parquet/file"
+ "github.com/apache/arrow/go/parquet/internal/utils"
+ "golang.org/x/xerrors"
+)
+
+func calcLeafCount(dt arrow.DataType) int {
+ switch dt.ID() {
+ case arrow.EXTENSION, arrow.UNION:
+ panic("arrow type not implemented")
+ case arrow.LIST:
+ return calcLeafCount(dt.(*arrow.ListType).Elem())
+ case arrow.FIXED_SIZE_LIST:
+ return calcLeafCount(dt.(*arrow.FixedSizeListType).Elem())
+ case arrow.MAP:
+ return calcLeafCount(dt.(*arrow.MapType).ValueType())
+ case arrow.STRUCT:
+ nleaves := 0
+ for _, f := range dt.(*arrow.StructType).Fields() {
+ nleaves += calcLeafCount(f.Type)
+ }
+ return nleaves
+ default:
+ return 1
+ }
+}
+
+func nullableRoot(manifest *SchemaManifest, field *SchemaField) bool {
+ curField := field
+ nullable := field.Field.Nullable
+ for curField != nil {
+ nullable = curField.Field.Nullable
+ curField = manifest.GetParent(curField)
+ }
+ return nullable
+}
+
+// ArrowColumnWriter is a convenience object for easily writing arrow data to a specific
+// set of columns in a parquet file. Since a single arrow array can itself be a nested type
+// consisting of multiple columns of data, this will write to all of the appropriate leaves in
+// the parquet file, allowing easy writing of nested columns.
+type ArrowColumnWriter struct {
+ builders []*multipathLevelBuilder
+ leafCount int
+ colIdx int
+ rgw file.RowGroupWriter
+}
+
+// NewArrowColumnWriter returns a new writer using the chunked array to determine the number of leaf columns,
+// and the provided schema manifest to determine the paths for writing the columns.
+//
+// Using an arrow column writer is a convenience to avoid having to process the arrow array yourself
+// and determine the correct definition and repetition levels manually.
+func NewArrowColumnWriter(data *array.Chunked, offset, size int64, manifest *SchemaManifest, rgw file.RowGroupWriter, col int) (ArrowColumnWriter, error) {
+ if data.Len() == 0 {
+ return ArrowColumnWriter{leafCount: calcLeafCount(data.DataType()), rgw: rgw}, nil
+ }
+
+ var (
+ absPos int64
+ chunkOffset int64
+ chunkIdx int
+ values int64
+ )
+
+ for idx, chnk := range data.Chunks() {
+ chunkIdx = idx
+ if absPos >= offset {
+ break
+ }
+
+ chunkLen := int64(chnk.Len())
+ if absPos+chunkLen > offset {
+ chunkOffset = offset - absPos
+ break
+ }
+
+ absPos += chunkLen
+ }
+
+ if absPos >= int64(data.Len()) {
+ return ArrowColumnWriter{}, xerrors.New("cannot write data at offset past end of chunked array")
+ }
+
+ leafCount := calcLeafCount(data.DataType())
+ isNullable := false
+ // row group writer hasn't been advanced yet so add 1 to the current
+ // which is the one this instance will start writing for
+ // colIdx := rgw.CurrentColumn() + 1
+
+ schemaField, err := manifest.GetColumnField(col)
+ if err != nil {
+ return ArrowColumnWriter{}, err
+ }
+ isNullable = nullableRoot(manifest, schemaField)
+
+ builders := make([]*multipathLevelBuilder, 0)
+ for values < size {
+ chunk := data.Chunk(chunkIdx)
+ available := int64(chunk.Len() - int(chunkOffset))
+ chunkWriteSize := utils.Min(size-values, available)
+
+ // the chunk offset will be 0 here except for possibly the first chunk
+ // because of the above advancing logic
+ arrToWrite := array.NewSlice(chunk, chunkOffset, chunkOffset+chunkWriteSize)
+
+ if arrToWrite.Len() > 0 {
+ bldr, err := newMultipathLevelBuilder(arrToWrite, isNullable)
+ if err != nil {
+ return ArrowColumnWriter{}, nil
+ }
+ if leafCount != bldr.leafCount() {
+ return ArrowColumnWriter{}, xerrors.Errorf("data type leaf_count != builder leafcount: %d - %d", leafCount, bldr.leafCount())
+ }
+ builders = append(builders, bldr)
+ }
+
+ if chunkWriteSize == available {
+ chunkOffset = 0
+ chunkIdx++
+ }
+ values += chunkWriteSize
+ }
+
+ return ArrowColumnWriter{builders: builders, leafCount: leafCount, rgw: rgw, colIdx: col}, nil
+}
+
+func (acw *ArrowColumnWriter) Write(ctx context.Context) error {
+ arrCtx := arrowCtxFromContext(ctx)
+ for leafIdx := 0; leafIdx < acw.leafCount; leafIdx++ {
+ var (
+ cw file.ColumnWriter
+ err error
+ )
+
+ if acw.rgw.Buffered() {
+ cw, err = acw.rgw.(file.BufferedRowGroupWriter).Column(acw.colIdx + leafIdx)
+ } else {
+ cw, err = acw.rgw.(file.SerialRowGroupWriter).NextColumn()
+ }
+ // cw, err := acw.rgw.NextColumn()
+ if err != nil {
+ return err
+ }
+
+ for _, bldr := range acw.builders {
+ if leafIdx == 0 {
+ defer bldr.Release()
+ }
+ res, err := bldr.write(leafIdx, arrCtx)
+ if err != nil {
+ return err
+ }
+ defer res.Release()
+
+ if len(res.postListVisitedElems) != 1 {
+ return xerrors.New("lists with non-zero length null components are not supported")
+ }
+ rng := res.postListVisitedElems[0]
+ values := array.NewSlice(res.leafArr, rng.start, rng.end)
+ defer values.Release()
+ if err = WriteArrowToColumn(ctx, cw, values, res.defLevels, res.repLevels, res.leafIsNullable); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// WriteArrowToColumn writes apache arrow columnar data directly to a ColumnWriter.
+// Returns non-nil error if the array data type is not compatible with the concrete
+// writer type.
+//
+// leafArr is always a primitive (possibly dictionary encoded type).
+// Leaf_field_nullable indicates whether the leaf array is considered nullable
+// according to its schema in a Table or its parent array.
+func WriteArrowToColumn(ctx context.Context, cw file.ColumnWriter, leafArr array.Interface, defLevels, repLevels []int16, leafFieldNullable bool) error {
+ // Leaf nulls are canonical when there is only a single null element after a list
+ // and it is at the leaf.
+ colLevelInfo := cw.LevelInfo()
+ singleNullable := (colLevelInfo.DefLevel == colLevelInfo.RepeatedAncestorDefLevel+1) && leafFieldNullable
+ maybeParentNulls := colLevelInfo.HasNullableValues() && !singleNullable
+
+ if maybeParentNulls {
+ buf := memory.NewResizableBuffer(cw.Properties().Allocator())
+ buf.Resize(int(bitutil.BytesForBits(cw.Properties().WriteBatchSize())))
+ cw.SetBitsBuffer(buf)
+ }
+
+ if leafArr.DataType().ID() == arrow.DICTIONARY {
+ // TODO(mtopol): write arrow dictionary ARROW-7283
+ }
+ return writeDenseArrow(arrowCtxFromContext(ctx), cw, leafArr, defLevels, repLevels, maybeParentNulls)
+}
+
+type binaryarr interface {
+ ValueOffsets() []int32
+}
+
+func writeDenseArrow(ctx *arrowWriteContext, cw file.ColumnWriter, leafArr array.Interface, defLevels, repLevels []int16, maybeParentNulls bool) (err error) {
+ noNulls := cw.Descr().SchemaNode().RepetitionType() == parquet.Repetitions.Required || leafArr.NullN() == 0
+
+ if ctx.dataBuffer == nil {
+ ctx.dataBuffer = memory.NewResizableBuffer(cw.Properties().Allocator())
+ }
+
+ switch wr := cw.(type) {
+ case *file.BooleanColumnWriter:
+ if leafArr.DataType().ID() != arrow.BOOL {
+ return xerrors.Errorf("type mismatch, column is %s, array is %s", cw.Type(), leafArr.DataType().ID())
+ }
+
+ if leafArr.Len() == 0 {
+ wr.WriteBatch(nil, defLevels, repLevels)
+ break
+ }
+
+ ctx.dataBuffer.ResizeNoShrink(leafArr.Len())
+ buf := ctx.dataBuffer.Bytes()
+ data := *(*[]bool)(unsafe.Pointer(&buf))
+ for idx := range data {
+ data[idx] = leafArr.(*array.Boolean).Value(idx)
+ }
+ if !maybeParentNulls && noNulls {
+ wr.WriteBatch(data, defLevels, repLevels)
+ } else {
+ wr.WriteBatchSpaced(data, defLevels, repLevels, leafArr.NullBitmapBytes(), int64(leafArr.Data().Offset()))
+ }
+ case *file.Int32ColumnWriter:
+ var data []int32
+ switch leafArr.DataType().ID() {
+ case arrow.INT32:
+ data = leafArr.(*array.Int32).Int32Values()
+ case arrow.DATE32, arrow.UINT32:
+ data = arrow.Int32Traits.CastFromBytes(leafArr.Data().Buffers()[1].Bytes())
+ data = data[leafArr.Data().Offset() : leafArr.Data().Offset()+leafArr.Len()]
+ case arrow.TIME32:
+ if leafArr.DataType().(*arrow.Time32Type).Unit != arrow.Second {
+ data = arrow.Int32Traits.CastFromBytes(leafArr.Data().Buffers()[1].Bytes())
+ data = data[leafArr.Data().Offset() : leafArr.Data().Offset()+leafArr.Len()]
+ } else {
+ ctx.dataBuffer.ResizeNoShrink(arrow.Int32Traits.BytesRequired(leafArr.Len()))
+ data = arrow.Int32Traits.CastFromBytes(ctx.dataBuffer.Bytes())
+ for idx, val := range leafArr.(*array.Time32).Time32Values() {
+ data[idx] = int32(val) * 1000
+ }
+ }
+
+ default:
+ ctx.dataBuffer.ResizeNoShrink(arrow.Int32Traits.BytesRequired(leafArr.Len()))
+ data = arrow.Int32Traits.CastFromBytes(ctx.dataBuffer.Bytes())
+ switch leafArr.DataType().ID() {
+ case arrow.UINT8:
+ for idx, val := range leafArr.(*array.Uint8).Uint8Values() {
+ data[idx] = int32(val)
+ }
+ case arrow.INT8:
+ for idx, val := range leafArr.(*array.Int8).Int8Values() {
+ data[idx] = int32(val)
+ }
+ case arrow.UINT16:
+ for idx, val := range leafArr.(*array.Uint16).Uint16Values() {
+ data[idx] = int32(val)
+ }
+ case arrow.INT16:
+ for idx, val := range leafArr.(*array.Int16).Int16Values() {
+ data[idx] = int32(val)
+ }
+ case arrow.DATE64:
+ for idx, val := range leafArr.(*array.Date64).Date64Values() {
+ data[idx] = int32(val / 86400000)
+ }
+ default:
+ return xerrors.Errorf("type mismatch, column is int32 writer, arrow array is %s, and not a compatible type", leafArr.DataType().Name())
+ }
+ }
+
+ if !maybeParentNulls && noNulls {
+ wr.WriteBatch(data, defLevels, repLevels)
+ } else {
+ nulls := leafArr.NullBitmapBytes()
+ wr.WriteBatchSpaced(data, defLevels, repLevels, nulls, int64(leafArr.Data().Offset()))
+ }
+ case *file.Int64ColumnWriter:
+ var data []int64
+ switch leafArr.DataType().ID() {
+ case arrow.TIMESTAMP:
+ tstype := leafArr.DataType().(*arrow.TimestampType)
+ if ctx.props.coerceTimestamps {
+ // user explicitly requested coercion to specific unit
+ if tstype.Unit == ctx.props.coerceTimestampUnit {
+ // no conversion necessary
+ data = arrow.Int64Traits.CastFromBytes(leafArr.Data().Buffers()[1].Bytes())
+ data = data[leafArr.Data().Offset() : leafArr.Data().Offset()+leafArr.Len()]
+ } else {
+ ctx.dataBuffer.ResizeNoShrink(arrow.Int64Traits.BytesRequired(leafArr.Len()))
+ data = arrow.Int64Traits.CastFromBytes(ctx.dataBuffer.Bytes())
+ if err := writeCoerceTimestamps(leafArr.(*array.Timestamp), &ctx.props, data); err != nil {
+ return err
+ }
+ }
+ } else if (cw.Properties().Version() == parquet.V1_0 || cw.Properties().Version() == parquet.V2_4) && tstype.Unit == arrow.Nanosecond {
+ // absent superceding user instructions, when writing a Parquet Version <=2.4 File,
+ // timestamps in nano seconds are coerced to microseconds
+ ctx.dataBuffer.ResizeNoShrink(arrow.Int64Traits.BytesRequired(leafArr.Len()))
+ data = arrow.Int64Traits.CastFromBytes(ctx.dataBuffer.Bytes())
+ p := NewArrowWriterProperties(WithCoerceTimestamps(arrow.Microsecond), WithTruncatedTimestamps(true))
+ if err := writeCoerceTimestamps(leafArr.(*array.Timestamp), &p, data); err != nil {
+ return err
+ }
+ } else if tstype.Unit == arrow.Second {
+ // absent superceding user instructions, timestamps in seconds are coerced
+ // to milliseconds
+ p := NewArrowWriterProperties(WithCoerceTimestamps(arrow.Millisecond))
+ ctx.dataBuffer.ResizeNoShrink(arrow.Int64Traits.BytesRequired(leafArr.Len()))
+ data = arrow.Int64Traits.CastFromBytes(ctx.dataBuffer.Bytes())
+ if err := writeCoerceTimestamps(leafArr.(*array.Timestamp), &p, data); err != nil {
+ return err
+ }
+ } else {
+ // no data conversion neccessary
+ data = arrow.Int64Traits.CastFromBytes(leafArr.Data().Buffers()[1].Bytes())
+ data = data[leafArr.Data().Offset() : leafArr.Data().Offset()+leafArr.Len()]
+ }
+ case arrow.UINT32:
+ ctx.dataBuffer.ResizeNoShrink(arrow.Int64Traits.BytesRequired(leafArr.Len()))
+ data = arrow.Int64Traits.CastFromBytes(ctx.dataBuffer.Bytes())
+ for idx, val := range leafArr.(*array.Uint32).Uint32Values() {
+ data[idx] = int64(val)
+ }
+ case arrow.INT64:
+ data = leafArr.(*array.Int64).Int64Values()
+ case arrow.UINT64, arrow.TIME64, arrow.DATE64:
+ data = arrow.Int64Traits.CastFromBytes(leafArr.Data().Buffers()[1].Bytes())
+ data = data[leafArr.Data().Offset() : leafArr.Data().Offset()+leafArr.Len()]
+ default:
+ return xerrors.Errorf("unimplemented arrow type to write to int64 column: %s", leafArr.DataType().Name())
+ }
+
+ if !maybeParentNulls && noNulls {
+ wr.WriteBatch(data, defLevels, repLevels)
+ } else {
+ nulls := leafArr.NullBitmapBytes()
+ wr.WriteBatchSpaced(data, defLevels, repLevels, nulls, int64(leafArr.Data().Offset()))
+ }
+ case *file.Int96ColumnWriter:
+ if leafArr.DataType().ID() != arrow.TIMESTAMP {
+ return xerrors.New("unsupported arrow type to write to Int96 column")
+ }
+ ctx.dataBuffer.ResizeNoShrink(parquet.Int96Traits.BytesRequired(leafArr.Len()))
+ data := parquet.Int96Traits.CastFromBytes(ctx.dataBuffer.Bytes())
+ input := leafArr.(*array.Timestamp).TimestampValues()
+ unit := leafArr.DataType().(*arrow.TimestampType).Unit
+ for idx, val := range input {
+ arrowTimestampToImpalaTimestamp(unit, int64(val), &data[idx])
+ }
+
+ if !maybeParentNulls && noNulls {
+ wr.WriteBatch(data, defLevels, repLevels)
+ } else {
+ nulls := leafArr.NullBitmapBytes()
+ wr.WriteBatchSpaced(data, defLevels, repLevels, nulls, int64(leafArr.Data().Offset()))
+ }
+ case *file.Float32ColumnWriter:
+ if leafArr.DataType().ID() != arrow.FLOAT32 {
+ return xerrors.New("invalid column type to write to Float")
+ }
+ if !maybeParentNulls && noNulls {
+ wr.WriteBatch(leafArr.(*array.Float32).Float32Values(), defLevels, repLevels)
+ } else {
+ wr.WriteBatchSpaced(leafArr.(*array.Float32).Float32Values(), defLevels, repLevels, leafArr.NullBitmapBytes(), int64(leafArr.Data().Offset()))
+ }
+ case *file.Float64ColumnWriter:
+ if leafArr.DataType().ID() != arrow.FLOAT64 {
+ return xerrors.New("invalid column type to write to Float")
+ }
+ if !maybeParentNulls && noNulls {
+ wr.WriteBatch(leafArr.(*array.Float64).Float64Values(), defLevels, repLevels)
+ } else {
+ wr.WriteBatchSpaced(leafArr.(*array.Float64).Float64Values(), defLevels, repLevels, leafArr.NullBitmapBytes(), int64(leafArr.Data().Offset()))
+ }
+ case *file.ByteArrayColumnWriter:
+ if leafArr.DataType().ID() != arrow.STRING && leafArr.DataType().ID() != arrow.BINARY {
+ return xerrors.New("invalid column type to write to ByteArray")
+ }
+
+ var (
+ offsets = leafArr.(binaryarr).ValueOffsets()
+ buffer = leafArr.Data().Buffers()[2]
+ valueBuf []byte
+ )
+
+ if buffer == nil {
+ valueBuf = []byte{}
+ } else {
+ valueBuf = buffer.Bytes()
+ }
+
+ data := make([]parquet.ByteArray, leafArr.Len())
+ for i := range data {
+ data[i] = parquet.ByteArray(valueBuf[offsets[i]:offsets[i+1]])
+ }
+ if !maybeParentNulls && noNulls {
+ wr.WriteBatch(data, defLevels, repLevels)
+ } else {
+ wr.WriteBatchSpaced(data, defLevels, repLevels, leafArr.NullBitmapBytes(), int64(leafArr.Data().Offset()))
+ }
+
+ case *file.FixedLenByteArrayColumnWriter:
+ switch dt := leafArr.DataType().(type) {
+ case *arrow.FixedSizeBinaryType:
+ data := make([]parquet.FixedLenByteArray, leafArr.Len())
+ for idx := range data {
+ data[idx] = leafArr.(*array.FixedSizeBinary).Value(idx)
+ }
+ if !maybeParentNulls && noNulls {
+ wr.WriteBatch(data, defLevels, repLevels)
+ } else {
+ wr.WriteBatchSpaced(data, defLevels, repLevels, leafArr.NullBitmapBytes(), int64(leafArr.Data().Offset()))
+ }
+ case *arrow.Decimal128Type:
+ // parquet decimal are stored with FixedLength values where the length is
+ // proportional to the precision. Arrow's Decimal are always stored with 16/32
+ // bytes. thus the internal FLBA must be adjusted by the offset calculation
+ offset := dt.BitWidth() - int(DecimalSize(dt.Precision))
+ ctx.dataBuffer.ResizeNoShrink((leafArr.Len() - leafArr.NullN()) * dt.BitWidth())
+ scratch := ctx.dataBuffer.Bytes()
+ typeLen := wr.Descr().TypeLength()
+ fixDecimalEndianness := func(in decimal128.Num) parquet.FixedLenByteArray {
+ out := scratch[offset : offset+typeLen]
+ binary.BigEndian.PutUint64(scratch, uint64(in.HighBits()))
+ binary.BigEndian.PutUint64(scratch[arrow.Uint64SizeBytes:], in.LowBits())
+ scratch = scratch[2*arrow.Uint64SizeBytes:]
+ return out
+ }
+
+ data := make([]parquet.FixedLenByteArray, leafArr.Len())
+ arr := leafArr.(*array.Decimal128)
+ if leafArr.NullN() == 0 {
+ for idx := range data {
+ data[idx] = fixDecimalEndianness(arr.Value(idx))
+ }
+ wr.WriteBatch(data, defLevels, repLevels)
+ } else {
+ for idx := range data {
+ if arr.IsValid(idx) {
+ data[idx] = fixDecimalEndianness(arr.Value(idx))
+ }
+ }
+ wr.WriteBatchSpaced(data, defLevels, repLevels, arr.NullBitmapBytes(), int64(arr.Data().Offset()))
+ }
+ default:
+ return xerrors.New("unimplemented")
+ }
+ default:
+ return xerrors.New("unknown column writer physical type")
+ }
+ return
+}
+
+type coerceType int8
+
+const (
+ coerceInvalid coerceType = iota
+ coerceDivide
+ coerceMultiply
+)
+
+type coercePair struct {
+ typ coerceType
+ factor int64
+}
+
+var factors = map[arrow.TimeUnit]map[arrow.TimeUnit]coercePair{
+ arrow.Second: {
+ arrow.Second: {coerceInvalid, 0},
+ arrow.Millisecond: {coerceMultiply, 1000},
+ arrow.Microsecond: {coerceMultiply, 1000000},
+ arrow.Nanosecond: {coerceMultiply, 1000000000},
+ },
+ arrow.Millisecond: {
+ arrow.Second: {coerceInvalid, 0},
+ arrow.Millisecond: {coerceMultiply, 1},
+ arrow.Microsecond: {coerceMultiply, 1000},
+ arrow.Nanosecond: {coerceMultiply, 1000000},
+ },
+ arrow.Microsecond: {
+ arrow.Second: {coerceInvalid, 0},
+ arrow.Millisecond: {coerceDivide, 1000},
+ arrow.Microsecond: {coerceMultiply, 1},
+ arrow.Nanosecond: {coerceMultiply, 1000},
+ },
+ arrow.Nanosecond: {
+ arrow.Second: {coerceInvalid, 0},
+ arrow.Millisecond: {coerceDivide, 1000000},
+ arrow.Microsecond: {coerceDivide, 1000},
+ arrow.Nanosecond: {coerceMultiply, 1},
+ },
+}
+
+func writeCoerceTimestamps(arr *array.Timestamp, props *ArrowWriterProperties, out []int64) error {
+ source := arr.DataType().(*arrow.TimestampType).Unit
+ target := props.coerceTimestampUnit
+ truncation := props.allowTruncatedTimestamps
+
+ vals := arr.TimestampValues()
+ multiply := func(factor int64) error {
+ for idx, val := range vals {
+ out[idx] = int64(val) * factor
+ }
+ return nil
+ }
+
+ divide := func(factor int64) error {
+ for idx, val := range vals {
+ if !truncation && arr.IsValid(idx) && (int64(val)%factor != 0) {
+ return xerrors.Errorf("casting from %s to %s would lose data", source, target)
+ }
+ out[idx] = int64(val) / factor
+ }
+ return nil
+ }
+
+ coerce := factors[source][target]
+ switch coerce.typ {
+ case coerceMultiply:
+ return multiply(coerce.factor)
+ case coerceDivide:
+ return divide(coerce.factor)
+ default:
+ panic("invalid coercion")
+ }
+}
+
+const (
+ julianEpochOffsetDays int64 = 2440588
+ nanoSecondsPerDay = 24 * 60 * 60 * 1000 * 1000 * 1000
+)
+
+func arrowTimestampToImpalaTimestamp(unit arrow.TimeUnit, t int64, out *parquet.Int96) {
+ var d time.Duration
+ switch unit {
+ case arrow.Second:
+ d = time.Duration(t) * time.Second
+ case arrow.Microsecond:
+ d = time.Duration(t) * time.Microsecond
+ case arrow.Millisecond:
+ d = time.Duration(t) * time.Millisecond
+ case arrow.Nanosecond:
+ d = time.Duration(t) * time.Nanosecond
+ }
+
+ julianDays := (int64(d.Hours()) / 24) + julianEpochOffsetDays
+ lastDayNanos := t % (nanoSecondsPerDay)
+ binary.LittleEndian.PutUint64((*out)[:8], uint64(lastDayNanos))
+ binary.LittleEndian.PutUint32((*out)[8:], uint32(julianDays))
+}
diff --git a/go/parquet/pqarrow/encode_arrow_test.go b/go/parquet/pqarrow/encode_arrow_test.go
new file mode 100644
index 0000000..2146d42
--- /dev/null
+++ b/go/parquet/pqarrow/encode_arrow_test.go
@@ -0,0 +1,1379 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pqarrow_test
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "math"
+ "strconv"
+ "testing"
+
+ "github.com/apache/arrow/go/arrow"
+ "github.com/apache/arrow/go/arrow/array"
+ "github.com/apache/arrow/go/arrow/bitutil"
+ "github.com/apache/arrow/go/arrow/decimal128"
+ "github.com/apache/arrow/go/arrow/memory"
+ "github.com/apache/arrow/go/parquet"
+ "github.com/apache/arrow/go/parquet/compress"
+ "github.com/apache/arrow/go/parquet/file"
+ "github.com/apache/arrow/go/parquet/internal/encoding"
+ "github.com/apache/arrow/go/parquet/internal/testutils"
+ "github.com/apache/arrow/go/parquet/internal/utils"
+ "github.com/apache/arrow/go/parquet/pqarrow"
+ "github.com/apache/arrow/go/parquet/schema"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/stretchr/testify/suite"
+)
+
+func makeSimpleTable(values *array.Chunked, nullable bool) array.Table {
+ sc := arrow.NewSchema([]arrow.Field{{Name: "col", Type: values.DataType(), Nullable: nullable}}, nil)
+ column := array.NewColumn(sc.Field(0), values)
+ defer column.Release()
+ return array.NewTable(sc, []array.Column{*column}, -1)
+}
+
+func makeDateTimeTypesTable(expected bool, addFieldMeta bool) array.Table {
+ isValid := []bool{true, true, true, false, true, true}
+
+ // roundtrip without modification
+ f0 := arrow.Field{Name: "f0", Type: arrow.FixedWidthTypes.Date32, Nullable: true}
+ f1 := arrow.Field{Name: "f1", Type: arrow.FixedWidthTypes.Timestamp_ms, Nullable: true}
+ f2 := arrow.Field{Name: "f2", Type: arrow.FixedWidthTypes.Timestamp_us, Nullable: true}
+ f3 := arrow.Field{Name: "f3", Type: arrow.FixedWidthTypes.Timestamp_ns, Nullable: true}
+ f3X := arrow.Field{Name: "f3", Type: arrow.FixedWidthTypes.Timestamp_us, Nullable: true}
+ f4 := arrow.Field{Name: "f4", Type: arrow.FixedWidthTypes.Time32ms, Nullable: true}
+ f5 := arrow.Field{Name: "f5", Type: arrow.FixedWidthTypes.Time64us, Nullable: true}
+ f6 := arrow.Field{Name: "f6", Type: arrow.FixedWidthTypes.Time64ns, Nullable: true}
+
+ fieldList := []arrow.Field{f0, f1, f2}
+ if expected {
+ fieldList = append(fieldList, f3X)
+ } else {
+ fieldList = append(fieldList, f3)
+ }
+ fieldList = append(fieldList, f4, f5, f6)
+
+ if addFieldMeta {
+ for idx := range fieldList {
+ fieldList[idx].Metadata = arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{strconv.Itoa(idx + 1)})
+ }
+ }
+ arrsc := arrow.NewSchema(fieldList, nil)
+
+ d32Values := []arrow.Date32{1489269000, 1489270000, 1489271000, 1489272000, 1489272000, 1489273000}
+ ts64nsValues := []arrow.Timestamp{1489269000000, 1489270000000, 1489271000000, 1489272000000, 1489272000000, 1489273000000}
+ ts64usValues := []arrow.Timestamp{1489269000, 1489270000, 1489271000, 1489272000, 1489272000, 1489273000}
+ ts64msValues := []arrow.Timestamp{1489269, 1489270, 1489271, 1489272, 1489272, 1489273}
+ t32Values := []arrow.Time32{1489269000, 1489270000, 1489271000, 1489272000, 1489272000, 1489273000}
+ t64nsValues := []arrow.Time64{1489269000000, 1489270000000, 1489271000000, 1489272000000, 1489272000000, 1489273000000}
+ t64usValues := []arrow.Time64{1489269000, 1489270000, 1489271000, 1489272000, 1489272000, 1489273000}
+
+ builders := make([]array.Builder, 0, len(fieldList))
+ for _, f := range fieldList {
+ bldr := array.NewBuilder(memory.DefaultAllocator, f.Type)
+ defer bldr.Release()
+ builders = append(builders, bldr)
+ }
+
+ builders[0].(*array.Date32Builder).AppendValues(d32Values, isValid)
+ builders[1].(*array.TimestampBuilder).AppendValues(ts64msValues, isValid)
+ builders[2].(*array.TimestampBuilder).AppendValues(ts64usValues, isValid)
+ if expected {
+ builders[3].(*array.TimestampBuilder).AppendValues(ts64usValues, isValid)
+ } else {
+ builders[3].(*array.TimestampBuilder).AppendValues(ts64nsValues, isValid)
+ }
+ builders[4].(*array.Time32Builder).AppendValues(t32Values, isValid)
+ builders[5].(*array.Time64Builder).AppendValues(t64usValues, isValid)
+ builders[6].(*array.Time64Builder).AppendValues(t64nsValues, isValid)
+
+ cols := make([]array.Column, 0, len(fieldList))
+ for idx, field := range fieldList {
+ arr := builders[idx].NewArray()
+ defer arr.Release()
+
+ cols = append(cols, *array.NewColumn(field, array.NewChunked(field.Type, []array.Interface{arr})))
+ }
+
+ return array.NewTable(arrsc, cols, int64(len(isValid)))
+}
+
+func TestWriteArrowCols(t *testing.T) {
+ tbl := makeDateTimeTypesTable(false, false)
+ defer tbl.Release()
+
+ psc, err := pqarrow.ToParquet(tbl.Schema(), nil, pqarrow.DefaultWriterProps())
+ require.NoError(t, err)
+
+ manifest, err := pqarrow.NewSchemaManifest(psc, nil, nil)
+ require.NoError(t, err)
+
+ sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
+ writer := file.NewParquetWriter(sink, psc.Root(), file.WithWriterProps(parquet.NewWriterProperties(parquet.WithVersion(parquet.V2_4))))
+
+ srgw := writer.AppendRowGroup()
+ ctx := pqarrow.NewArrowWriteContext(context.TODO(), nil)
+
+ for i := int64(0); i < tbl.NumCols(); i++ {
+ acw, err := pqarrow.NewArrowColumnWriter(tbl.Column(int(i)).Data(), 0, tbl.NumRows(), manifest, srgw, int(i))
+ require.NoError(t, err)
+ require.NoError(t, acw.Write(ctx))
+ }
+ require.NoError(t, srgw.Close())
+ require.NoError(t, writer.Close())
+
+ expected := makeDateTimeTypesTable(true, false)
+ defer expected.Release()
+
+ reader, err := file.NewParquetReader(bytes.NewReader(sink.Bytes()))
+ require.NoError(t, err)
+
+ assert.EqualValues(t, expected.NumCols(), reader.MetaData().Schema.NumColumns())
+ assert.EqualValues(t, expected.NumRows(), reader.NumRows())
+ assert.EqualValues(t, 1, reader.NumRowGroups())
+
+ rgr := reader.RowGroup(0)
+
+ for i := 0; i < int(expected.NumCols()); i++ {
+ var (
+ total int64
+ read int
+ err error
+ defLevelsOut = make([]int16, int(expected.NumRows()))
+ arr = expected.Column(i).Data().Chunk(0)
+ )
+ switch expected.Schema().Field(i).Type.(arrow.FixedWidthDataType).BitWidth() {
+ case 32:
+ colReader := rgr.Column(i).(*file.Int32ColumnReader)
+ vals := make([]int32, int(expected.NumRows()))
+ total, read, err = colReader.ReadBatch(expected.NumRows(), vals, defLevelsOut, nil)
+ require.NoError(t, err)
+
+ nulls := 0
+ for j := 0; j < arr.Len(); j++ {
+ if arr.IsNull(j) {
+ nulls++
+ continue
+ }
+
+ switch v := arr.(type) {
+ case *array.Date32:
+ assert.EqualValues(t, v.Value(j), vals[j-nulls])
+ case *array.Time32:
+ assert.EqualValues(t, v.Value(j), vals[j-nulls])
+ }
+ }
+ case 64:
+ colReader := rgr.Column(i).(*file.Int64ColumnReader)
+ vals := make([]int64, int(expected.NumRows()))
+ total, read, err = colReader.ReadBatch(expected.NumRows(), vals, defLevelsOut, nil)
+ require.NoError(t, err)
+
+ nulls := 0
+ for j := 0; j < arr.Len(); j++ {
+ if arr.IsNull(j) {
+ nulls++
+ continue
+ }
+
+ switch v := arr.(type) {
+ case *array.Date64:
+ assert.EqualValues(t, v.Value(j), vals[j-nulls])
+ case *array.Time64:
+ assert.EqualValues(t, v.Value(j), vals[j-nulls])
+ case *array.Timestamp:
+ assert.EqualValues(t, v.Value(j), vals[j-nulls])
+ }
+ }
+ }
+ assert.EqualValues(t, expected.NumRows(), total)
+ assert.EqualValues(t, expected.NumRows()-1, read)
+ assert.Equal(t, []int16{1, 1, 1, 0, 1, 1}, defLevelsOut)
+ }
+}
+
+func TestWriteArrowInt96(t *testing.T) {
+ tbl := makeDateTimeTypesTable(false, false)
+ defer tbl.Release()
+
+ props := pqarrow.NewArrowWriterProperties(pqarrow.WithDeprecatedInt96Timestamps(true))
+ psc, err := pqarrow.ToParquet(tbl.Schema(), nil, props)
+ require.NoError(t, err)
+
+ manifest, err := pqarrow.NewSchemaManifest(psc, nil, nil)
+ require.NoError(t, err)
+
+ sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
+ writer := file.NewParquetWriter(sink, psc.Root())
+
+ srgw := writer.AppendRowGroup()
+ ctx := pqarrow.NewArrowWriteContext(context.TODO(), &props)
+
+ for i := int64(0); i < tbl.NumCols(); i++ {
+ acw, err := pqarrow.NewArrowColumnWriter(tbl.Column(int(i)).Data(), 0, tbl.NumRows(), manifest, srgw, int(i))
+ require.NoError(t, err)
+ require.NoError(t, acw.Write(ctx))
+ }
+ require.NoError(t, srgw.Close())
+ require.NoError(t, writer.Close())
+
+ expected := makeDateTimeTypesTable(false, false)
+ defer expected.Release()
+
+ reader, err := file.NewParquetReader(bytes.NewReader(sink.Bytes()))
+ require.NoError(t, err)
+
+ assert.EqualValues(t, expected.NumCols(), reader.MetaData().Schema.NumColumns())
+ assert.EqualValues(t, expected.NumRows(), reader.NumRows())
+ assert.EqualValues(t, 1, reader.NumRowGroups())
+
+ rgr := reader.RowGroup(0)
+ tsRdr := rgr.Column(3)
+ assert.Equal(t, parquet.Types.Int96, tsRdr.Type())
+
+ rdr := tsRdr.(*file.Int96ColumnReader)
+ vals := make([]parquet.Int96, expected.NumRows())
+ defLevels := make([]int16, int(expected.NumRows()))
+
+ total, read, _ := rdr.ReadBatch(expected.NumRows(), vals, defLevels, nil)
+ assert.EqualValues(t, expected.NumRows(), total)
+ assert.EqualValues(t, expected.NumRows()-1, read)
+ assert.Equal(t, []int16{1, 1, 1, 0, 1, 1}, defLevels)
+
+ data := expected.Column(3).Data().Chunk(0).(*array.Timestamp)
+ assert.EqualValues(t, data.Value(0), vals[0].ToTime().UnixNano())
+ assert.EqualValues(t, data.Value(1), vals[1].ToTime().UnixNano())
+ assert.EqualValues(t, data.Value(2), vals[2].ToTime().UnixNano())
+ assert.EqualValues(t, data.Value(4), vals[3].ToTime().UnixNano())
+ assert.EqualValues(t, data.Value(5), vals[4].ToTime().UnixNano())
+}
+
+func writeTableToBuffer(t *testing.T, tbl array.Table, rowGroupSize int64, props pqarrow.ArrowWriterProperties) *memory.Buffer {
+ sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
+ wrprops := parquet.NewWriterProperties(parquet.WithVersion(parquet.V1_0))
+ psc, err := pqarrow.ToParquet(tbl.Schema(), wrprops, props)
+ require.NoError(t, err)
+
+ manifest, err := pqarrow.NewSchemaManifest(psc, nil, nil)
+ require.NoError(t, err)
+
+ writer := file.NewParquetWriter(sink, psc.Root(), file.WithWriterProps(wrprops))
+ ctx := pqarrow.NewArrowWriteContext(context.TODO(), &props)
+
+ offset := int64(0)
+ for offset < tbl.NumRows() {
+ sz := utils.Min(rowGroupSize, tbl.NumRows()-offset)
+ srgw := writer.AppendRowGroup()
+ for i := 0; i < int(tbl.NumCols()); i++ {
+ col := tbl.Column(i)
+ acw, err := pqarrow.NewArrowColumnWriter(col.Data(), offset, sz, manifest, srgw, i)
+ require.NoError(t, err)
+ require.NoError(t, acw.Write(ctx))
+ }
+ srgw.Close()
+ offset += sz
+ }
+ writer.Close()
+
+ return sink.Finish()
+}
+
+func simpleRoundTrip(t *testing.T, tbl array.Table, rowGroupSize int64) {
+ buf := writeTableToBuffer(t, tbl, rowGroupSize, pqarrow.DefaultWriterProps())
+ defer buf.Release()
+
+ rdr, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()))
+ require.NoError(t, err)
+
+ ardr, err := pqarrow.NewFileReader(rdr, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
+ require.NoError(t, err)
+
+ for i := 0; i < int(tbl.NumCols()); i++ {
+ crdr, err := ardr.GetColumn(context.TODO(), i)
+ require.NoError(t, err)
+
+ chunked, err := crdr.NextBatch(tbl.NumRows())
+ require.NoError(t, err)
+
+ require.EqualValues(t, tbl.NumRows(), chunked.Len())
+
+ chunkList := tbl.Column(i).Data().Chunks()
+ offset := int64(0)
+ for _, chnk := range chunkList {
+ slc := chunked.NewSlice(offset, offset+int64(chnk.Len()))
+ defer slc.Release()
+
+ assert.EqualValues(t, chnk.Len(), slc.Len())
+ if len(slc.Chunks()) == 1 {
+ offset += int64(chnk.Len())
+ assert.True(t, array.ArrayEqual(chnk, slc.Chunk(0)))
+ }
+ }
+ }
+}
+
+func TestArrowReadWriteTableChunkedCols(t *testing.T) {
+ chunkSizes := []int{2, 4, 10, 2}
+ const totalLen = int64(18)
+
+ rng := testutils.NewRandomArrayGenerator(0)
+
+ arr := rng.Int32(totalLen, 0, math.MaxInt32/2, 0.9)
+ defer arr.Release()
+
+ offset := int64(0)
+ chunks := make([]array.Interface, 0)
+ for _, chnksize := range chunkSizes {
+ chk := array.NewSlice(arr, offset, offset+int64(chnksize))
+ defer chk.Release()
+ chunks = append(chunks, chk)
+ }
+
+ sc := arrow.NewSchema([]arrow.Field{{Name: "field", Type: arr.DataType(), Nullable: true}}, nil)
+ tbl := array.NewTable(sc, []array.Column{*array.NewColumn(sc.Field(0), array.NewChunked(arr.DataType(), chunks))}, -1)
+ defer tbl.Release()
+
+ simpleRoundTrip(t, tbl, 2)
+ simpleRoundTrip(t, tbl, 10)
+}
+
+// set this up for checking our expected results so we can test the functions
+// that generate them which we export
+func getLogicalType(typ arrow.DataType) schema.LogicalType {
+ switch typ.ID() {
+ case arrow.INT8:
+ return schema.NewIntLogicalType(8, true)
+ case arrow.UINT8:
+ return schema.NewIntLogicalType(8, false)
+ case arrow.INT16:
+ return schema.NewIntLogicalType(16, true)
+ case arrow.UINT16:
+ return schema.NewIntLogicalType(16, false)
+ case arrow.INT32:
+ return schema.NewIntLogicalType(32, true)
+ case arrow.UINT32:
+ return schema.NewIntLogicalType(32, false)
+ case arrow.INT64:
+ return schema.NewIntLogicalType(64, true)
+ case arrow.UINT64:
+ return schema.NewIntLogicalType(64, false)
+ case arrow.STRING:
+ return schema.StringLogicalType{}
+ case arrow.DATE32:
+ return schema.DateLogicalType{}
+ case arrow.DATE64:
+ return schema.DateLogicalType{}
+ case arrow.TIMESTAMP:
+ ts := typ.(*arrow.TimestampType)
+ adjustedUTC := len(ts.TimeZone) == 0
+ switch ts.Unit {
+ case arrow.Microsecond:
+ return schema.NewTimestampLogicalType(adjustedUTC, schema.TimeUnitMicros)
+ case arrow.Millisecond:
+ return schema.NewTimestampLogicalType(adjustedUTC, schema.TimeUnitMillis)
+ case arrow.Nanosecond:
+ return schema.NewTimestampLogicalType(adjustedUTC, schema.TimeUnitNanos)
+ default:
+ panic("only milli, micro and nano units supported for arrow timestamp")
+ }
+ case arrow.TIME32:
+ return schema.NewTimeLogicalType(false, schema.TimeUnitMillis)
+ case arrow.TIME64:
+ ts := typ.(*arrow.Time64Type)
+ switch ts.Unit {
+ case arrow.Microsecond:
+ return schema.NewTimeLogicalType(false, schema.TimeUnitMicros)
+ case arrow.Nanosecond:
+ return schema.NewTimeLogicalType(false, schema.TimeUnitNanos)
+ default:
+ panic("only micro and nano seconds are supported for arrow TIME64")
+ }
+ case arrow.DECIMAL:
+ dec := typ.(*arrow.Decimal128Type)
+ return schema.NewDecimalLogicalType(dec.Precision, dec.Scale)
+ }
+ return schema.NoLogicalType{}
+}
+
+func getPhysicalType(typ arrow.DataType) parquet.Type {
+ switch typ.ID() {
+ case arrow.BOOL:
+ return parquet.Types.Boolean
+ case arrow.UINT8, arrow.INT8, arrow.UINT16, arrow.INT16, arrow.UINT32, arrow.INT32:
+ return parquet.Types.Int32
+ case arrow.INT64, arrow.UINT64:
+ return parquet.Types.Int64
+ case arrow.FLOAT32:
+ return parquet.Types.Float
+ case arrow.FLOAT64:
+ return parquet.Types.Double
+ case arrow.BINARY, arrow.STRING:
+ return parquet.Types.ByteArray
+ case arrow.FIXED_SIZE_BINARY, arrow.DECIMAL:
+ return parquet.Types.FixedLenByteArray
+ case arrow.DATE32:
+ return parquet.Types.Int32
+ case arrow.DATE64:
+ // convert to date32 internally
+ return parquet.Types.Int32
+ case arrow.TIME32:
+ return parquet.Types.Int32
+ case arrow.TIME64, arrow.TIMESTAMP:
+ return parquet.Types.Int64
+ default:
+ return parquet.Types.Int32
+ }
+}
+
+const (
+ boolTestValue = true
+ uint8TestVal = uint8(64)
+ int8TestVal = int8(-64)
+ uint16TestVal = uint16(1024)
+ int16TestVal = int16(-1024)
+ uint32TestVal = uint32(1024)
+ int32TestVal = int32(-1024)
+ uint64TestVal = uint64(1024)
+ int64TestVal = int64(-1024)
+ tsTestValue = arrow.Timestamp(14695634030000)
+ date32TestVal = arrow.Date32(170000)
+ floatTestVal = float32(2.1)
+ doubleTestVal = float64(4.2)
+ strTestVal = "Test"
+
+ smallSize = 100
+)
+
+var (
+ binTestVal = []byte{0, 1, 2, 3}
+ flbaTestVal = []byte("Fixed")
+)
+
+type ParquetIOTestSuite struct {
+ suite.Suite
+}
+
+func (ps *ParquetIOTestSuite) makeSimpleSchema(typ arrow.DataType, rep parquet.Repetition) *schema.GroupNode {
+ byteWidth := int32(-1)
+
+ switch typ := typ.(type) {
+ case *arrow.FixedSizeBinaryType:
+ byteWidth = int32(typ.ByteWidth)
+ case *arrow.Decimal128Type:
+ byteWidth = pqarrow.DecimalSize(typ.Precision)
+ }
+
+ pnode, _ := schema.NewPrimitiveNodeLogical("column1", rep, getLogicalType(typ), getPhysicalType(typ), int(byteWidth), -1)
+ return schema.MustGroup(schema.NewGroupNode("schema", parquet.Repetitions.Required, schema.FieldList{pnode}, -1))
+}
+
+func (ps *ParquetIOTestSuite) makePrimitiveTestCol(size int, typ arrow.DataType) array.Interface {
+ switch typ.ID() {
+ case arrow.BOOL:
+ bldr := array.NewBooleanBuilder(memory.DefaultAllocator)
+ defer bldr.Release()
+ for i := 0; i < size; i++ {
+ bldr.Append(boolTestValue)
+ }
+ return bldr.NewArray()
+ case arrow.INT8:
+ bldr := array.NewInt8Builder(memory.DefaultAllocator)
+ defer bldr.Release()
+ for i := 0; i < size; i++ {
+ bldr.Append(int8TestVal)
+ }
+ return bldr.NewArray()
+ case arrow.UINT8:
+ bldr := array.NewUint8Builder(memory.DefaultAllocator)
+ defer bldr.Release()
+ for i := 0; i < size; i++ {
+ bldr.Append(uint8TestVal)
+ }
+ return bldr.NewArray()
+ case arrow.INT16:
+ bldr := array.NewInt16Builder(memory.DefaultAllocator)
+ defer bldr.Release()
+ for i := 0; i < size; i++ {
+ bldr.Append(int16TestVal)
+ }
+ return bldr.NewArray()
+ case arrow.UINT16:
+ bldr := array.NewUint16Builder(memory.DefaultAllocator)
+ defer bldr.Release()
+ for i := 0; i < size; i++ {
+ bldr.Append(uint16TestVal)
+ }
+ return bldr.NewArray()
+ case arrow.INT32:
+ bldr := array.NewInt32Builder(memory.DefaultAllocator)
+ defer bldr.Release()
+ for i := 0; i < size; i++ {
+ bldr.Append(int32TestVal)
+ }
+ return bldr.NewArray()
+ case arrow.UINT32:
+ bldr := array.NewUint32Builder(memory.DefaultAllocator)
+ defer bldr.Release()
+ for i := 0; i < size; i++ {
+ bldr.Append(uint32TestVal)
+ }
+ return bldr.NewArray()
+ case arrow.INT64:
+ bldr := array.NewInt64Builder(memory.DefaultAllocator)
+ defer bldr.Release()
+ for i := 0; i < size; i++ {
+ bldr.Append(int64TestVal)
+ }
+ return bldr.NewArray()
+ case arrow.UINT64:
+ bldr := array.NewUint64Builder(memory.DefaultAllocator)
+ defer bldr.Release()
+ for i := 0; i < size; i++ {
+ bldr.Append(uint64TestVal)
+ }
+ return bldr.NewArray()
+ case arrow.FLOAT32:
+ bldr := array.NewFloat32Builder(memory.DefaultAllocator)
+ defer bldr.Release()
+ for i := 0; i < size; i++ {
+ bldr.Append(floatTestVal)
+ }
+ return bldr.NewArray()
+ case arrow.FLOAT64:
+ bldr := array.NewFloat64Builder(memory.DefaultAllocator)
+ defer bldr.Release()
+ for i := 0; i < size; i++ {
+ bldr.Append(doubleTestVal)
+ }
+ return bldr.NewArray()
+ }
+ return nil
+}
+
+func (ps *ParquetIOTestSuite) makeTestFile(typ arrow.DataType, arr array.Interface, numChunks int) []byte {
+ sc := ps.makeSimpleSchema(typ, parquet.Repetitions.Required)
+ sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
+ writer := file.NewParquetWriter(sink, sc)
+
+ ctx := pqarrow.NewArrowWriteContext(context.TODO(), nil)
+ rowGroupSize := arr.Len() / numChunks
+
+ for i := 0; i < numChunks; i++ {
+ rgw := writer.AppendRowGroup()
+ cw, err := rgw.NextColumn()
+ ps.NoError(err)
+
+ start := i * rowGroupSize
+ ps.NoError(pqarrow.WriteArrowToColumn(ctx, cw, array.NewSlice(arr, int64(start), int64(start+rowGroupSize)), nil, nil, false))
+ cw.Close()
+ rgw.Close()
+ }
+ writer.Close()
+ buf := sink.Finish()
+ defer buf.Release()
+ return buf.Bytes()
+}
+
+func (ps *ParquetIOTestSuite) createReader(data []byte) *pqarrow.FileReader {
+ rdr, err := file.NewParquetReader(bytes.NewReader(data))
+ ps.NoError(err)
+
+ reader, err := pqarrow.NewFileReader(rdr, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
+ ps.NoError(err)
+ return reader
+}
+
+func (ps *ParquetIOTestSuite) readTable(rdr *pqarrow.FileReader) array.Table {
+ tbl, err := rdr.ReadTable(context.TODO())
+ ps.NoError(err)
+ ps.NotNil(tbl)
+ return tbl
+}
+
+func (ps *ParquetIOTestSuite) checkSingleColumnRequiredTableRead(typ arrow.DataType, numChunks int) {
+ values := ps.makePrimitiveTestCol(smallSize, typ)
+ defer values.Release()
+
+ data := ps.makeTestFile(typ, values, numChunks)
+ reader := ps.createReader(data)
+
+ tbl := ps.readTable(reader)
+ defer tbl.Release()
+
+ ps.EqualValues(1, tbl.NumCols())
+ ps.EqualValues(smallSize, tbl.NumRows())
+
+ chunked := tbl.Column(0).Data()
+ defer chunked.Release()
+ ps.Len(chunked.Chunks(), 1)
+ ps.True(array.ArrayEqual(values, chunked.Chunk(0)))
+}
+
+func (ps *ParquetIOTestSuite) checkSingleColumnRead(typ arrow.DataType, numChunks int) {
+ values := ps.makePrimitiveTestCol(smallSize, typ)
+ defer values.Release()
+
+ data := ps.makeTestFile(typ, values, numChunks)
+ reader := ps.createReader(data)
+
+ cr, err := reader.GetColumn(context.TODO(), 0)
+ ps.NoError(err)
+
+ chunked, err := cr.NextBatch(smallSize)
+ ps.NoError(err)
+ defer chunked.Release()
+
+ ps.Len(chunked.Chunks(), 1)
+ ps.True(array.ArrayEqual(values, chunked.Chunk(0)))
+}
+
+func (ps *ParquetIOTestSuite) TestDateTimeTypesReadWriteTable() {
+ toWrite := makeDateTimeTypesTable(false, true)
+ defer toWrite.Release()
+ buf := writeTableToBuffer(ps.T(), toWrite, toWrite.NumRows(), pqarrow.DefaultWriterProps())
+ defer buf.Release()
+
+ reader := ps.createReader(buf.Bytes())
+ tbl := ps.readTable(reader)
+ defer tbl.Release()
+
+ expected := makeDateTimeTypesTable(true, true)
+ defer expected.Release()
+
+ ps.Equal(expected.NumCols(), tbl.NumCols())
+ ps.Equal(expected.NumRows(), tbl.NumRows())
+ ps.Truef(expected.Schema().Equal(tbl.Schema()), "expected schema: %s\ngot schema: %s", expected.Schema(), tbl.Schema())
+
+ for i := 0; i < int(expected.NumCols()); i++ {
+ exChunk := expected.Column(i).Data()
+ tblChunk := tbl.Column(i).Data()
+
+ ps.Equal(len(exChunk.Chunks()), len(tblChunk.Chunks()))
+ ps.Truef(array.ArrayEqual(exChunk.Chunk(0), tblChunk.Chunk(0)), "expected %s\ngot %s", exChunk.Chunk(0), tblChunk.Chunk(0))
+ }
+}
+
+func (ps *ParquetIOTestSuite) TestDateTimeTypesWithInt96ReadWriteTable() {
+ expected := makeDateTimeTypesTable(false, true)
+ defer expected.Release()
+ buf := writeTableToBuffer(ps.T(), expected, expected.NumRows(), pqarrow.NewArrowWriterProperties(pqarrow.WithDeprecatedInt96Timestamps(true)))
+ defer buf.Release()
+
+ reader := ps.createReader(buf.Bytes())
+ tbl := ps.readTable(reader)
+ defer tbl.Release()
+
+ ps.Equal(expected.NumCols(), tbl.NumCols())
+ ps.Equal(expected.NumRows(), tbl.NumRows())
+ ps.Truef(expected.Schema().Equal(tbl.Schema()), "expected schema: %s\ngot schema: %s", expected.Schema(), tbl.Schema())
+
+ for i := 0; i < int(expected.NumCols()); i++ {
+ exChunk := expected.Column(i).Data()
+ tblChunk := tbl.Column(i).Data()
+
+ ps.Equal(len(exChunk.Chunks()), len(tblChunk.Chunks()))
+ ps.Truef(array.ArrayEqual(exChunk.Chunk(0), tblChunk.Chunk(0)), "expected %s\ngot %s", exChunk.Chunk(0), tblChunk.Chunk(0))
+ }
+}
+
+func (ps *ParquetIOTestSuite) TestReadSingleColumnFile() {
+ types := []arrow.DataType{
+ arrow.FixedWidthTypes.Boolean,
+ arrow.PrimitiveTypes.Uint8,
+ arrow.PrimitiveTypes.Int8,
+ arrow.PrimitiveTypes.Uint16,
+ arrow.PrimitiveTypes.Int16,
+ arrow.PrimitiveTypes.Uint32,
+ arrow.PrimitiveTypes.Int32,
+ arrow.PrimitiveTypes.Uint64,
+ arrow.PrimitiveTypes.Int64,
+ arrow.PrimitiveTypes.Float32,
+ arrow.PrimitiveTypes.Float64,
+ }
+
+ nchunks := []int{1, 4}
+
+ for _, n := range nchunks {
+ for _, dt := range types {
+ ps.Run(fmt.Sprintf("%s %d chunks", dt.Name(), n), func() {
+ ps.checkSingleColumnRead(dt, n)
+ })
+ }
+ }
+}
+
+func (ps *ParquetIOTestSuite) TestSingleColumnRequiredRead() {
+ types := []arrow.DataType{
+ arrow.FixedWidthTypes.Boolean,
+ arrow.PrimitiveTypes.Uint8,
+ arrow.PrimitiveTypes.Int8,
+ arrow.PrimitiveTypes.Uint16,
+ arrow.PrimitiveTypes.Int16,
+ arrow.PrimitiveTypes.Uint32,
+ arrow.PrimitiveTypes.Int32,
+ arrow.PrimitiveTypes.Uint64,
+ arrow.PrimitiveTypes.Int64,
+ arrow.PrimitiveTypes.Float32,
+ arrow.PrimitiveTypes.Float64,
+ }
+
+ nchunks := []int{1, 4}
+
+ for _, n := range nchunks {
+ for _, dt := range types {
+ ps.Run(fmt.Sprintf("%s %d chunks", dt.Name(), n), func() {
+ ps.checkSingleColumnRequiredTableRead(dt, n)
+ })
+ }
+ }
+}
+
+func (ps *ParquetIOTestSuite) TestReadDecimals() {
+ bigEndian := []parquet.ByteArray{
+ // 123456
+ []byte{1, 226, 64},
+ // 987654
+ []byte{15, 18, 6},
+ // -123456
+ []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 254, 29, 192},
+ }
+
+ bldr := array.NewDecimal128Builder(memory.DefaultAllocator, &arrow.Decimal128Type{Precision: 6, Scale: 3})
+ defer bldr.Release()
+
+ bldr.Append(decimal128.FromU64(123456))
+ bldr.Append(decimal128.FromU64(987654))
+ bldr.Append(decimal128.FromI64(-123456))
+
+ expected := bldr.NewDecimal128Array()
+ defer expected.Release()
+
+ sc := schema.MustGroup(schema.NewGroupNode("schema", parquet.Repetitions.Required, schema.FieldList{
+ schema.Must(schema.NewPrimitiveNodeLogical("decimals", parquet.Repetitions.Required, schema.NewDecimalLogicalType(6, 3), parquet.Types.ByteArray, -1, -1)),
+ }, -1))
+
+ sink := encoding.NewBufferWriter(0, memory.DefaultAllocator)
+ writer := file.NewParquetWriter(sink, sc)
+
+ rgw := writer.AppendRowGroup()
+ cw, _ := rgw.NextColumn()
+ cw.(*file.ByteArrayColumnWriter).WriteBatch(bigEndian, nil, nil)
+ cw.Close()
+ rgw.Close()
+ writer.Close()
+
+ rdr := ps.createReader(sink.Bytes())
+ cr, err := rdr.GetColumn(context.TODO(), 0)
+ ps.NoError(err)
+
+ chunked, err := cr.NextBatch(smallSize)
+ ps.NoError(err)
+ defer chunked.Release()
+
+ ps.Len(chunked.Chunks(), 1)
+ ps.True(array.ArrayEqual(expected, chunked.Chunk(0)))
+}
+
+func (ps *ParquetIOTestSuite) writeColumn(sc *schema.GroupNode, values array.Interface) []byte {
+ var buf bytes.Buffer
+ arrsc, err := pqarrow.FromParquet(schema.NewSchema(sc), nil, nil)
+ ps.NoError(err)
+
+ writer, err := pqarrow.NewFileWriter(arrsc, &buf, parquet.NewWriterProperties(parquet.WithDictionaryDefault(false)), pqarrow.DefaultWriterProps())
+ ps.NoError(err)
+
+ writer.NewRowGroup()
+ ps.NoError(writer.WriteColumnData(values))
+ ps.NoError(writer.Close())
+ ps.NoError(writer.Close())
+
+ return buf.Bytes()
+}
+
+func (ps *ParquetIOTestSuite) readAndCheckSingleColumnFile(data []byte, values array.Interface) {
+ reader := ps.createReader(data)
+ cr, err := reader.GetColumn(context.TODO(), 0)
+ ps.NoError(err)
+ ps.NotNil(cr)
+
+ chunked, err := cr.NextBatch(smallSize)
+ ps.NoError(err)
+ defer chunked.Release()
+
+ ps.Len(chunked.Chunks(), 1)
+ ps.NotNil(chunked.Chunk(0))
+
+ ps.True(array.ArrayEqual(values, chunked.Chunk(0)))
+}
+
+var fullTypeList = []arrow.DataType{
+ arrow.FixedWidthTypes.Boolean,
+ arrow.PrimitiveTypes.Uint8,
+ arrow.PrimitiveTypes.Int8,
+ arrow.PrimitiveTypes.Uint16,
+ arrow.PrimitiveTypes.Int16,
+ arrow.PrimitiveTypes.Uint32,
+ arrow.PrimitiveTypes.Int32,
+ arrow.PrimitiveTypes.Uint64,
+ arrow.PrimitiveTypes.Int64,
+ arrow.FixedWidthTypes.Date32,
+ arrow.PrimitiveTypes.Float32,
+ arrow.PrimitiveTypes.Float64,
+ arrow.BinaryTypes.String,
+ arrow.BinaryTypes.Binary,
+ &arrow.FixedSizeBinaryType{ByteWidth: 10},
+ &arrow.Decimal128Type{Precision: 1, Scale: 0},
+ &arrow.Decimal128Type{Precision: 5, Scale: 4},
+ &arrow.Decimal128Type{Precision: 10, Scale: 9},
+ &arrow.Decimal128Type{Precision: 19, Scale: 18},
+ &arrow.Decimal128Type{Precision: 23, Scale: 22},
+ &arrow.Decimal128Type{Precision: 27, Scale: 26},
+ &arrow.Decimal128Type{Precision: 38, Scale: 37},
+}
+
+func (ps *ParquetIOTestSuite) TestSingleColumnRequiredWrite() {
+ for _, dt := range fullTypeList {
+ ps.Run(dt.Name(), func() {
+ values := testutils.RandomNonNull(dt, smallSize)
+ sc := ps.makeSimpleSchema(dt, parquet.Repetitions.Required)
+ data := ps.writeColumn(sc, values)
+ ps.readAndCheckSingleColumnFile(data, values)
+ })
+ }
+}
+
+func (ps *ParquetIOTestSuite) roundTripTable(expected array.Table, storeSchema bool) {
+ var buf bytes.Buffer
+ var props pqarrow.ArrowWriterProperties
+ if storeSchema {
+ props = pqarrow.NewArrowWriterProperties(pqarrow.WithStoreSchema())
+ } else {
+ props = pqarrow.DefaultWriterProps()
+ }
+
+ ps.Require().NoError(pqarrow.WriteTable(expected, &buf, expected.NumRows(), nil, props, memory.DefaultAllocator))
+
+ reader := ps.createReader(buf.Bytes())
+ tbl := ps.readTable(reader)
+ defer tbl.Release()
+
+ ps.Equal(expected.NumCols(), tbl.NumCols())
+ ps.Equal(expected.NumRows(), tbl.NumRows())
+
+ exChunk := expected.Column(0).Data()
+ tblChunk := tbl.Column(0).Data()
+
+ ps.Equal(len(exChunk.Chunks()), len(tblChunk.Chunks()))
+ if exChunk.DataType().ID() != arrow.STRUCT {
+ ps.Truef(array.ArrayEqual(exChunk.Chunk(0), tblChunk.Chunk(0)), "expected: %s\ngot: %s", exChunk.Chunk(0), tblChunk.Chunk(0))
+ } else {
+ // current impl of ArrayEquals for structs doesn't correctly handle nulls in the parent
+ // with a non-nullable child when comparing. Since after the round trip, the data in the
+ // child will have the nulls, not the original data.
+ ex := exChunk.Chunk(0)
+ tb := tblChunk.Chunk(0)
+ ps.Equal(ex.NullN(), tb.NullN())
+ if ex.NullN() > 0 {
+ ps.Equal(ex.NullBitmapBytes()[:int(bitutil.BytesForBits(int64(ex.Len())))], tb.NullBitmapBytes()[:int(bitutil.BytesForBits(int64(tb.Len())))])
+ }
+ ps.Equal(ex.Len(), tb.Len())
+ // only compare the non-null values
+ ps.NoErrorf(utils.VisitSetBitRuns(ex.NullBitmapBytes(), int64(ex.Data().Offset()), int64(ex.Len()), func(pos, length int64) error {
+ if !ps.True(array.ArraySliceEqual(ex, pos, pos+length, tb, pos, pos+length)) {
+ return errors.New("failed")
+ }
+ return nil
+ }), "expected: %s\ngot: %s", ex, tb)
+ }
+}
+
+func makeEmptyListsArray(size int) array.Interface {
+ // allocate an offsets buffer with only zeros
+ offsetsNbytes := arrow.Int32Traits.BytesRequired(size + 1)
+ offsetsBuffer := make([]byte, offsetsNbytes)
+
+ childBuffers := []*memory.Buffer{nil, nil}
+ childData := array.NewData(arrow.PrimitiveTypes.Float32, 0, childBuffers, nil, 0, 0)
+ defer childData.Release()
+ buffers := []*memory.Buffer{nil, memory.NewBufferBytes(offsetsBuffer)}
+ arrayData := array.NewData(arrow.ListOf(childData.DataType()), size, buffers, []*array.Data{childData}, 0, 0)
+ defer arrayData.Release()
+ return array.MakeFromData(arrayData)
+}
+
+func makeListArray(values array.Interface, size, nullcount int) array.Interface {
+ nonNullEntries := size - nullcount - 1
+ lengthPerEntry := values.Len() / nonNullEntries
+
+ offsets := make([]byte, arrow.Int32Traits.BytesRequired(size+1))
+ offsetsArr := arrow.Int32Traits.CastFromBytes(offsets)
+
+ nullBitmap := make([]byte, int(bitutil.BytesForBits(int64(size))))
+
+ curOffset := 0
+ for i := 0; i < size; i++ {
+ offsetsArr[i] = int32(curOffset)
+ if !(((i % 2) == 0) && ((i / 2) < nullcount)) {
+ // non-null list (list with index 1 is always empty)
+ bitutil.SetBit(nullBitmap, i)
+ if i != 1 {
+ curOffset += lengthPerEntry
+ }
+ }
+ }
+ offsetsArr[size] = int32(values.Len())
+
+ return array.NewListData(array.NewData(arrow.ListOf(values.DataType()), size,
+ []*memory.Buffer{memory.NewBufferBytes(nullBitmap), memory.NewBufferBytes(offsets)},
+ []*array.Data{values.Data()}, nullcount, 0))
+}
+
+func prepareEmptyListsTable(size int) array.Table {
+ lists := makeEmptyListsArray(size)
+ return makeSimpleTable(array.NewChunked(lists.DataType(), []array.Interface{lists}), true)
+}
+
+func prepareListTable(dt arrow.DataType, size int, nullableLists bool, nullableElems bool, nullCount int) array.Table {
+ nc := nullCount
+ if !nullableElems {
+ nc = 0
+ }
+ values := testutils.RandomNullable(dt, size*size, nc)
+ defer values.Release()
+ // also test that slice offsets are respected
+ values = array.NewSlice(values, 5, int64(values.Len()))
+ defer values.Release()
+
+ if !nullableLists {
+ nullCount = 0
+ }
+ lists := makeListArray(values, size, nullCount)
+ defer lists.Release()
+
+ chunked := array.NewChunked(lists.DataType(), []array.Interface{lists})
+ defer chunked.Release()
+
+ return makeSimpleTable(chunked.NewSlice(3, int64(size)), nullableLists)
+}
+
+func prepareListOfListTable(dt arrow.DataType, size, nullCount int, nullableParentLists, nullableLists, nullableElems bool) array.Table {
+ nc := nullCount
+ if !nullableElems {
+ nc = 0
+ }
+
+ values := testutils.RandomNullable(dt, size*6, nc)
+ defer values.Release()
+
+ if nullableLists {
+ nc = nullCount
+ } else {
+ nc = 0
+ }
+
+ lists := makeListArray(values, size*3, nc)
+ defer lists.Release()
+
+ if !nullableParentLists {
+ nullCount = 0
+ }
+
+ parentLists := makeListArray(lists, size, nullCount)
+ defer parentLists.Release()
+
+ chunked := array.NewChunked(parentLists.DataType(), []array.Interface{parentLists})
+ defer chunked.Release()
+
+ return makeSimpleTable(chunked, nullableParentLists)
+}
+
+func (ps *ParquetIOTestSuite) TestSingleEmptyListsColumnReadWrite() {
+ expected := prepareEmptyListsTable(smallSize)
+ buf := writeTableToBuffer(ps.T(), expected, smallSize, pqarrow.DefaultWriterProps())
+ defer buf.Release()
+
+ reader := ps.createReader(buf.Bytes())
+ tbl := ps.readTable(reader)
+ defer tbl.Release()
+
+ ps.EqualValues(expected.NumCols(), tbl.NumCols())
+ ps.EqualValues(expected.NumRows(), tbl.NumRows())
+
+ exChunk := expected.Column(0).Data()
+ tblChunk := tbl.Column(0).Data()
+
+ ps.Equal(len(exChunk.Chunks()), len(tblChunk.Chunks()))
+ ps.True(array.ArrayEqual(exChunk.Chunk(0), tblChunk.Chunk(0)))
+}
+
+func (ps *ParquetIOTestSuite) TestSingleColumnOptionalReadWrite() {
+ for _, dt := range fullTypeList {
+ ps.Run(dt.Name(), func() {
+ values := testutils.RandomNullable(dt, smallSize, 10)
+ sc := ps.makeSimpleSchema(dt, parquet.Repetitions.Optional)
+ data := ps.writeColumn(sc, values)
+ ps.readAndCheckSingleColumnFile(data, values)
+ })
+ }
+}
+
+func (ps *ParquetIOTestSuite) TestSingleNullableListNullableColumnReadWrite() {
+ for _, dt := range fullTypeList {
+ ps.Run(dt.Name(), func() {
+ expected := prepareListTable(dt, smallSize, true, true, 10)
+ defer expected.Release()
+ ps.roundTripTable(expected, false)
+ })
+ }
+}
+
+func (ps *ParquetIOTestSuite) TestSingleRequiredListNullableColumnReadWrite() {
+ for _, dt := range fullTypeList {
+ ps.Run(dt.Name(), func() {
+ expected := prepareListTable(dt, smallSize, false, true, 10)
+ defer expected.Release()
+ ps.roundTripTable(expected, false)
+ })
+ }
+}
+
+func (ps *ParquetIOTestSuite) TestSingleNullableListRequiredColumnReadWrite() {
+ for _, dt := range fullTypeList {
+ ps.Run(dt.Name(), func() {
+ expected := prepareListTable(dt, smallSize, true, false, 10)
+ defer expected.Release()
+ ps.roundTripTable(expected, false)
+ })
+ }
+}
+
+func (ps *ParquetIOTestSuite) TestSingleRequiredListRequiredColumnReadWrite() {
+ for _, dt := range fullTypeList {
+ ps.Run(dt.Name(), func() {
+ expected := prepareListTable(dt, smallSize, false, false, 0)
+ defer expected.Release()
+ ps.roundTripTable(expected, false)
+ })
+ }
+}
+
+func (ps *ParquetIOTestSuite) TestSingleNullableListRequiredListRequiredColumnReadWrite() {
+ for _, dt := range fullTypeList {
+ ps.Run(dt.Name(), func() {
+ expected := prepareListOfListTable(dt, smallSize, 2, true, false, false)
+ defer expected.Release()
+ ps.roundTripTable(expected, false)
+ })
+ }
+}
+
+func (ps *ParquetIOTestSuite) TestSimpleStruct() {
+ links := arrow.StructOf(arrow.Field{Name: "Backward", Type: arrow.PrimitiveTypes.Int64, Nullable: true},
+ arrow.Field{Name: "Forward", Type: arrow.PrimitiveTypes.Int64, Nullable: true})
+
+ bldr := array.NewStructBuilder(memory.DefaultAllocator, links)
+ defer bldr.Release()
+
+ backBldr := bldr.FieldBuilder(0).(*array.Int64Builder)
+ forwardBldr := bldr.FieldBuilder(1).(*array.Int64Builder)
+
+ bldr.Append(true)
+ backBldr.AppendNull()
+ forwardBldr.Append(20)
+
+ bldr.Append(true)
+ backBldr.Append(10)
+ forwardBldr.Append(40)
+
+ data := bldr.NewArray()
+ defer data.Release()
+
+ tbl := array.NewTable(arrow.NewSchema([]arrow.Field{{Name: "links", Type: links}}, nil),
+ []array.Column{*array.NewColumn(arrow.Field{Name: "links", Type: links}, array.NewChunked(links, []array.Interface{data}))}, -1)
+ defer tbl.Release()
+
+ ps.roundTripTable(tbl, false)
+}
+
+func (ps *ParquetIOTestSuite) TestSingleColumnNullableStruct() {
+ links := arrow.StructOf(arrow.Field{Name: "Backward", Type: arrow.PrimitiveTypes.Int64, Nullable: true})
+ bldr := array.NewStructBuilder(memory.DefaultAllocator, links)
+ defer bldr.Release()
+
+ backBldr := bldr.FieldBuilder(0).(*array.Int64Builder)
+
+ bldr.AppendNull()
+ bldr.Append(true)
+ backBldr.Append(10)
+
+ data := bldr.NewArray()
+ defer data.Release()
+
+ tbl := array.NewTable(arrow.NewSchema([]arrow.Field{{Name: "links", Type: links, Nullable: true}}, nil),
+ []array.Column{*array.NewColumn(arrow.Field{Name: "links", Type: links, Nullable: true}, array.NewChunked(links, []array.Interface{data}))}, -1)
+ defer tbl.Release()
+
+ ps.roundTripTable(tbl, false)
+}
+
+func (ps *ParquetIOTestSuite) TestNestedRequiredFieldStruct() {
+ intField := arrow.Field{Name: "int_array", Type: arrow.PrimitiveTypes.Int32}
+ intBldr := array.NewInt32Builder(memory.DefaultAllocator)
+ defer intBldr.Release()
+ intBldr.AppendValues([]int32{0, 1, 2, 3, 4, 5, 7, 8}, nil)
+
+ intArr := intBldr.NewArray()
+ defer intArr.Release()
+
+ validity := memory.NewBufferBytes([]byte{0xCC})
+ defer validity.Release()
+
+ structField := arrow.Field{Name: "root", Type: arrow.StructOf(intField), Nullable: true}
+ stData := array.NewStructData(array.NewData(structField.Type, 8, []*memory.Buffer{validity}, []*array.Data{intArr.Data()}, 4, 0))
+ defer stData.Release()
+
+ tbl := array.NewTable(arrow.NewSchema([]arrow.Field{structField}, nil),
+ []array.Column{*array.NewColumn(structField,
+ array.NewChunked(structField.Type, []array.Interface{stData}))}, -1)
+ defer tbl.Release()
+
+ ps.roundTripTable(tbl, false)
+}
+
+func (ps *ParquetIOTestSuite) TestNestedNullableField() {
+ intField := arrow.Field{Name: "int_array", Type: arrow.PrimitiveTypes.Int32, Nullable: true}
+ intBldr := array.NewInt32Builder(memory.DefaultAllocator)
+ defer intBldr.Release()
+ intBldr.AppendValues([]int32{0, 1, 2, 3, 4, 5, 7, 8}, []bool{true, false, true, false, true, true, false, true})
+
+ intArr := intBldr.NewArray()
+ defer intArr.Release()
+
+ validity := memory.NewBufferBytes([]byte{0xCC})
+ defer validity.Release()
+
+ structField := arrow.Field{Name: "root", Type: arrow.StructOf(intField), Nullable: true}
+ stData := array.NewStructData(array.NewData(structField.Type, 8, []*memory.Buffer{validity}, []*array.Data{intArr.Data()}, 4, 0))
+ defer stData.Release()
+
+ tbl := array.NewTable(arrow.NewSchema([]arrow.Field{structField}, nil),
+ []array.Column{*array.NewColumn(structField,
+ array.NewChunked(structField.Type, []array.Interface{stData}))}, -1)
+ defer tbl.Release()
+
+ ps.roundTripTable(tbl, false)
+}
+
+func (ps *ParquetIOTestSuite) TestCanonicalNestedRoundTrip() {
+ docIdField := arrow.Field{Name: "DocID", Type: arrow.PrimitiveTypes.Int64}
+ linksField := arrow.Field{Name: "Links", Type: arrow.StructOf(
+ arrow.Field{Name: "Backward", Type: arrow.ListOf(arrow.PrimitiveTypes.Int64)},
+ arrow.Field{Name: "Forward", Type: arrow.ListOf(arrow.PrimitiveTypes.Int64)},
+ ), Nullable: true}
+
+ nameStruct := arrow.StructOf(
+ arrow.Field{Name: "Language", Nullable: true, Type: arrow.ListOf(
+ arrow.StructOf(arrow.Field{Name: "Code", Type: arrow.BinaryTypes.String},
+ arrow.Field{Name: "Country", Type: arrow.BinaryTypes.String, Nullable: true}))},
+ arrow.Field{Name: "Url", Type: arrow.BinaryTypes.String, Nullable: true})
+
+ nameField := arrow.Field{Name: "Name", Type: arrow.ListOf(nameStruct)}
+ sc := arrow.NewSchema([]arrow.Field{docIdField, linksField, nameField}, nil)
+
+ docBldr := array.NewInt64Builder(memory.DefaultAllocator)
+ defer docBldr.Release()
+
+ docBldr.AppendValues([]int64{10, 20}, nil)
+ docIdArr := docBldr.NewArray()
+
+ linkBldr := array.NewStructBuilder(memory.DefaultAllocator, linksField.Type.(*arrow.StructType))
+ defer linkBldr.Release()
+
+ backBldr := linkBldr.FieldBuilder(0).(*array.ListBuilder)
+ forwBldr := linkBldr.FieldBuilder(1).(*array.ListBuilder)
+
+ backVb := backBldr.ValueBuilder().(*array.Int64Builder)
+ forwVb := forwBldr.ValueBuilder().(*array.Int64Builder)
+
+ linkBldr.Append(true)
+ backBldr.Append(true)
+ forwBldr.Append(true)
+ forwVb.AppendValues([]int64{20, 40, 60}, nil)
+
+ linkBldr.Append(true)
+ backBldr.Append(true)
+ backVb.AppendValues([]int64{10, 30}, nil)
+ forwBldr.Append(true)
+ forwVb.AppendValues([]int64{80}, nil)
+
+ linkArr := linkBldr.NewArray()
+
+ nameBldr := array.NewBuilder(memory.DefaultAllocator, nameField.Type).(*array.ListBuilder)
+ nameStructBldr := nameBldr.ValueBuilder().(*array.StructBuilder)
+ langListBldr := nameStructBldr.FieldBuilder(0).(*array.ListBuilder)
+ urlBldr := nameStructBldr.FieldBuilder(1).(*array.StringBuilder)
+ langStructBldr := langListBldr.ValueBuilder().(*array.StructBuilder)
+ codeBldr := langStructBldr.FieldBuilder(0).(*array.StringBuilder)
+ countryBldr := langStructBldr.FieldBuilder(1).(*array.StringBuilder)
+
+ nameBldr.Append(true)
+ nameStructBldr.Append(true)
+ langListBldr.Append(true)
+ langStructBldr.Append(true)
+ codeBldr.Append("en_us")
+ countryBldr.Append("us")
+ langStructBldr.Append(true)
+ codeBldr.Append("en_us")
+ countryBldr.AppendNull()
+ urlBldr.Append("http://A")
+
+ nameStructBldr.Append(true)
+ langListBldr.AppendNull()
+ urlBldr.Append("http://B")
+
+ nameStructBldr.Append(true)
+ langListBldr.Append(true)
+ langStructBldr.Append(true)
+ codeBldr.Append("en-gb")
+ countryBldr.Append("gb")
+ urlBldr.AppendNull()
+
+ nameBldr.Append(true)
+ nameStructBldr.Append(true)
+ langListBldr.AppendNull()
+ urlBldr.Append("http://C")
+
+ nameArr := nameBldr.NewArray()
+
+ expected := array.NewTable(sc, []array.Column{
+ *array.NewColumn(docIdField, array.NewChunked(docIdField.Type, []array.Interface{docIdArr})),
+ *array.NewColumn(linksField, array.NewChunked(linksField.Type, []array.Interface{linkArr})),
+ *array.NewColumn(nameField, array.NewChunked(nameField.Type, []array.Interface{nameArr})),
+ }, 2)
+
+ ps.roundTripTable(expected, false)
+}
+
+func (ps *ParquetIOTestSuite) TestFixedSizeList() {
+ bldr := array.NewFixedSizeListBuilder(memory.DefaultAllocator, 3, arrow.PrimitiveTypes.Int16)
+ defer bldr.Release()
+
+ vb := bldr.ValueBuilder().(*array.Int16Builder)
+
+ bldr.AppendValues([]bool{true, true, true})
+ vb.AppendValues([]int16{1, 2, 3, 4, 5, 6, 7, 8, 9}, nil)
+
+ data := bldr.NewArray()
+ field := arrow.Field{Name: "root", Type: data.DataType(), Nullable: true}
+ expected := array.NewTable(arrow.NewSchema([]arrow.Field{field}, nil),
+ []array.Column{*array.NewColumn(field, array.NewChunked(field.Type, []array.Interface{data}))}, -1)
+
+ ps.roundTripTable(expected, true)
+}
+
+func TestParquetArrowIO(t *testing.T) {
+ suite.Run(t, new(ParquetIOTestSuite))
+}
+
+func TestBufferedRecWrite(t *testing.T) {
+ sc := arrow.NewSchema([]arrow.Field{
+ {Name: "f32", Type: arrow.PrimitiveTypes.Float32, Nullable: true},
+ {Name: "i32", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
+ {Name: "struct_i64_f64", Type: arrow.StructOf(
+ arrow.Field{Name: "i64", Type: arrow.PrimitiveTypes.Int64, Nullable: true},
+ arrow.Field{Name: "f64", Type: arrow.PrimitiveTypes.Float64, Nullable: true})},
+ }, nil)
+
+ cols := []array.Interface{
+ testutils.RandomNullable(sc.Field(0).Type, SIZELEN, SIZELEN/5),
+ testutils.RandomNullable(sc.Field(1).Type, SIZELEN, SIZELEN/5),
+ array.NewStructData(array.NewData(sc.Field(2).Type, SIZELEN,
+ []*memory.Buffer{nil, nil},
+ []*array.Data{testutils.RandomNullable(arrow.PrimitiveTypes.Int64, SIZELEN, 0).Data(), testutils.RandomNullable(arrow.PrimitiveTypes.Float64, SIZELEN, 0).Data()}, 0, 0)),
+ }
+
+ rec := array.NewRecord(sc, cols, SIZELEN)
+ defer rec.Release()
+
+ var (
+ buf bytes.Buffer
+ )
+
+ wr, err := pqarrow.NewFileWriter(sc, &buf,
+ parquet.NewWriterProperties(parquet.WithCompression(compress.Codecs.Snappy), parquet.WithDictionaryDefault(false), parquet.WithDataPageSize(100*1024)),
+ pqarrow.DefaultWriterProps())
+ require.NoError(t, err)
+
+ p1 := rec.NewSlice(0, SIZELEN/2)
+ defer p1.Release()
+ require.NoError(t, wr.WriteBuffered(p1))
+
+ p2 := rec.NewSlice(SIZELEN/2, SIZELEN)
+ defer p2.Release()
+ require.NoError(t, wr.WriteBuffered(p2))
+
+ wr.Close()
+
+ rdr, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()))
+ assert.NoError(t, err)
+
+ assert.EqualValues(t, 1, rdr.NumRowGroups())
+ assert.EqualValues(t, SIZELEN, rdr.NumRows())
+ rdr.Close()
+
+ tbl, err := pqarrow.ReadTable(context.Background(), bytes.NewReader(buf.Bytes()), nil, pqarrow.ArrowReadProperties{}, nil)
+ assert.NoError(t, err)
+ defer tbl.Release()
+
+ assert.EqualValues(t, SIZELEN, tbl.NumRows())
+}
+
+func (ps *ParquetIOTestSuite) TestArrowMapTypeRoundTrip() {
+ bldr := array.NewMapBuilder(memory.DefaultAllocator, arrow.BinaryTypes.String, arrow.PrimitiveTypes.Int32, false)
+ defer bldr.Release()
+
+ kb := bldr.KeyBuilder().(*array.StringBuilder)
+ ib := bldr.ItemBuilder().(*array.Int32Builder)
+
+ bldr.Append(true)
+ kb.AppendValues([]string{"Fee", "Fi", "Fo", "Fum"}, nil)
+ ib.AppendValues([]int32{1, 2, 3, 4}, nil)
+
+ bldr.Append(true)
+ kb.AppendValues([]string{"Fee", "Fi", "Fo"}, nil)
+ ib.AppendValues([]int32{5, 4, 3}, nil)
+
+ bldr.AppendNull()
+
+ bldr.Append(true)
+ kb.AppendValues([]string{"Fo", "Fi", "Fee"}, nil)
+ ib.AppendValues([]int32{-1, 2, 3}, []bool{false, true, true})
+
+ arr := bldr.NewArray()
+ defer arr.Release()
+
+ fld := arrow.Field{Name: "mapped", Type: arr.DataType(), Nullable: true}
+ tbl := array.NewTable(arrow.NewSchema([]arrow.Field{fld}, nil),
+ []array.Column{*array.NewColumn(fld, array.NewChunked(arr.DataType(), []array.Interface{arr}))}, -1)
+ defer tbl.Release()
+
+ ps.roundTripTable(tbl, true)
+}
diff --git a/go/parquet/pqarrow/file_reader.go b/go/parquet/pqarrow/file_reader.go
new file mode 100644
index 0000000..e479063
--- /dev/null
+++ b/go/parquet/pqarrow/file_reader.go
@@ -0,0 +1,686 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pqarrow
+
+import (
+ "context"
+ "io"
+ "sync"
+ "sync/atomic"
+
+ "github.com/apache/arrow/go/arrow"
+ "github.com/apache/arrow/go/arrow/array"
+ "github.com/apache/arrow/go/arrow/arrio"
+ "github.com/apache/arrow/go/arrow/memory"
+ "github.com/apache/arrow/go/parquet"
+ "github.com/apache/arrow/go/parquet/file"
+ "github.com/apache/arrow/go/parquet/schema"
+ "golang.org/x/xerrors"
+)
+
+type itrFactory func(int, *file.Reader) *columnIterator
+
+type readerCtx struct {
+ rdr *file.Reader
+ mem memory.Allocator
+ colFactory itrFactory
+ filterLeaves bool
+ includedLeaves map[int]bool
+}
+
+func (r readerCtx) includesLeaf(idx int) bool {
+ _, ok := r.includedLeaves[idx]
+ return ok
+}
+
+// ReadTable is a convenience function to quickly and easily read a parquet file
+// into an arrow table.
+//
+// The schema of the arrow table is generated based on the schema of the parquet file,
+// including nested columns/lists/etc. in the same fashion as the FromParquetSchema
+// function. This just encapsulates the logic of creating a separate file.Reader and
+// pqarrow.FileReader to make a single easy function when you just want to construct
+// a table from the entire parquet file rather than reading it piecemeal.
+func ReadTable(ctx context.Context, r parquet.ReaderAtSeeker, props *parquet.ReaderProperties, arrProps ArrowReadProperties, mem memory.Allocator) (array.Table, error) {
+ pf, err := file.NewParquetReader(r, file.WithReadProps(props))
+ if err != nil {
+ return nil, err
+ }
+
+ reader, err := NewFileReader(pf, arrProps, mem)
+ if err != nil {
+ return nil, err
+ }
+
+ return reader.ReadTable(ctx)
+}
+
+// FileReader is the base object for reading a parquet file into arrow object
+// types.
+//
+// It provides utility functions for reading record batches, a table, subsets of
+// columns / rowgroups, and so on.
+type FileReader struct {
+ mem memory.Allocator
+ rdr *file.Reader
+
+ Props ArrowReadProperties
+ Manifest *SchemaManifest
+}
+
+// NewFileReader constructs a reader for converting to Arrow objects from an existing
+// parquet file reader object.
+//
+// Only returns an error if there is some error constructing the schema manifest from
+// the parquet file metadata.
+func NewFileReader(rdr *file.Reader, props ArrowReadProperties, mem memory.Allocator) (*FileReader, error) {
+ manifest, err := NewSchemaManifest(rdr.MetaData().Schema, rdr.MetaData().KeyValueMetadata(), &props)
+ if err != nil {
+ return nil, err
+ }
+
+ return &FileReader{
+ mem: mem,
+ rdr: rdr,
+ Props: props,
+ Manifest: manifest,
+ }, nil
+}
+
+// Schema returns the arrow schema representation of the underlying file's schema.
+func (fr *FileReader) Schema() (*arrow.Schema, error) {
+ return FromParquet(fr.rdr.MetaData().Schema, &fr.Props, fr.rdr.MetaData().KeyValueMetadata())
+}
+
+type colReaderImpl interface {
+ LoadBatch(nrecs int64) error
+ BuildArray(boundedLen int64) (*array.Chunked, error)
+ GetDefLevels() ([]int16, error)
+ GetRepLevels() ([]int16, error)
+ Field() *arrow.Field
+ IsOrHasRepeatedChild() bool
+ Retain()
+ Release()
+}
+
+// ColumnReader is used for reading batches of data from a specific column
+// across multiple row groups to return a chunked arrow array.
+type ColumnReader struct {
+ colReaderImpl
+}
+
+// NextBatch returns a chunked array after reading `size` values, potentially
+// across multiple row groups.
+func (c *ColumnReader) NextBatch(size int64) (*array.Chunked, error) {
+ if err := c.LoadBatch(size); err != nil {
+ return nil, err
+ }
+ return c.BuildArray(size)
+}
+
+type rdrCtxKey struct{}
+
+func readerCtxFromContext(ctx context.Context) readerCtx {
+ rdc := ctx.Value(rdrCtxKey{})
+ if rdc != nil {
+ return rdc.(readerCtx)
+ }
+ panic("no readerctx")
+}
+
+// ParquetReader returns the underlying parquet file reader that it was constructed with
+func (fr *FileReader) ParquetReader() *file.Reader { return fr.rdr }
+
+// GetColumn returns a reader for pulling the data of leaf column index i
+// across all row groups in the file.
+func (fr *FileReader) GetColumn(ctx context.Context, i int) (*ColumnReader, error) {
+ return fr.getColumnReader(ctx, i, fr.allRowGroupFactory())
+}
+
+func rowGroupFactory(rowGroups []int) itrFactory {
+ return func(i int, rdr *file.Reader) *columnIterator {
+ return &columnIterator{
+ index: i,
+ rdr: rdr,
+ schema: rdr.MetaData().Schema,
+ rowGroups: rowGroups,
+ }
+ }
+}
+
+func (fr *FileReader) allRowGroupFactory() itrFactory {
+ rowGroups := make([]int, fr.rdr.NumRowGroups())
+ for idx := range rowGroups {
+ rowGroups[idx] = idx
+ }
+ return rowGroupFactory(rowGroups)
+}
+
+// GetFieldReader returns a reader for the entire Field of index i which could potentially include reading
+// multiple columns from the underlying parquet file if that field is a nested field.
+//
+// IncludedLeaves and RowGroups are used to specify precisely which leaf indexes and row groups to read a subset of.
+func (fr *FileReader) GetFieldReader(ctx context.Context, i int, includedLeaves map[int]bool, rowGroups []int) (*ColumnReader, error) {
+ ctx = context.WithValue(ctx, rdrCtxKey{}, readerCtx{
+ rdr: fr.rdr,
+ mem: fr.mem,
+ colFactory: rowGroupFactory(rowGroups),
+ filterLeaves: true,
+ includedLeaves: includedLeaves,
+ })
+ return fr.getReader(ctx, &fr.Manifest.Fields[i], *fr.Manifest.Fields[i].Field)
+}
+
+// GetFieldReaders is for retrieving readers for multiple fields at one time for only the list
+// of column indexes and rowgroups requested. It returns a slice of the readers and the corresponding
+// arrow.Schema for those columns.
+func (fr *FileReader) GetFieldReaders(ctx context.Context, colIndices, rowGroups []int) ([]*ColumnReader, *arrow.Schema, error) {
+ fieldIndices, err := fr.Manifest.GetFieldIndices(colIndices)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ includedLeaves := make(map[int]bool)
+ for _, col := range colIndices {
+ includedLeaves[col] = true
+ }
+
+ out := make([]*ColumnReader, len(fieldIndices))
+ outFields := make([]arrow.Field, len(fieldIndices))
+ for idx, fidx := range fieldIndices {
+ rdr, err := fr.GetFieldReader(ctx, fidx, includedLeaves, rowGroups)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ outFields[idx] = *rdr.Field()
+ out[idx] = rdr
+ }
+
+ return out, arrow.NewSchema(outFields, fr.Manifest.SchemaMeta), nil
+}
+
+// RowGroup creates a reader that will *only* read from the requested row group
+func (fr *FileReader) RowGroup(idx int) RowGroupReader {
+ return RowGroupReader{fr, idx}
+}
+
+// ReadColumn reads data to create a chunked array only from the requested row groups.
+func (fr *FileReader) ReadColumn(rowGroups []int, rdr *ColumnReader) (*array.Chunked, error) {
+ recs := int64(0)
+ for _, rg := range rowGroups {
+ recs += fr.rdr.MetaData().RowGroups[rg].GetNumRows()
+ }
+ return rdr.NextBatch(recs)
+}
+
+// ReadTable reads the entire file into an array.Table
+func (fr *FileReader) ReadTable(ctx context.Context) (array.Table, error) {
+ var (
+ cols = []int{}
+ rgs = []int{}
+ )
+ for i := 0; i < fr.rdr.MetaData().Schema.NumColumns(); i++ {
+ cols = append(cols, i)
+ }
+ for i := 0; i < fr.rdr.NumRowGroups(); i++ {
+ rgs = append(rgs, i)
+ }
+ return fr.ReadRowGroups(ctx, cols, rgs)
+}
+
+func (fr *FileReader) checkCols(indices []int) (err error) {
+ for _, col := range indices {
+ if col < 0 || col >= fr.rdr.MetaData().Schema.NumColumns() {
+ err = xerrors.Errorf("invalid column index specified %d out of %d", col, fr.rdr.MetaData().Schema.NumColumns())
+ break
+ }
+ }
+ return
+}
+
+func (fr *FileReader) checkRowGroups(indices []int) (err error) {
+ for _, rg := range indices {
+ if rg < 0 || rg >= fr.rdr.NumRowGroups() {
+ err = xerrors.Errorf("invalid row group specified: %d, file only has %d row groups", rg, fr.rdr.NumRowGroups())
+ break
+ }
+ }
+ return
+}
+
+type readerInfo struct {
+ rdr *ColumnReader
+ idx int
+}
+
+type resultPair struct {
+ idx int
+ data *array.Chunked
+ err error
+}
+
+// ReadRowGroups is for generating an array.Table from the file but filtering to only read the requested
+// columns and row groups rather than the entire file which ReadTable does.
+func (fr *FileReader) ReadRowGroups(ctx context.Context, indices, rowGroups []int) (array.Table, error) {
+ if err := fr.checkRowGroups(rowGroups); err != nil {
+ return nil, err
+ }
+ if err := fr.checkCols(indices); err != nil {
+ return nil, err
+ }
+
+ // pre-buffer stuff?
+
+ readers, sc, err := fr.GetFieldReaders(ctx, indices, rowGroups)
+ if err != nil {
+ return nil, err
+ }
+
+ var (
+ np = 1
+ wg sync.WaitGroup
+ ch = make(chan readerInfo, len(readers))
+ results = make(chan resultPair, 2)
+ )
+
+ if fr.Props.Parallel {
+ np = len(readers)
+ }
+
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ wg.Add(np)
+ for i := 0; i < np; i++ {
+ go func() {
+ defer wg.Done()
+ for {
+ select {
+ case r, ok := <-ch:
+ if !ok {
+ return
+ }
+
+ chnked, err := fr.ReadColumn(rowGroups, r.rdr)
+ results <- resultPair{r.idx, chnked, err}
+ case <-ctx.Done():
+ return
+ }
+ }
+ }()
+ }
+
+ go func() {
+ wg.Wait()
+ close(results)
+ }()
+
+ for idx, r := range readers {
+ ch <- readerInfo{r, idx}
+ }
+ close(ch)
+
+ columns := make([]array.Column, len(sc.Fields()))
+ for data := range results {
+ defer data.data.Release()
+
+ if data.err != nil {
+ err = data.err
+ cancel()
+ break
+ }
+
+ columns[data.idx] = *array.NewColumn(sc.Field(data.idx), data.data)
+ }
+
+ if err != nil {
+ for data := range results {
+ defer data.data.Release()
+ }
+ return nil, err
+ }
+
+ var nrows int
+ if len(columns) > 0 {
+ nrows = columns[0].Len()
+ }
+ return array.NewTable(sc, columns, int64(nrows)), nil
+}
+
+func (fr *FileReader) getColumnReader(ctx context.Context, i int, colFactory itrFactory) (*ColumnReader, error) {
+ if i < 0 || i >= fr.rdr.MetaData().Schema.NumColumns() {
+ return nil, xerrors.Errorf("invalid column index chosen %d, there are only %d columns", i, fr.rdr.MetaData().Schema.NumColumns())
+ }
+
+ ctx = context.WithValue(ctx, rdrCtxKey{}, readerCtx{
+ rdr: fr.rdr,
+ mem: fr.mem,
+ colFactory: colFactory,
+ filterLeaves: false,
+ })
+
+ return fr.getReader(ctx, &fr.Manifest.Fields[i], *fr.Manifest.Fields[i].Field)
+}
+
+// RecordReader is a Record Batch Reader that meets the interfaces for both
+// array.RecordReader and arrio.Reader to allow easy progressive reading
+// of record batches from the parquet file. Ideal for streaming.
+type RecordReader interface {
+ array.RecordReader
+ arrio.Reader
+}
+
+// GetRecordReader returns a record reader that reads only the requested column indexes and row groups.
+//
+// For both cases, if you pass nil for column indexes or rowgroups it will default to reading all of them.
+func (fr *FileReader) GetRecordReader(ctx context.Context, colIndices, rowGroups []int) (RecordReader, error) {
+ if err := fr.checkRowGroups(rowGroups); err != nil {
+ return nil, err
+ }
+
+ if rowGroups == nil {
+ rowGroups = make([]int, fr.rdr.NumRowGroups())
+ for idx := range rowGroups {
+ rowGroups[idx] = idx
+ }
+ }
+
+ if err := fr.checkCols(colIndices); err != nil {
+ return nil, err
+ }
+
+ if colIndices == nil {
+ colIndices = make([]int, fr.rdr.MetaData().Schema.NumColumns())
+ for idx := range colIndices {
+ colIndices[idx] = idx
+ }
+ }
+
+ // pre-buffer stuff?
+
+ readers, sc, err := fr.GetFieldReaders(ctx, colIndices, rowGroups)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(readers) == 0 {
+ return nil, xerrors.New("no leaf column readers matched col indices")
+ }
+
+ nrows := int64(0)
+ for _, rg := range rowGroups {
+ nrows += fr.rdr.MetaData().RowGroup(rg).NumRows()
+ }
+
+ return &recordReader{
+ numRows: nrows,
+ batchSize: fr.Props.BatchSize,
+ parallel: fr.Props.Parallel,
+ sc: sc,
+ fieldReaders: readers,
+ refCount: 1,
+ }, nil
+}
+
+func (fr *FileReader) getReader(ctx context.Context, field *SchemaField, arrowField arrow.Field) (out *ColumnReader, err error) {
+ rctx := readerCtxFromContext(ctx)
+ if len(field.Children) == 0 {
+ if !field.IsLeaf() {
+ return nil, xerrors.New("parquet non-leaf node has no children")
+ }
+ if rctx.filterLeaves && !rctx.includesLeaf(field.ColIndex) {
+ return nil, nil
+ }
+
+ out, err = newLeafReader(&rctx, field.Field, rctx.colFactory(field.ColIndex, rctx.rdr), field.LevelInfo)
+ return
+ }
+
+ switch arrowField.Type.ID() {
+ case arrow.EXTENSION:
+ return nil, xerrors.New("extension type not implemented")
+ case arrow.STRUCT:
+ childReaders := make([]*ColumnReader, 0)
+ childFields := make([]arrow.Field, 0)
+ for _, child := range field.Children {
+ reader, err := fr.getReader(ctx, &child, *child.Field)
+ if err != nil {
+ return nil, err
+ }
+ if reader == nil {
+ continue
+ }
+ childFields = append(childFields, *child.Field)
+ childReaders = append(childReaders, reader)
+ }
+ if len(childFields) == 0 {
+ return nil, nil
+ }
+ filtered := arrow.Field{Name: arrowField.Name, Nullable: arrowField.Nullable,
+ Metadata: arrowField.Metadata, Type: arrow.StructOf(childFields...)}
+ out = newStructReader(&rctx, &filtered, field.LevelInfo, childReaders)
+ case arrow.LIST, arrow.FIXED_SIZE_LIST, arrow.MAP:
+ child := field.Children[0]
+ childReader, err := fr.getReader(ctx, &child, *child.Field)
+ if err != nil {
+ return nil, err
+ }
+ if childReader == nil {
+ return nil, nil
+ }
+ defer childReader.Release()
+
+ switch arrowField.Type.(type) {
+ case *arrow.MapType:
+ if len(child.Children) != 2 {
+ arrowField.Type = arrow.ListOf(childReader.Field().Type)
+ }
+ out = newListReader(&rctx, &arrowField, field.LevelInfo, childReader)
+ case *arrow.ListType:
+ out = newListReader(&rctx, &arrowField, field.LevelInfo, childReader)
+ case *arrow.FixedSizeListType:
+ out = newFixedSizeListReader(&rctx, &arrowField, field.LevelInfo, childReader)
+ default:
+ return nil, xerrors.Errorf("unknown list type: %s", field.Field.String())
+ }
+ }
+ return
+}
+
+// RowGroupReader is a reader for getting data only from a single row group of the file
+// rather than having to repeatedly pass the index to functions on the reader.
+type RowGroupReader struct {
+ impl *FileReader
+ idx int
+}
+
+// ReadTable provides an array.Table consisting only of the columns requested for this rowgroup
+func (rgr RowGroupReader) ReadTable(ctx context.Context, colIndices []int) (array.Table, error) {
+ return rgr.impl.ReadRowGroups(ctx, colIndices, []int{rgr.idx})
+}
+
+// Column creates a reader for just the requested column chunk in only this row group.
+func (rgr RowGroupReader) Column(idx int) ColumnChunkReader {
+ return ColumnChunkReader{rgr.impl, idx, rgr.idx}
+}
+
+// ColumnChunkReader is a reader that reads only a single column chunk from a single
+// column in a single row group
+type ColumnChunkReader struct {
+ impl *FileReader
+ idx int
+ rowGroup int
+}
+
+func (ccr ColumnChunkReader) Read(ctx context.Context) (*array.Chunked, error) {
+ rdr, err := ccr.impl.getColumnReader(ctx, ccr.idx, rowGroupFactory([]int{ccr.rowGroup}))
+ if err != nil {
+ return nil, err
+ }
+ return ccr.impl.ReadColumn([]int{ccr.rowGroup}, rdr)
+}
+
+type columnIterator struct {
+ index int
+ rdr *file.Reader
+ schema *schema.Schema
+ rowGroups []int
+}
+
+func (c *columnIterator) NextChunk() (file.PageReader, error) {
+ if len(c.rowGroups) == 0 {
+ return nil, nil
+ }
+
+ rgr := c.rdr.RowGroup(c.rowGroups[0])
+ c.rowGroups = c.rowGroups[1:]
+ return rgr.GetColumnPageReader(c.index)
+}
+
+func (c *columnIterator) Descr() *schema.Column { return c.schema.Column(c.index) }
+
+type recordReader struct {
+ numRows int64
+ batchSize int64
+ parallel bool
+ sc *arrow.Schema
+ fieldReaders []*ColumnReader
+ cur array.Record
+ err error
+
+ refCount int64
+}
+
+func (r *recordReader) Retain() {
+ atomic.AddInt64(&r.refCount, 1)
+}
+
+func (r *recordReader) Release() {
+ if atomic.AddInt64(&r.refCount, -1) == 0 {
+ if r.cur != nil {
+ r.cur.Release()
+ r.cur = nil
+ }
+ if r.fieldReaders == nil {
+ return
+ }
+ for _, fr := range r.fieldReaders {
+ fr.Release()
+ }
+ r.fieldReaders = nil
+ }
+}
+
+func (r *recordReader) Schema() *arrow.Schema { return r.sc }
+
+func (r *recordReader) next() bool {
+ cols := make([]array.Interface, len(r.sc.Fields()))
+ readField := func(idx int, rdr *ColumnReader) error {
+ data, err := rdr.NextBatch(r.batchSize)
+ if err != nil {
+ return err
+ }
+ defer data.Release()
+
+ if data.Len() == 0 {
+ return io.EOF
+ }
+
+ arrdata, err := chunksToSingle(data)
+ if err != nil {
+ return err
+ }
+ cols[idx] = array.MakeFromData(arrdata)
+ return nil
+ }
+
+ if !r.parallel {
+ for idx, rdr := range r.fieldReaders {
+ if err := readField(idx, rdr); err != nil {
+ r.err = err
+ return false
+ }
+ }
+
+ r.cur = array.NewRecord(r.sc, cols, -1)
+ return true
+ }
+
+ var (
+ wg sync.WaitGroup
+ np = len(cols)
+ ch = make(chan int, np)
+ )
+
+ wg.Add(np)
+ for i := 0; i < np; i++ {
+ go func() {
+ defer wg.Done()
+ for idx := range ch {
+ if r.err != nil {
+ break
+ }
+ if err := readField(idx, r.fieldReaders[idx]); err != nil {
+ r.err = err
+ break
+ }
+ }
+ }()
+ }
+
+ for idx := range r.fieldReaders {
+ ch <- idx
+ }
+ close(ch)
+ wg.Wait()
+
+ if r.err != nil {
+ return false
+ }
+
+ r.cur = array.NewRecord(r.sc, cols, -1)
+ return true
+}
+
+func (r *recordReader) Next() bool {
+ if r.cur != nil {
+ r.cur.Release()
+ r.cur = nil
+ }
+
+ if r.err != nil {
+ return false
+ }
+
+ return r.next()
+}
+
+func (r *recordReader) Record() array.Record { return r.cur }
+
+func (r *recordReader) Read() (array.Record, error) {
+ if r.cur != nil {
+ r.cur.Release()
+ r.cur = nil
+ }
+
+ if !r.next() {
+ return nil, r.err
+ }
+
+ return r.cur, nil
+}
diff --git a/go/parquet/pqarrow/file_reader_test.go b/go/parquet/pqarrow/file_reader_test.go
new file mode 100644
index 0000000..0e8adb0
--- /dev/null
+++ b/go/parquet/pqarrow/file_reader_test.go
@@ -0,0 +1,177 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pqarrow_test
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/apache/arrow/go/arrow"
+ "github.com/apache/arrow/go/arrow/array"
+ "github.com/apache/arrow/go/arrow/decimal128"
+ "github.com/apache/arrow/go/arrow/memory"
+ "github.com/apache/arrow/go/parquet/file"
+ "github.com/apache/arrow/go/parquet/pqarrow"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func getDataDir() string {
+ datadir := os.Getenv("PARQUET_TEST_DATA")
+ if datadir == "" {
+ panic("please point PARQUET_TEST_DATA env var to the test data directory")
+ }
+ return datadir
+}
+
+func TestArrowReaderAdHocReadDecimals(t *testing.T) {
+ tests := []struct {
+ file string
+ typ *arrow.Decimal128Type
+ }{
+ {"int32_decimal", &arrow.Decimal128Type{Precision: 4, Scale: 2}},
+ {"int64_decimal", &arrow.Decimal128Type{Precision: 10, Scale: 2}},
+ {"fixed_length_decimal", &arrow.Decimal128Type{Precision: 25, Scale: 2}},
+ {"fixed_length_decimal_legacy", &arrow.Decimal128Type{Precision: 13, Scale: 2}},
+ {"byte_array_decimal", &arrow.Decimal128Type{Precision: 4, Scale: 2}},
+ }
+
+ dataDir := getDataDir()
+ for _, tt := range tests {
+ t.Run(tt.file, func(t *testing.T) {
+ filename := filepath.Join(dataDir, tt.file+".parquet")
+ require.FileExists(t, filename)
+
+ rdr, err := file.OpenParquetFile(filename, false)
+ require.NoError(t, err)
+ arrowRdr, err := pqarrow.NewFileReader(rdr, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
+ require.NoError(t, err)
+
+ tbl, err := arrowRdr.ReadTable(context.Background())
+ require.NoError(t, err)
+
+ assert.EqualValues(t, 1, tbl.NumCols())
+ assert.Truef(t, arrow.TypeEqual(tbl.Schema().Field(0).Type, tt.typ), "expected: %s\ngot: %s", tbl.Schema().Field(0).Type, tt.typ)
+
+ const expectedLen = 24
+ valCol := tbl.Column(0)
+
+ assert.EqualValues(t, expectedLen, valCol.Len())
+ assert.Len(t, valCol.Data().Chunks(), 1)
+
+ chunk := valCol.Data().Chunk(0)
+ bldr := array.NewDecimal128Builder(memory.DefaultAllocator, tt.typ)
+ defer bldr.Release()
+ for i := 0; i < expectedLen; i++ {
+ bldr.Append(decimal128.FromI64(int64((i + 1) * 100)))
+ }
+
+ expectedArr := bldr.NewDecimal128Array()
+ defer expectedArr.Release()
+
+ assert.Truef(t, array.ArrayEqual(expectedArr, chunk), "expected: %s\ngot: %s", expectedArr, chunk)
+ })
+ }
+}
+
+func TestRecordReaderParallel(t *testing.T) {
+ tbl := makeDateTimeTypesTable(true, true)
+ var buf bytes.Buffer
+ require.NoError(t, pqarrow.WriteTable(tbl, &buf, tbl.NumRows(), nil, pqarrow.DefaultWriterProps(), memory.DefaultAllocator))
+
+ pf, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()))
+ require.NoError(t, err)
+
+ reader, err := pqarrow.NewFileReader(pf, pqarrow.ArrowReadProperties{BatchSize: 3, Parallel: true}, memory.DefaultAllocator)
+ require.NoError(t, err)
+
+ sc, err := reader.Schema()
+ assert.NoError(t, err)
+ assert.Truef(t, tbl.Schema().Equal(sc), "expected: %s\ngot: %s", tbl.Schema(), sc)
+
+ rr, err := reader.GetRecordReader(context.Background(), nil, nil)
+ assert.NoError(t, err)
+ assert.NotNil(t, rr)
+ defer rr.Release()
+
+ records := make([]array.Record, 0)
+ for rr.Next() {
+ rec := rr.Record()
+ defer rec.Release()
+
+ assert.Truef(t, sc.Equal(rec.Schema()), "expected: %s\ngot: %s", sc, rec.Schema())
+ rec.Retain()
+ records = append(records, rec)
+ }
+
+ assert.False(t, rr.Next())
+
+ tr := array.NewTableReader(tbl, 3)
+ defer tr.Release()
+
+ assert.True(t, tr.Next())
+ assert.Truef(t, array.RecordEqual(tr.Record(), records[0]), "expected: %s\ngot: %s", tr.Record(), records[0])
+ assert.True(t, tr.Next())
+ assert.Truef(t, array.RecordEqual(tr.Record(), records[1]), "expected: %s\ngot: %s", tr.Record(), records[1])
+}
+
+func TestRecordReaderSerial(t *testing.T) {
+ tbl := makeDateTimeTypesTable(true, true)
+ var buf bytes.Buffer
+ require.NoError(t, pqarrow.WriteTable(tbl, &buf, tbl.NumRows(), nil, pqarrow.DefaultWriterProps(), memory.DefaultAllocator))
+
+ pf, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()))
+ require.NoError(t, err)
+
+ reader, err := pqarrow.NewFileReader(pf, pqarrow.ArrowReadProperties{BatchSize: 2}, memory.DefaultAllocator)
+ require.NoError(t, err)
+
+ sc, err := reader.Schema()
+ assert.NoError(t, err)
+ assert.Truef(t, tbl.Schema().Equal(sc), "expected: %s\ngot: %s", tbl.Schema(), sc)
+
+ rr, err := reader.GetRecordReader(context.Background(), nil, nil)
+ assert.NoError(t, err)
+ assert.NotNil(t, rr)
+ defer rr.Release()
+
+ tr := array.NewTableReader(tbl, 2)
+ defer tr.Release()
+
+ rec, err := rr.Read()
+ assert.NoError(t, err)
+ tr.Next()
+ assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec)
+
+ rec, err = rr.Read()
+ assert.NoError(t, err)
+ tr.Next()
+ assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec)
+
+ rec, err = rr.Read()
+ assert.NoError(t, err)
+ tr.Next()
+ assert.Truef(t, array.RecordEqual(tr.Record(), rec), "expected: %s\ngot: %s", tr.Record(), rec)
+
+ rec, err = rr.Read()
+ assert.Same(t, io.EOF, err)
+ assert.Nil(t, rec)
+}
diff --git a/go/parquet/pqarrow/file_writer.go b/go/parquet/pqarrow/file_writer.go
new file mode 100644
index 0000000..1f0a946
--- /dev/null
+++ b/go/parquet/pqarrow/file_writer.go
@@ -0,0 +1,291 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pqarrow
+
+import (
+ "context"
+ "encoding/base64"
+ "io"
+
+ "github.com/apache/arrow/go/arrow"
+ "github.com/apache/arrow/go/arrow/array"
+ "github.com/apache/arrow/go/arrow/flight"
+ "github.com/apache/arrow/go/arrow/memory"
+ "github.com/apache/arrow/go/parquet"
+ "github.com/apache/arrow/go/parquet/file"
+ "github.com/apache/arrow/go/parquet/internal/utils"
+ "github.com/apache/arrow/go/parquet/metadata"
+ "golang.org/x/xerrors"
+)
+
+// WriteTable is a convenience function to create and write a full array.Table to a parquet file. The schema
+// and columns will be determined by the schema of the table, writing the file out to the the provided writer.
+// The chunksize will be utilized in order to determine the size of the row groups.
+func WriteTable(tbl array.Table, w io.Writer, chunkSize int64, props *parquet.WriterProperties, arrprops ArrowWriterProperties, mem memory.Allocator) error {
+ writer, err := NewFileWriter(tbl.Schema(), w, props, arrprops)
+ if err != nil {
+ return err
+ }
+
+ if err := writer.WriteTable(tbl, chunkSize); err != nil {
+ return err
+ }
+
+ return writer.Close()
+}
+
+// FileWriter is an object for writing Arrow directly to a parquet file.
+type FileWriter struct {
+ wr *file.Writer
+ schema *arrow.Schema
+ manifest *SchemaManifest
+ rgw file.RowGroupWriter
+ arrowProps ArrowWriterProperties
+ ctx context.Context
+ colIdx int
+ closed bool
+}
+
+// NewFileWriter returns a writer for writing Arrow directly to a parquetfile, rather than
+// the ArrowColumnWriter and WriteArrow functions which allow writing arrow to an existing
+// file.Writer, this will create a new file.Writer based on the schema provided.
+func NewFileWriter(arrschema *arrow.Schema, w io.Writer, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (*FileWriter, error) {
+ if props == nil {
+ props = parquet.NewWriterProperties()
+ }
+
+ pqschema, err := ToParquet(arrschema, props, arrprops)
+ if err != nil {
+ return nil, err
+ }
+
+ meta := make(metadata.KeyValueMetadata, 0)
+ if arrprops.storeSchema {
+ for i := 0; i < arrschema.Metadata().Len(); i++ {
+ meta.Append(arrschema.Metadata().Keys()[i], arrschema.Metadata().Values()[i])
+ }
+
+ serializedSchema := flight.SerializeSchema(arrschema, props.Allocator())
+ meta.Append("ARROW:schema", base64.RawStdEncoding.EncodeToString(serializedSchema))
+ }
+
+ schemaNode := pqschema.Root()
+ baseWriter := file.NewParquetWriter(w, schemaNode, file.WithWriterProps(props), file.WithWriteMetadata(meta))
+
+ manifest, err := NewSchemaManifest(pqschema, nil, &ArrowReadProperties{})
+ if err != nil {
+ return nil, err
+ }
+
+ return &FileWriter{wr: baseWriter, schema: arrschema, manifest: manifest, arrowProps: arrprops, ctx: NewArrowWriteContext(context.TODO(), &arrprops)}, nil
+}
+
+// NewRowGroup does what it says on the tin, creates a new row group in the underlying file.
+// Equivalent to `AppendRowGroup` on a file.Writer
+func (fw *FileWriter) NewRowGroup() {
+ if fw.rgw != nil {
+ fw.rgw.Close()
+ }
+ fw.rgw = fw.wr.AppendRowGroup()
+ fw.colIdx = 0
+}
+
+// NewBufferedRowGroup starts a new memory Buffered Row Group to allow writing columns / records
+// without immediately flushing them to disk. This allows using WriteBuffered to write records
+// and decide where to break your rowgroup based on the TotalBytesWritten rather than on the max
+// row group len. If using Records, this should be paired with WriteBuffered, while
+// Write will always write a new record as a row group in and of itself.
+func (fw *FileWriter) NewBufferedRowGroup() {
+ if fw.rgw != nil {
+ fw.rgw.Close()
+ }
+ fw.rgw = fw.wr.AppendBufferedRowGroup()
+ fw.colIdx = 0
+}
+
+// RowGroupTotalCompressedBytes returns the total number of bytes after compression
+// that have been written to the current row group so far.
+func (fw *FileWriter) RowGroupTotalCompressedBytes() int64 {
+ if fw.rgw != nil {
+ return fw.rgw.TotalCompressedBytes()
+ }
+ return 0
+}
+
+// RowGroupTotalBytesWritten returns the total number of bytes written and flushed out in
+// the current row group.
+func (fw *FileWriter) RowGroupTotalBytesWritten() int64 {
+ if fw.rgw != nil {
+ return fw.rgw.TotalBytesWritten()
+ }
+ return 0
+}
+
+func (fw *FileWriter) WriteBuffered(rec array.Record) error {
+ if !rec.Schema().Equal(fw.schema) {
+ return xerrors.Errorf("record schema does not match writer's. \nrecord: %s\nwriter: %s", rec.Schema(), fw.schema)
+ }
+
+ var (
+ recList []array.Record
+ maxRows = fw.wr.Properties().MaxRowGroupLength()
+ curRows int64
+ err error
+ )
+ if fw.rgw != nil {
+ if curRows, err = fw.rgw.NumRows(); err != nil {
+ return err
+ }
+ } else {
+ fw.NewBufferedRowGroup()
+ }
+
+ if curRows+rec.NumRows() <= maxRows {
+ recList = []array.Record{rec}
+ } else {
+ recList = []array.Record{rec.NewSlice(0, maxRows-curRows)}
+ defer recList[0].Release()
+ for offset := int64(maxRows - curRows); offset < rec.NumRows(); offset += maxRows {
+ s := rec.NewSlice(offset, offset+utils.Min(maxRows, rec.NumRows()-offset))
+ defer s.Release()
+ recList = append(recList, s)
+ }
+ }
+
+ for idx, r := range recList {
+ if idx > 0 {
+ fw.NewBufferedRowGroup()
+ }
+ for i := 0; i < int(r.NumCols()); i++ {
+ if err := fw.WriteColumnData(r.Column(i)); err != nil {
+ fw.Close()
+ return err
+ }
+ }
+ }
+ fw.colIdx = 0
+ return nil
+}
+
+// Write an arrow Record Batch to the file, respecting the MaxRowGroupLength in the writer
+// properties to determine whether or not a new row group is created while writing.
+func (fw *FileWriter) Write(rec array.Record) error {
+ if !rec.Schema().Equal(fw.schema) {
+ return xerrors.Errorf("record schema does not match writer's. \nrecord: %s\nwriter: %s", rec.Schema(), fw.schema)
+ }
+
+ var recList []array.Record
+ rowgroupLen := fw.wr.Properties().MaxRowGroupLength()
+ if rec.NumRows() > rowgroupLen {
+ recList = make([]array.Record, 0)
+ for offset := int64(0); offset < rec.NumRows(); offset += rowgroupLen {
+ s := rec.NewSlice(offset, offset+utils.Min(rowgroupLen, rec.NumRows()-offset))
+ defer s.Release()
+ recList = append(recList, s)
+ }
+ } else {
+ recList = []array.Record{rec}
+ }
+
+ for _, r := range recList {
+ fw.NewRowGroup()
+ for i := 0; i < int(r.NumCols()); i++ {
+ if err := fw.WriteColumnData(r.Column(i)); err != nil {
+ fw.Close()
+ return err
+ }
+ }
+ }
+ fw.colIdx = 0
+ return nil
+}
+
+// WriteTable writes an arrow table to the underlying file using chunkSize to determine
+// the size to break at for making row groups. Writing a table will always create a new
+// row group for each chunk of chunkSize rows in the table. Calling this with 0 rows will
+// still write a 0 length Row Group to the file.
+func (fw *FileWriter) WriteTable(tbl array.Table, chunkSize int64) error {
+ if chunkSize <= 0 && tbl.NumRows() > 0 {
+ return xerrors.New("chunk size per row group must be greater than 0")
+ } else if !tbl.Schema().Equal(fw.schema) {
+ return xerrors.Errorf("table schema does not match writer's. \nTable: %s\n writer: %s", tbl.Schema(), fw.schema)
+ } else if chunkSize > fw.wr.Properties().MaxRowGroupLength() {
+ chunkSize = fw.wr.Properties().MaxRowGroupLength()
+ }
+
+ writeRowGroup := func(offset, size int64) error {
+ fw.NewRowGroup()
+ for i := 0; i < int(tbl.NumCols()); i++ {
+ if err := fw.WriteColumnChunked(tbl.Column(i).Data(), offset, size); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+
+ if tbl.NumRows() == 0 {
+ if err := writeRowGroup(0, 0); err != nil {
+ fw.Close()
+ return err
+ }
+ return nil
+ }
+
+ for offset := int64(0); offset < tbl.NumRows(); offset += chunkSize {
+ if err := writeRowGroup(offset, utils.Min(chunkSize, tbl.NumRows()-offset)); err != nil {
+ fw.Close()
+ return err
+ }
+ }
+ return nil
+}
+
+// Close flushes out the data and closes the file. It can be called multiple times,
+// subsequent calls after the first will have no effect.
+func (fw *FileWriter) Close() error {
+ if !fw.closed {
+ fw.closed = true
+ if fw.rgw != nil {
+ if err := fw.rgw.Close(); err != nil {
+ return err
+ }
+ }
+ return fw.wr.Close()
+ }
+ return nil
+}
+
+// WriteColumnChunked will write the data provided to the underlying file, using the provided
+// offset and size to allow writing subsets of data from the chunked column. It uses the current
+// column in the underlying row group writer as the starting point, allowing progressive
+// building of writing columns to a file via arrow data without needing to already have
+// a record or table.
+func (fw *FileWriter) WriteColumnChunked(data *array.Chunked, offset, size int64) error {
+ acw, err := NewArrowColumnWriter(data, offset, size, fw.manifest, fw.rgw, fw.colIdx)
+ if err != nil {
+ return err
+ }
+ fw.colIdx += acw.leafCount
+ return acw.Write(fw.ctx)
+}
+
+// WriteColumnData writes the entire array to the file as the next columns. Like WriteColumnChunked
+// it is based on the current column of the row group writer allowing progressive building
+// of the file by columns without needing a full record or table to write.
+func (fw *FileWriter) WriteColumnData(data array.Interface) error {
+ return fw.WriteColumnChunked(array.NewChunked(data.DataType(), []array.Interface{data}), 0, int64(data.Len()))
+}
diff --git a/go/parquet/pqarrow/path_builder.go b/go/parquet/pqarrow/path_builder.go
new file mode 100644
index 0000000..c590984
--- /dev/null
+++ b/go/parquet/pqarrow/path_builder.go
@@ -0,0 +1,738 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pqarrow
+
+import (
+ "sync/atomic"
+ "unsafe"
+
+ "github.com/apache/arrow/go/arrow"
+ "github.com/apache/arrow/go/arrow/array"
+ "github.com/apache/arrow/go/arrow/memory"
+ "github.com/apache/arrow/go/parquet/internal/encoding"
+ "github.com/apache/arrow/go/parquet/internal/utils"
+ "golang.org/x/xerrors"
+)
+
+type iterResult int8
+
+const (
+ iterDone iterResult = -1
+ iterNext iterResult = 1
+)
+
+type elemRange struct {
+ start int64
+ end int64
+}
+
+func (e elemRange) empty() bool { return e.start == e.end }
+func (e elemRange) size() int64 { return e.end - e.start }
+
+type rangeSelector interface {
+ GetRange(idx int64) elemRange
+}
+
+type varRangeSelector struct {
+ offsets []int32
+}
+
+func (v varRangeSelector) GetRange(idx int64) elemRange {
+ return elemRange{int64(v.offsets[idx]), int64(v.offsets[idx+1])}
+}
+
+type fixedSizeRangeSelector struct {
+ listSize int32
+}
+
+func (f fixedSizeRangeSelector) GetRange(idx int64) elemRange {
+ start := idx * int64(f.listSize)
+ return elemRange{start, start + int64(f.listSize)}
+}
+
+type pathNode interface {
+ clone() pathNode
+}
+
+type allPresentTerminalNode struct {
+ defLevel int16
+}
+
+func (n *allPresentTerminalNode) clone() pathNode {
+ ret := *n
+ return &ret
+}
+
+func (n *allPresentTerminalNode) run(rng elemRange, ctx *pathWriteCtx) iterResult {
+ return ctx.AppendDefLevels(int(rng.size()), n.defLevel)
+}
+
+type allNullsTerminalNode struct {
+ defLevel int16
+ repLevel int16
+}
+
+func (n *allNullsTerminalNode) clone() pathNode {
+ ret := *n
+ return &ret
+}
+
+func (n *allNullsTerminalNode) run(rng elemRange, ctx *pathWriteCtx) iterResult {
+ fillRepLevels(int(rng.size()), n.repLevel, ctx)
+ return ctx.AppendDefLevels(int(rng.size()), n.defLevel)
+}
+
+type nullableTerminalNode struct {
+ bitmap []byte
+ elemOffset int64
+ defLevelIfPresent int16
+ defLevelIfNull int16
+}
+
+func (n *nullableTerminalNode) clone() pathNode {
+ ret := *n
+ return &ret
+}
+
+func (n *nullableTerminalNode) run(rng elemRange, ctx *pathWriteCtx) iterResult {
+ elems := rng.size()
+ ctx.ReserveDefLevels(int(elems))
+
+ var (
+ present = (*(*[2]byte)(unsafe.Pointer(&n.defLevelIfPresent)))[:]
+ null = (*(*[2]byte)(unsafe.Pointer(&n.defLevelIfNull)))[:]
+ )
+ rdr := utils.NewBitRunReader(n.bitmap, n.elemOffset+rng.start, elems)
+ for {
+ run := rdr.NextRun()
+ if run.Len == 0 {
+ break
+ }
+ if run.Set {
+ ctx.defLevels.UnsafeWriteCopy(int(run.Len), present)
+ } else {
+ ctx.defLevels.UnsafeWriteCopy(int(run.Len), null)
+ }
+ }
+ return iterDone
+}
+
+type listNode struct {
+ selector rangeSelector
+ prevRepLevel int16
+ repLevel int16
+ defLevelIfEmpty int16
+ isLast bool
+}
+
+func (n *listNode) clone() pathNode {
+ ret := *n
+ return &ret
+}
+
+func (n *listNode) run(rng, childRng *elemRange, ctx *pathWriteCtx) iterResult {
+ if rng.empty() {
+ return iterDone
+ }
+
+ // find the first non-empty list (skipping a run of empties)
+ start := rng.start
+ for {
+ // retrieve the range of elements that this list contains
+ *childRng = n.selector.GetRange(rng.start)
+ if !childRng.empty() {
+ break
+ }
+ rng.start++
+ if rng.empty() {
+ break
+ }
+ }
+
+ // loops post-condition:
+ // * rng is either empty (we're done processing this node)
+ // or start corresponds to a non-empty list
+ // * if rng is non-empty, childRng contains the bounds of the non-empty list
+
+ // handle any skipped over empty lists
+ emptyElems := rng.start - start
+ if emptyElems > 0 {
+ fillRepLevels(int(emptyElems), n.prevRepLevel, ctx)
+ ctx.AppendDefLevels(int(emptyElems), n.defLevelIfEmpty)
+ }
+
+ // start of a new list, note that for nested lists adding the element
+ // here effectively suppresses this code until we either encounter null
+ // elements or empty lists between here and the innermost list (since we
+ // make the rep levels repetition and definition levels unequal).
+ // similarly when we are backtracking up the stack, the repetition
+ // and definition levels are again equal so if we encounter an intermediate
+ // list, with more elements, this will detect it as a new list
+ if ctx.equalRepDeflevlsLen() && !rng.empty() {
+ ctx.AppendRepLevel(n.prevRepLevel)
+ }
+
+ if rng.empty() {
+ return iterDone
+ }
+
+ rng.start++
+ if n.isLast {
+ // if this is the last repeated node, we can try
+ // to extend the child range as wide as possible,
+ // before continuing to the next node
+ return n.fillForLast(rng, childRng, ctx)
+ }
+
+ return iterNext
+}
+
+func (n *listNode) fillForLast(rng, childRng *elemRange, ctx *pathWriteCtx) iterResult {
+ fillRepLevels(int(childRng.size()), n.repLevel, ctx)
+ // once we've reached this point the following preconditions should hold:
+ // 1. there are no more repeated path nodes to deal with
+ // 2. all elements in |range| reperesent contiguous elements in the child
+ // array (null values would have shortened the range to ensure all
+ // remaining list elements are present, though they may be empty)
+ // 3. no element of range spans a parent list (intermediate list nodes
+ // only handle one list entry at a time)
+ //
+ // given these preconditions, it should be safe to fill runs on non-empty lists
+ // here and expand the range in the child node accordingly
+ for !rng.empty() {
+ sizeCheck := n.selector.GetRange(rng.start)
+ if sizeCheck.empty() {
+ // the empty range will need to be handled after we pass down the accumulated
+ // range because it affects def level placement and we need to get the children
+ // def levels entered first
+ break
+ }
+
+ // this is the start of a new list. we can be sure that it only applies to the
+ // previous list (and doesn't jump to the start of any list further up in nesting
+ // due to the contraints mentioned earlier)
+ ctx.AppendRepLevel(n.prevRepLevel)
+ ctx.AppendRepLevels(int(sizeCheck.size())-1, n.repLevel)
+ childRng.end = sizeCheck.end
+ rng.start++
+ }
+
+ // do book-keeping to track the elements of the arrays that are actually visited
+ // beyond this point. this is necessary to identify "gaps" in values that should
+ // not be processed (written out to parquet)
+ ctx.recordPostListVisit(*childRng)
+ return iterNext
+}
+
+type nullableNode struct {
+ bitmap []byte
+ entryOffset int64
+ repLevelIfNull int16
+ defLevelIfNull int16
+
+ validBitsReader utils.BitRunReader
+ newRange bool
+}
+
+func (n *nullableNode) clone() pathNode {
+ var ret nullableNode = *n
+ return &ret
+}
+
+func (n *nullableNode) run(rng, childRng *elemRange, ctx *pathWriteCtx) iterResult {
+ if n.newRange {
+ n.validBitsReader = utils.NewBitRunReader(n.bitmap, n.entryOffset+rng.start, rng.size())
+ }
+ childRng.start = rng.start
+ run := n.validBitsReader.NextRun()
+ if !run.Set {
+ rng.start += run.Len
+ fillRepLevels(int(run.Len), n.repLevelIfNull, ctx)
+ ctx.AppendDefLevels(int(run.Len), n.defLevelIfNull)
+ run = n.validBitsReader.NextRun()
+ }
+
+ if rng.empty() {
+ n.newRange = true
+ return iterDone
+ }
+ childRng.start = rng.start
+ childRng.end = childRng.start
+ childRng.end += run.Len
+ rng.start += childRng.size()
+ n.newRange = false
+ return iterNext
+}
+
+type pathInfo struct {
+ path []pathNode
+ primitiveArr array.Interface
+ maxDefLevel int16
+ maxRepLevel int16
+ hasDict bool
+ leafIsNullable bool
+}
+
+func (p pathInfo) clone() pathInfo {
+ ret := p
+ ret.path = make([]pathNode, len(p.path))
+ for idx, n := range p.path {
+ ret.path[idx] = n.clone()
+ }
+ return ret
+}
+
+type pathBuilder struct {
+ info pathInfo
+ paths []pathInfo
+ nullableInParent bool
+
+ refCount int64
+}
+
+func (p *pathBuilder) Retain() {
+ atomic.AddInt64(&p.refCount, 1)
+}
+
+func (p *pathBuilder) Release() {
+ if atomic.AddInt64(&p.refCount, -1) == 0 {
+ for idx := range p.paths {
+ p.paths[idx].primitiveArr.Release()
+ p.paths[idx].primitiveArr = nil
+ }
+ }
+}
+
+// calling NullN on the arr directly will compute the nulls
+// if we have "UnknownNullCount", calling NullN on the data
+// object directly will just return the value the data has.
+// thus we might bet array.UnknownNullCount as the result here.
+func lazyNullCount(arr array.Interface) int64 {
+ return int64(arr.Data().NullN())
+}
+
+func lazyNoNulls(arr array.Interface) bool {
+ nulls := lazyNullCount(arr)
+ return nulls == 0 || (nulls == array.UnknownNullCount && arr.NullBitmapBytes() == nil)
+}
+
+type fixupVisitor struct {
+ maxRepLevel int
+ repLevelIfNull int16
+}
+
+func (f *fixupVisitor) visit(n pathNode) {
+ switch n := n.(type) {
+ case *listNode:
+ if n.repLevel == int16(f.maxRepLevel) {
+ n.isLast = true
+ f.repLevelIfNull = -1
+ } else {
+ f.repLevelIfNull = n.repLevel
+ }
+ case *nullableTerminalNode:
+ case *allPresentTerminalNode:
+ case *allNullsTerminalNode:
+ if f.repLevelIfNull != -1 {
+ n.repLevel = f.repLevelIfNull
+ }
+ case *nullableNode:
+ if f.repLevelIfNull != -1 {
+ n.repLevelIfNull = f.repLevelIfNull
+ }
+ }
+}
+
+func fixup(info pathInfo) pathInfo {
+ // we only need to fixup the path if there were repeated elems
+ if info.maxRepLevel == 0 {
+ return info
+ }
+
+ visitor := fixupVisitor{maxRepLevel: int(info.maxRepLevel)}
+ if visitor.maxRepLevel > 0 {
+ visitor.repLevelIfNull = 0
+ } else {
+ visitor.repLevelIfNull = -1
+ }
+
+ for _, p := range info.path {
+ visitor.visit(p)
+ }
+ return info
+}
+
+func (p *pathBuilder) Visit(arr array.Interface) error {
+ switch arr.DataType().ID() {
+ case arrow.LIST, arrow.MAP:
+ p.maybeAddNullable(arr)
+ // increment necessary due to empty lists
+ p.info.maxDefLevel++
+ p.info.maxRepLevel++
+ larr, ok := arr.(*array.List)
+ if !ok {
+ larr = arr.(*array.Map).List
+ }
+
+ p.info.path = append(p.info.path, &listNode{
+ selector: varRangeSelector{larr.Offsets()[larr.Data().Offset():]},
+ prevRepLevel: p.info.maxRepLevel - 1,
+ repLevel: p.info.maxRepLevel,
+ defLevelIfEmpty: p.info.maxDefLevel - 1,
+ })
+ p.nullableInParent = ok
+ return p.Visit(larr.ListValues())
+ case arrow.FIXED_SIZE_LIST:
+ p.maybeAddNullable(arr)
+ larr := arr.(*array.FixedSizeList)
+ listSize := larr.DataType().(*arrow.FixedSizeListType).Len()
+ // technically we could encoded fixed sized lists with two level encodings
+ // but we always use 3 level encoding, so we increment def levels as well
+ p.info.maxDefLevel++
+ p.info.maxRepLevel++
+ p.info.path = append(p.info.path, &listNode{
+ selector: fixedSizeRangeSelector{listSize},
+ prevRepLevel: p.info.maxRepLevel - 1,
+ repLevel: p.info.maxRepLevel,
+ defLevelIfEmpty: p.info.maxDefLevel,
+ })
+ // if arr.data.offset > 0, slice?
+ return p.Visit(larr.ListValues())
+ case arrow.DICTIONARY:
+ return xerrors.New("dictionary types not implemented yet")
+ case arrow.STRUCT:
+ p.maybeAddNullable(arr)
+ infoBackup := p.info
+ dt := arr.DataType().(*arrow.StructType)
+ for idx, f := range dt.Fields() {
+ p.nullableInParent = f.Nullable
+ if err := p.Visit(arr.(*array.Struct).Field(idx)); err != nil {
+ return err
+ }
+ p.info = infoBackup
+ }
+ return nil
+ case arrow.EXTENSION:
+ return xerrors.New("extension types not implemented yet")
+ case arrow.UNION:
+ return xerrors.New("union types aren't supported in parquet")
+ default:
+ p.addTerminalInfo(arr)
+ return nil
+ }
+}
+
+func (p *pathBuilder) addTerminalInfo(arr array.Interface) {
+ p.info.leafIsNullable = p.nullableInParent
+ if p.nullableInParent {
+ p.info.maxDefLevel++
+ }
+
+ // we don't use null_count because if the null_count isn't known
+ // and the array does in fact contain nulls, we will end up traversing
+ // the null bitmap twice.
+ if lazyNoNulls(arr) {
+ p.info.path = append(p.info.path, &allPresentTerminalNode{p.info.maxDefLevel})
+ p.info.leafIsNullable = false
+ } else if lazyNullCount(arr) == int64(arr.Len()) {
+ p.info.path = append(p.info.path, &allNullsTerminalNode{p.info.maxDefLevel - 1, -1})
+ } else {
+ p.info.path = append(p.info.path, &nullableTerminalNode{bitmap: arr.NullBitmapBytes(), elemOffset: int64(arr.Data().Offset()), defLevelIfPresent: p.info.maxDefLevel, defLevelIfNull: p.info.maxDefLevel - 1})
+ }
+ arr.Retain()
+ p.info.primitiveArr = arr
+ p.paths = append(p.paths, fixup(p.info.clone()))
+}
+
+func (p *pathBuilder) maybeAddNullable(arr array.Interface) {
+ if !p.nullableInParent {
+ return
+ }
+
+ p.info.maxDefLevel++
+ if lazyNoNulls(arr) {
+ return
+ }
+
+ if lazyNullCount(arr) == int64(arr.Len()) {
+ p.info.path = append(p.info.path, &allNullsTerminalNode{p.info.maxDefLevel - 1, -1})
+ return
+ }
+
+ p.info.path = append(p.info.path, &nullableNode{
+ bitmap: arr.NullBitmapBytes(), entryOffset: int64(arr.Data().Offset()),
+ defLevelIfNull: p.info.maxDefLevel - 1, repLevelIfNull: -1,
+ newRange: true,
+ })
+}
+
+type multipathLevelBuilder struct {
+ rootRange elemRange
+ data *array.Data
+ builder pathBuilder
+
+ refCount int64
+}
+
+func (m *multipathLevelBuilder) Retain() {
+ atomic.AddInt64(&m.refCount, 1)
+}
+
+func (m *multipathLevelBuilder) Release() {
+ if atomic.AddInt64(&m.refCount, -1) == 0 {
+ m.data.Release()
+ m.data = nil
+ m.builder.Release()
+ m.builder = pathBuilder{}
+ }
+}
+
+func newMultipathLevelBuilder(arr array.Interface, fieldNullable bool) (*multipathLevelBuilder, error) {
+ ret := &multipathLevelBuilder{
+ refCount: 1,
+ rootRange: elemRange{int64(0), int64(arr.Data().Len())},
+ data: arr.Data(),
+ builder: pathBuilder{nullableInParent: fieldNullable, paths: make([]pathInfo, 0), refCount: 1},
+ }
+ if err := ret.builder.Visit(arr); err != nil {
+ return nil, err
+ }
+ arr.Data().Retain()
+ return ret, nil
+}
+
+func (m *multipathLevelBuilder) leafCount() int {
+ return len(m.builder.paths)
+}
+
+func (m *multipathLevelBuilder) write(leafIdx int, ctx *arrowWriteContext) (multipathLevelResult, error) {
+ return writePath(m.rootRange, &m.builder.paths[leafIdx], ctx)
+}
+
+func (m *multipathLevelBuilder) writeAll(ctx *arrowWriteContext) (res []multipathLevelResult, err error) {
+ res = make([]multipathLevelResult, m.leafCount())
+ for idx := range res {
+ res[idx], err = m.write(idx, ctx)
+ if err != nil {
+ break
+ }
+ }
+ return
+}
+
+type multipathLevelResult struct {
+ leafArr array.Interface
+ defLevels []int16
+ defLevelsBuffer encoding.Buffer
+ repLevels []int16
+ repLevelsBuffer encoding.Buffer
+ // contains the element ranges of the required visiting on the descendants of the
+ // final list ancestor for any leaf node.
+ //
+ // the algorithm will attempt to consolidate the visited ranges into the smallest number
+ //
+ // this data is necessary to pass along because after producing the def-rep levels for each
+ // leaf array, it is impossible to determine which values have to be sent to parquet when a
+ // null list value in a nullable listarray is non-empty
+ //
+ // this allows for the parquet writing to determine which values ultimately need to be written
+ postListVisitedElems []elemRange
+
+ leafIsNullable bool
+}
+
+func (m *multipathLevelResult) Release() {
+ m.leafArr.Release()
+ m.defLevels = nil
+ if m.defLevelsBuffer != nil {
+ m.defLevelsBuffer.Release()
+ }
+ if m.repLevels != nil {
+ m.repLevels = nil
+ m.repLevelsBuffer.Release()
+ }
+}
+
+type pathWriteCtx struct {
+ mem memory.Allocator
+ defLevels *int16BufferBuilder
+ repLevels *int16BufferBuilder
+ visitedElems []elemRange
+}
+
+func (p *pathWriteCtx) ReserveDefLevels(elems int) iterResult {
+ p.defLevels.Reserve(elems)
+ return iterDone
+}
+
+func (p *pathWriteCtx) AppendDefLevel(lvl int16) iterResult {
+ p.defLevels.Append(lvl)
+ return iterDone
+}
+
+func (p *pathWriteCtx) AppendDefLevels(count int, defLevel int16) iterResult {
+ p.defLevels.AppendCopies(count, defLevel)
+ return iterDone
+}
+
+func (p *pathWriteCtx) UnsafeAppendDefLevel(v int16) iterResult {
+ p.defLevels.UnsafeAppend(v)
+ return iterDone
+}
+
+func (p *pathWriteCtx) AppendRepLevel(lvl int16) iterResult {
+ p.repLevels.Append(lvl)
+ return iterDone
+}
+
+func (p *pathWriteCtx) AppendRepLevels(count int, lvl int16) iterResult {
+ p.repLevels.AppendCopies(count, lvl)
+ return iterDone
+}
+
+func (p *pathWriteCtx) equalRepDeflevlsLen() bool { return p.defLevels.Len() == p.repLevels.Len() }
+
+func (p *pathWriteCtx) recordPostListVisit(rng elemRange) {
+ if len(p.visitedElems) > 0 && rng.start == p.visitedElems[len(p.visitedElems)-1].end {
+ p.visitedElems[len(p.visitedElems)-1].end = rng.end
+ return
+ }
+ p.visitedElems = append(p.visitedElems, rng)
+}
+
+type int16BufferBuilder struct {
+ *encoding.PooledBufferWriter
+}
+
+func (b *int16BufferBuilder) Values() []int16 {
+ return arrow.Int16Traits.CastFromBytes(b.PooledBufferWriter.Bytes())
+}
+
+func (b *int16BufferBuilder) Value(i int) int16 {
+ return b.Values()[i]
+}
+
+func (b *int16BufferBuilder) Reserve(n int) {
+ b.PooledBufferWriter.Reserve(n * arrow.Int16SizeBytes)
+}
+
+func (b *int16BufferBuilder) Len() int { return b.PooledBufferWriter.Len() / arrow.Int16SizeBytes }
+
+func (b *int16BufferBuilder) AppendCopies(count int, val int16) {
+ b.Reserve(count)
+ b.UnsafeWriteCopy(count, (*(*[2]byte)(unsafe.Pointer(&val)))[:])
+}
+
+func (b *int16BufferBuilder) UnsafeAppend(v int16) {
+ b.PooledBufferWriter.UnsafeWrite((*(*[2]byte)(unsafe.Pointer(&v)))[:])
+}
+
+func (b *int16BufferBuilder) Append(v int16) {
+ b.PooledBufferWriter.Reserve(arrow.Int16SizeBytes)
+ b.PooledBufferWriter.Write((*(*[2]byte)(unsafe.Pointer(&v)))[:])
+}
+
+func fillRepLevels(count int, repLvl int16, ctx *pathWriteCtx) {
+ if repLvl == -1 {
+ return
+ }
+
+ fillCount := count
+ // this condition occurs (rep and def levels equals), in one of a few cases:
+ // 1. before any list is encounted
+ // 2. after rep-level has been filled in due to null/empty values above
+ // 3. after finishing a list
+ if !ctx.equalRepDeflevlsLen() {
+ fillCount--
+ }
+ ctx.AppendRepLevels(fillCount, repLvl)
+}
+
+func writePath(rootRange elemRange, info *pathInfo, arrCtx *arrowWriteContext) (multipathLevelResult, error) {
+ stack := make([]elemRange, len(info.path))
+ buildResult := multipathLevelResult{
+ leafArr: info.primitiveArr,
+ leafIsNullable: info.leafIsNullable,
+ }
+
+ if info.maxDefLevel == 0 {
+ // this case only occurs when there are no nullable or repeated columns in the path from the root to the leaf
+ leafLen := buildResult.leafArr.Len()
+ buildResult.postListVisitedElems = []elemRange{{0, int64(leafLen)}}
+ return buildResult, nil
+ }
+
+ stack[0] = rootRange
+ if arrCtx.defLevelsBuffer != nil {
+ arrCtx.defLevelsBuffer.Release()
+ arrCtx.defLevelsBuffer = nil
+ }
+ if arrCtx.repLevelsBuffer != nil {
+ arrCtx.repLevelsBuffer.Release()
+ arrCtx.repLevelsBuffer = nil
+ }
+
+ ctx := pathWriteCtx{arrCtx.props.mem,
+ &int16BufferBuilder{encoding.NewPooledBufferWriter(0)},
+ &int16BufferBuilder{encoding.NewPooledBufferWriter(0)},
+ make([]elemRange, 0)}
+
+ ctx.defLevels.Reserve(int(rootRange.size()))
+ if info.maxRepLevel > 0 {
+ ctx.repLevels.Reserve(int(rootRange.size()))
+ }
+
+ stackBase := 0
+ stackPos := stackBase
+ for stackPos >= stackBase {
+ var res iterResult
+ switch n := info.path[stackPos].(type) {
+ case *nullableNode:
+ res = n.run(&stack[stackPos], &stack[stackPos+1], &ctx)
+ case *listNode:
+ res = n.run(&stack[stackPos], &stack[stackPos+1], &ctx)
+ case *nullableTerminalNode:
+ res = n.run(stack[stackPos], &ctx)
+ case *allPresentTerminalNode:
+ res = n.run(stack[stackPos], &ctx)
+ case *allNullsTerminalNode:
+ res = n.run(stack[stackPos], &ctx)
+ }
+ stackPos += int(res)
+ }
+
+ if ctx.repLevels.Len() > 0 {
+ // this case only occurs when there was a repeated element somewhere
+ buildResult.repLevels = ctx.repLevels.Values()
+ buildResult.repLevelsBuffer = ctx.repLevels.Finish()
+
+ buildResult.postListVisitedElems, ctx.visitedElems = ctx.visitedElems, buildResult.postListVisitedElems
+ // if it is possible when processing lists that all lists were empty. in this
+ // case, no elements would have been added to the postListVisitedElements. by
+ // adding an empty element, we avoid special casing later
+ if len(buildResult.postListVisitedElems) == 0 {
+ buildResult.postListVisitedElems = append(buildResult.postListVisitedElems, elemRange{0, 0})
+ }
+ } else {
+ buildResult.postListVisitedElems = append(buildResult.postListVisitedElems, elemRange{0, int64(buildResult.leafArr.Len())})
+ buildResult.repLevels = nil
+ }
+
+ buildResult.defLevels = ctx.defLevels.Values()
+ buildResult.defLevelsBuffer = ctx.defLevels.Finish()
+ return buildResult, nil
+}
diff --git a/go/parquet/pqarrow/path_builder_test.go b/go/parquet/pqarrow/path_builder_test.go
new file mode 100644
index 0000000..1848f79
--- /dev/null
+++ b/go/parquet/pqarrow/path_builder_test.go
@@ -0,0 +1,628 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pqarrow
+
+import (
+ "context"
+ "testing"
+
+ "github.com/apache/arrow/go/arrow"
+ "github.com/apache/arrow/go/arrow/array"
+ "github.com/apache/arrow/go/arrow/memory"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNonNullableSingleList(t *testing.T) {
+ // translates to the following parquet schema:
+ // required group bag {
+ // repeated group [unseen] (List) {
+ // required int64 Entires;
+ // }
+ // }
+ // So:
+ // def level 0: a null entry
+ // def level 1: a non-null entry
+ bldr := array.NewListBuilder(memory.DefaultAllocator, arrow.PrimitiveTypes.Int64)
+ defer bldr.Release()
+
+ vb := bldr.ValueBuilder().(*array.Int64Builder)
+
+ bldr.Append(true)
+ vb.Append(1)
+
+ bldr.Append(true)
+ vb.Append(2)
+ vb.Append(3)
+
+ bldr.Append(true)
+ vb.Append(4)
+ vb.Append(5)
+ vb.Append(6)
+
+ arr := bldr.NewListArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, false)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.write(0, ctx)
+ require.NoError(t, err)
+
+ assert.Equal(t, []int16{2, 2, 2, 2, 2, 2}, result.defLevels)
+ assert.Equal(t, []int16{0, 0, 1, 0, 1, 1}, result.repLevels)
+ assert.Len(t, result.postListVisitedElems, 1)
+ assert.EqualValues(t, 0, result.postListVisitedElems[0].start)
+ assert.EqualValues(t, 6, result.postListVisitedElems[0].end)
+}
+
+// next group of tests translates to the following parquet schema:
+// optional group bag {
+// repeated group [unseen] (List) {
+// optional int64 Entires;
+// }
+// }
+// So:
+// def level 0: a null list
+// def level 1: an empty list
+// def level 2: a null entry
+// def level 3: a non-null entry
+
+func TestNullableSingleListAllNulls(t *testing.T) {
+ bldr := array.NewListBuilder(memory.DefaultAllocator, arrow.PrimitiveTypes.Int64)
+ defer bldr.Release()
+
+ bldr.AppendNull()
+ bldr.AppendNull()
+ bldr.AppendNull()
+ bldr.AppendNull()
+
+ arr := bldr.NewListArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, true)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.write(0, ctx)
+ require.NoError(t, err)
+
+ assert.Equal(t, []int16{0, 0, 0, 0}, result.defLevels)
+ assert.Equal(t, []int16{0, 0, 0, 0}, result.repLevels)
+}
+
+func TestNullableSingleListAllEmpty(t *testing.T) {
+ bldr := array.NewListBuilder(memory.DefaultAllocator, arrow.PrimitiveTypes.Int64)
+ defer bldr.Release()
+
+ bldr.Append(true)
+ bldr.Append(true)
+ bldr.Append(true)
+ bldr.Append(true)
+
+ arr := bldr.NewListArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, true)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.write(0, ctx)
+ require.NoError(t, err)
+
+ assert.Equal(t, []int16{1, 1, 1, 1}, result.defLevels)
+ assert.Equal(t, []int16{0, 0, 0, 0}, result.repLevels)
+}
+
+func TestNullableSingleListAllNullEntries(t *testing.T) {
+ bldr := array.NewListBuilder(memory.DefaultAllocator, arrow.PrimitiveTypes.Int64)
+ defer bldr.Release()
+
+ vb := bldr.ValueBuilder().(*array.Int64Builder)
+
+ bldr.Append(true)
+ vb.AppendNull()
+ bldr.Append(true)
+ vb.AppendNull()
+ bldr.Append(true)
+ vb.AppendNull()
+ bldr.Append(true)
+ vb.AppendNull()
+
+ arr := bldr.NewListArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, true)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.write(0, ctx)
+ require.NoError(t, err)
+
+ assert.Equal(t, []int16{2, 2, 2, 2}, result.defLevels)
+ assert.Equal(t, []int16{0, 0, 0, 0}, result.repLevels)
+ assert.Len(t, result.postListVisitedElems, 1)
+ assert.EqualValues(t, 0, result.postListVisitedElems[0].start)
+ assert.EqualValues(t, 4, result.postListVisitedElems[0].end)
+}
+
+func TestNullableSingleListAllPresentEntries(t *testing.T) {
+ bldr := array.NewListBuilder(memory.DefaultAllocator, arrow.PrimitiveTypes.Int64)
+ defer bldr.Release()
+
+ vb := bldr.ValueBuilder().(*array.Int64Builder)
+
+ bldr.Append(true)
+ bldr.Append(true)
+ bldr.Append(true)
+ vb.Append(1)
+ bldr.Append(true)
+ bldr.Append(true)
+ vb.Append(2)
+ vb.Append(3)
+
+ arr := bldr.NewListArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, true)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.write(0, ctx)
+ require.NoError(t, err)
+
+ assert.Equal(t, []int16{1, 1, 3, 1, 3, 3}, result.defLevels)
+ assert.Equal(t, []int16{0, 0, 0, 0, 0, 1}, result.repLevels)
+ assert.Len(t, result.postListVisitedElems, 1)
+ assert.EqualValues(t, 0, result.postListVisitedElems[0].start)
+ assert.EqualValues(t, 3, result.postListVisitedElems[0].end)
+}
+
+func TestNullableSingleListSomeNullEntriesSomeNullLists(t *testing.T) {
+ bldr := array.NewListBuilder(memory.DefaultAllocator, arrow.PrimitiveTypes.Int64)
+ defer bldr.Release()
+
+ vb := bldr.ValueBuilder().(*array.Int64Builder)
+
+ bldr.Append(false)
+ bldr.Append(true)
+ vb.AppendValues([]int64{1, 2, 3}, nil)
+ bldr.Append(true)
+ bldr.Append(true)
+ bldr.AppendNull()
+ bldr.AppendNull()
+ bldr.Append(true)
+ vb.AppendValues([]int64{4, 5}, nil)
+ bldr.Append(true)
+ vb.AppendNull()
+
+ arr := bldr.NewListArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, true)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.write(0, ctx)
+ require.NoError(t, err)
+
+ assert.Equal(t, []int16{0, 3, 3, 3, 1, 1, 0, 0, 3, 3, 2}, result.defLevels)
+ assert.Equal(t, []int16{0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0}, result.repLevels)
+}
+
+// next group of tests translate to the following parquet schema:
+//
+// optional group bag {
+// repeated group outer_list (List) {
+// optional group nullable {
+// repeated group inner_list (List) {
+// optional int64 Entries;
+// }
+// }
+// }
+// }
+// So:
+// def level 0: null outer list
+// def level 1: empty outer list
+// def level 2: null inner list
+// def level 3: empty inner list
+// def level 4: null entry
+// def level 5: non-null entry
+
+func TestNestedListsWithSomeEntries(t *testing.T) {
+ listType := arrow.ListOf(arrow.PrimitiveTypes.Int64)
+ bldr := array.NewListBuilder(memory.DefaultAllocator, listType)
+ defer bldr.Release()
+
+ nestedBldr := bldr.ValueBuilder().(*array.ListBuilder)
+ vb := nestedBldr.ValueBuilder().(*array.Int64Builder)
+
+ // produce: [null, [[1, 2, 3], [4, 5]], [[], [], []], []]
+
+ bldr.AppendNull()
+ bldr.Append(true)
+ nestedBldr.Append(true)
+ vb.AppendValues([]int64{1, 2, 3}, nil)
+ nestedBldr.Append(true)
+ vb.AppendValues([]int64{4, 5}, nil)
+
+ bldr.Append(true)
+ nestedBldr.Append(true)
+ nestedBldr.Append(true)
+ nestedBldr.Append(true)
+ bldr.Append(true)
+
+ arr := bldr.NewListArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, true)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.write(0, ctx)
+ require.NoError(t, err)
+
+ assert.Equal(t, []int16{0, 5, 5, 5, 5, 5, 3, 3, 3, 1}, result.defLevels)
+ assert.Equal(t, []int16{0, 0, 2, 2, 1, 2, 0, 1, 1, 0}, result.repLevels)
+}
+
+func TestNestedListsWithSomeNulls(t *testing.T) {
+ listType := arrow.ListOf(arrow.PrimitiveTypes.Int64)
+ bldr := array.NewListBuilder(memory.DefaultAllocator, listType)
+ defer bldr.Release()
+
+ nestedBldr := bldr.ValueBuilder().(*array.ListBuilder)
+ vb := nestedBldr.ValueBuilder().(*array.Int64Builder)
+
+ // produce: [null, [[1, null, 3], null, null], [[4, 5]]]
+
+ bldr.AppendNull()
+ bldr.Append(true)
+ nestedBldr.Append(true)
+ vb.AppendValues([]int64{1, 0, 3}, []bool{true, false, true})
+ nestedBldr.AppendNull()
+ nestedBldr.AppendNull()
+ bldr.Append(true)
+ nestedBldr.Append(true)
+ vb.AppendValues([]int64{4, 5}, nil)
+
+ arr := bldr.NewListArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, true)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.write(0, ctx)
+ require.NoError(t, err)
+
+ assert.Equal(t, []int16{0, 5, 4, 5, 2, 2, 5, 5}, result.defLevels)
+ assert.Equal(t, []int16{0, 0, 2, 2, 1, 1, 0, 2}, result.repLevels)
+}
+
+func TestNestedListsSomeNullsSomeEmpty(t *testing.T) {
+ listType := arrow.ListOf(arrow.PrimitiveTypes.Int64)
+ bldr := array.NewListBuilder(memory.DefaultAllocator, listType)
+ defer bldr.Release()
+
+ nestedBldr := bldr.ValueBuilder().(*array.ListBuilder)
+ vb := nestedBldr.ValueBuilder().(*array.Int64Builder)
+
+ // produce: [null, [[1, null, 3], [], []], [[4, 5]]]
+
+ bldr.AppendNull()
+ bldr.Append(true)
+ nestedBldr.Append(true)
+ vb.AppendValues([]int64{1, 0, 3}, []bool{true, false, true})
+ nestedBldr.Append(true)
+ nestedBldr.Append(true)
+ bldr.Append(true)
+ nestedBldr.Append(true)
+ vb.AppendValues([]int64{4, 5}, nil)
+
+ arr := bldr.NewListArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, true)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.write(0, ctx)
+ require.NoError(t, err)
+
+ assert.Equal(t, []int16{0, 5, 4, 5, 3, 3, 5, 5}, result.defLevels)
+ assert.Equal(t, []int16{0, 0, 2, 2, 1, 1, 0, 2}, result.repLevels)
+}
+
+// triplenested translates to parquet:
+//
+// optional group bag {
+// repeated group outer_list (List) {
+// option group nullable {
+// repeated group middle_list (List) {
+// option group nullable {
+// repeated group inner_list (List) {
+// optional int64 Entries;
+// }
+// }
+// }
+// }
+// }
+// }
+// So:
+// def level 0: a outer list
+// def level 1: an empty outer list
+// def level 2: a null middle list
+// def level 3: an empty middle list
+// def level 4: an null inner list
+// def level 5: an empty inner list
+// def level 6: a null entry
+// def level 7: a non-null entry
+
+func TestTripleNestedAllPresent(t *testing.T) {
+ listType := arrow.ListOf(arrow.PrimitiveTypes.Int64)
+ nestedListType := arrow.ListOf(listType)
+ bldr := array.NewListBuilder(memory.DefaultAllocator, nestedListType)
+ defer bldr.Release()
+
+ dblNestedBldr := bldr.ValueBuilder().(*array.ListBuilder)
+ nestedBldr := dblNestedBldr.ValueBuilder().(*array.ListBuilder)
+ vb := nestedBldr.ValueBuilder().(*array.Int64Builder)
+
+ // produce: [ [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]] ]
+ bldr.Append(true)
+ dblNestedBldr.Append(true)
+ nestedBldr.Append(true)
+ vb.AppendValues([]int64{1, 2, 3}, nil)
+ nestedBldr.Append(true)
+ vb.AppendValues([]int64{4, 5, 6}, nil)
+
+ dblNestedBldr.Append(true)
+ nestedBldr.Append(true)
+ vb.AppendValues([]int64{7, 8, 9}, nil)
+
+ arr := bldr.NewListArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, true)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.write(0, ctx)
+ require.NoError(t, err)
+
+ assert.Equal(t, []int16{7, 7, 7, 7, 7, 7, 7, 7, 7}, result.defLevels)
+ assert.Equal(t, []int16{0, 3, 3, 2, 3, 3, 1, 3, 3}, result.repLevels)
+}
+
+func TestTripleNestedSomeNullsSomeEmpty(t *testing.T) {
+ listType := arrow.ListOf(arrow.PrimitiveTypes.Int64)
+ nestedListType := arrow.ListOf(listType)
+ bldr := array.NewListBuilder(memory.DefaultAllocator, nestedListType)
+ defer bldr.Release()
+
+ dblNestedBldr := bldr.ValueBuilder().(*array.ListBuilder)
+ nestedBldr := dblNestedBldr.ValueBuilder().(*array.ListBuilder)
+ vb := nestedBldr.ValueBuilder().(*array.Int64Builder)
+
+ // produce: [
+ // [null, [[1, null, 3], []], []], first row
+ // [[[]], [[], [1, 2]], null, [[3]]], second row
+ // null, third row
+ // [] fourth row
+ // ]
+
+ // first row
+ bldr.Append(true)
+ dblNestedBldr.AppendNull()
+ dblNestedBldr.Append(true)
+ nestedBldr.Append(true)
+ vb.AppendValues([]int64{1, 0, 3}, []bool{true, false, true})
+ nestedBldr.Append(true)
+ dblNestedBldr.Append(true)
+
+ // second row
+ bldr.Append(true)
+ dblNestedBldr.Append(true)
+ nestedBldr.Append(true)
+ dblNestedBldr.Append(true)
+ nestedBldr.Append(true)
+ nestedBldr.Append(true)
+ vb.AppendValues([]int64{1, 2}, nil)
+ dblNestedBldr.AppendNull()
+ dblNestedBldr.Append(true)
+ nestedBldr.Append(true)
+ vb.Append(3)
+
+ // third row
+ bldr.AppendNull()
+
+ // fourth row
+ bldr.Append(true)
+
+ arr := bldr.NewListArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, true)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.write(0, ctx)
+ require.NoError(t, err)
+
+ assert.Equal(t, []int16{
+ 2, 7, 6, 7, 5, 3, // first row
+ 5, 5, 7, 7, 2, 7, // second row
+ 0, // third row
+ 1,
+ }, result.defLevels)
+ assert.Equal(t, []int16{
+ 0, 1, 3, 3, 2, 1, // first row
+ 0, 1, 2, 3, 1, 1, // second row
+ 0, 0,
+ }, result.repLevels)
+}
+
+func TestStruct(t *testing.T) {
+ structType := arrow.StructOf(arrow.Field{Name: "list", Type: arrow.ListOf(arrow.PrimitiveTypes.Int64), Nullable: true},
+ arrow.Field{Name: "Entries", Type: arrow.PrimitiveTypes.Int64, Nullable: true})
+
+ bldr := array.NewStructBuilder(memory.DefaultAllocator, structType)
+ defer bldr.Release()
+
+ entryBldr := bldr.FieldBuilder(1).(*array.Int64Builder)
+ listBldr := bldr.FieldBuilder(0).(*array.ListBuilder)
+ vb := listBldr.ValueBuilder().(*array.Int64Builder)
+
+ // produce: [ {"Entries": 1, "list": [2, 3]}, {"Entries": 4, "list": [5, 6]}, null]
+
+ bldr.Append(true)
+ entryBldr.Append(1)
+ listBldr.Append(true)
+ vb.AppendValues([]int64{2, 3}, nil)
+
+ bldr.Append(true)
+ entryBldr.Append(4)
+ listBldr.Append(true)
+ vb.AppendValues([]int64{5, 6}, nil)
+
+ bldr.AppendNull()
+
+ arr := bldr.NewArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, true)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.writeAll(ctx)
+ require.NoError(t, err)
+
+ assert.Len(t, result, 2)
+ assert.Equal(t, []int16{4, 4, 4, 4, 0}, result[0].defLevels)
+ assert.Equal(t, []int16{0, 1, 0, 1, 0}, result[0].repLevels)
+
+ assert.Equal(t, []int16{2, 2, 0}, result[1].defLevels)
+ assert.Nil(t, result[1].repLevels)
+}
+
+func TestFixedSizeListNullableElems(t *testing.T) {
+ bldr := array.NewFixedSizeListBuilder(memory.DefaultAllocator, 2, arrow.PrimitiveTypes.Int64)
+ defer bldr.Release()
+
+ vb := bldr.ValueBuilder().(*array.Int64Builder)
+ bldr.AppendValues([]bool{false, true, true, false})
+ vb.AppendValues([]int64{2, 3, 4, 5}, nil)
+
+ // produce: [null, [2, 3], [4, 5], null]
+
+ arr := bldr.NewArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, true)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.writeAll(ctx)
+ require.NoError(t, err)
+
+ assert.Len(t, result, 1)
+ assert.Equal(t, []int16{0, 3, 3, 3, 3, 0}, result[0].defLevels)
+ assert.Equal(t, []int16{0, 0, 1, 0, 1, 0}, result[0].repLevels)
+
+ // null slots take up space in a fixed size list (they can in variable
+ // size lists as well) but the actual written values are only the middle
+ // elements
+ assert.Len(t, result[0].postListVisitedElems, 1)
+ assert.EqualValues(t, 2, result[0].postListVisitedElems[0].start)
+ assert.EqualValues(t, 6, result[0].postListVisitedElems[0].end)
+}
+
+func TestFixedSizeListMissingMiddleTwoVisitedRanges(t *testing.T) {
+ bldr := array.NewFixedSizeListBuilder(memory.DefaultAllocator, 2, arrow.PrimitiveTypes.Int64)
+ defer bldr.Release()
+
+ vb := bldr.ValueBuilder().(*array.Int64Builder)
+ bldr.AppendValues([]bool{true, false, true})
+ vb.AppendValues([]int64{0, 1, 2, 3}, nil)
+
+ // produce: [[0, 1], null, [2, 3]]
+
+ arr := bldr.NewArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, true)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.writeAll(ctx)
+ require.NoError(t, err)
+
+ assert.Len(t, result, 1)
+ assert.Equal(t, []int16{3, 3, 0, 3, 3}, result[0].defLevels)
+ assert.Equal(t, []int16{0, 1, 0, 0, 1}, result[0].repLevels)
+
+ // null slots take up space in a fixed size list (they can in variable
+ // size lists as well) but the actual written values are only the middle
+ // elements
+ assert.Len(t, result[0].postListVisitedElems, 2)
+ assert.EqualValues(t, 0, result[0].postListVisitedElems[0].start)
+ assert.EqualValues(t, 2, result[0].postListVisitedElems[0].end)
+
+ assert.EqualValues(t, 4, result[0].postListVisitedElems[1].start)
+ assert.EqualValues(t, 6, result[0].postListVisitedElems[1].end)
+}
+
+func TestPrimitiveNonNullable(t *testing.T) {
+ bldr := array.NewInt64Builder(memory.DefaultAllocator)
+ defer bldr.Release()
+
+ bldr.AppendValues([]int64{1, 2, 3, 4}, nil)
+
+ arr := bldr.NewArray()
+ defer arr.Release()
+
+ mp, err := newMultipathLevelBuilder(arr, false)
+ require.NoError(t, err)
+ defer mp.Release()
+
+ ctx := arrowCtxFromContext(NewArrowWriteContext(context.Background(), nil))
+ result, err := mp.write(0, ctx)
+ require.NoError(t, err)
+
+ assert.Nil(t, result.defLevels)
+ assert.Nil(t, result.repLevels)
+
+ assert.Len(t, result.postListVisitedElems, 1)
+ assert.EqualValues(t, 0, result.postListVisitedElems[0].start)
+ assert.EqualValues(t, 4, result.postListVisitedElems[0].end)
+}
diff --git a/go/parquet/pqarrow/properties.go b/go/parquet/pqarrow/properties.go
new file mode 100644
index 0000000..fbdc79f
--- /dev/null
+++ b/go/parquet/pqarrow/properties.go
@@ -0,0 +1,171 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pqarrow
+
+import (
+ "context"
+
+ "github.com/apache/arrow/go/arrow"
+ "github.com/apache/arrow/go/arrow/memory"
+ "github.com/apache/arrow/go/parquet/internal/encoding"
+)
+
+// ArrowWriterProperties are used to determine how to manipulate the arrow data
+// when writing it to a parquet file.
+type ArrowWriterProperties struct {
+ mem memory.Allocator
+ timestampAsInt96 bool
+ coerceTimestamps bool
+ coerceTimestampUnit arrow.TimeUnit
+ allowTruncatedTimestamps bool
+ storeSchema bool
+ noMapLogicalType bool
+ // compliantNestedTypes bool
+}
+
+// DefaultWriterProps returns the default properties for the arrow writer,
+// which are to use memory.DefaultAllocator and coerceTimestampUnit: arrow.Second.
+func DefaultWriterProps() ArrowWriterProperties {
+ return ArrowWriterProperties{
+ mem: memory.DefaultAllocator,
+ coerceTimestampUnit: arrow.Second,
+ }
+}
+
+type config struct {
+ props ArrowWriterProperties
+}
+
+// WriterOption is a convenience for building up arrow writer properties
+type WriterOption func(*config)
+
+// NewArrowWriterProperties creates a new writer properties object by passing in
+// a set of options to control the properties. Once created, an individual instance
+// of ArrowWriterProperties is immutable.
+func NewArrowWriterProperties(opts ...WriterOption) ArrowWriterProperties {
+ cfg := config{DefaultWriterProps()}
+ for _, o := range opts {
+ o(&cfg)
+ }
+ return cfg.props
+}
+
+// WithAllocator specifies the allocator to be used by the writer whenever allocating
+// buffers and memory.
+func WithAllocator(mem memory.Allocator) WriterOption {
+ return func(c *config) {
+ c.props.mem = mem
+ }
+}
+
+// WithDeprecatedInt96Timestamps allows specifying to enable conversion of arrow timestamps
+// to int96 columns when constructing the schema. Since int96 is the impala standard, it's
+// technically deprecated in terms of parquet files but is sometimes needed.
+func WithDeprecatedInt96Timestamps(enabled bool) WriterOption {
+ return func(c *config) {
+ c.props.timestampAsInt96 = enabled
+ }
+}
+
+// WithCoerceTimestamps enables coercing of timestamp units to a specific time unit
+// when constructing the schema and writing data so that regardless of the unit used
+// by the datatypes being written, they will be converted to the desired time unit.
+func WithCoerceTimestamps(unit arrow.TimeUnit) WriterOption {
+ return func(c *config) {
+ c.props.coerceTimestamps = true
+ c.props.coerceTimestampUnit = unit
+ }
+}
+
+// WithTruncatedTimestamps called with true turns off the error that would be returned
+// if coercing a timestamp unit would cause a loss of data such as converting from
+// nanoseconds to seconds.
+func WithTruncatedTimestamps(allow bool) WriterOption {
+ return func(c *config) {
+ c.props.allowTruncatedTimestamps = allow
+ }
+}
+
+// WithStoreSchema enables writing a binary serialized arrow schema to the file in metadata
+// to enable certain read options (like "read_dictionary") to be set automatically
+//
+// If called, the arrow schema is serialized and base64 encoded before being added to the
+// metadata of the parquet file with the key "ARROW:schema". If the key exists when
+// opening a file for read with pqarrow.FileReader, the schema will be used to choose
+// types and options when constructing the arrow schema of the resulting data.
+func WithStoreSchema() WriterOption {
+ return func(c *config) {
+ c.props.storeSchema = true
+ }
+}
+
+func WithNoMapLogicalType() WriterOption {
+ return func(c *config) {
+ c.props.noMapLogicalType = true
+ }
+}
+
+// func WithCompliantNestedTypes(enabled bool) WriterOption {
+// return func(c *config) {
+// c.props.compliantNestedTypes = enabled
+// }
+// }
+
+type arrowWriteContext struct {
+ props ArrowWriterProperties
+ dataBuffer *memory.Buffer
+ defLevelsBuffer encoding.Buffer
+ repLevelsBuffer encoding.Buffer
+}
+
+type arrowCtxKey struct{}
+
+// NewArrowWriteContext is for creating a re-usable context object that contains writer properties
+// and other re-usable buffers for writing. The resulting context should not be used to write
+// multiple columns concurrently. If nil is passed, then DefaultWriterProps will be used.
+func NewArrowWriteContext(ctx context.Context, props *ArrowWriterProperties) context.Context {
+ if props == nil {
+ p := DefaultWriterProps()
+ props = &p
+ }
+ return context.WithValue(ctx, arrowCtxKey{}, &arrowWriteContext{props: *props})
+}
+
+func arrowCtxFromContext(ctx context.Context) *arrowWriteContext {
+ awc := ctx.Value(arrowCtxKey{})
+ if awc != nil {
+ return awc.(*arrowWriteContext)
+ }
+
+ return &arrowWriteContext{
+ props: DefaultWriterProps(),
+ }
+}
+
+// ArrowReadProperties is the properties to define how to read a parquet file
+// into arrow arrays.
+type ArrowReadProperties struct {
+ // If Parallel is true, then functions which read multiple columns will read
+ // those columns in parallel from the file with a number of readers equal
+ // to the number of columns. Otherwise columns are read serially.
+ Parallel bool
+ // BatchSize is the size used for calls to NextBatch when reading whole columns
+ BatchSize int64
+
+ readDict map[int]bool
+ preBuffer bool
+}
diff --git a/go/parquet/pqarrow/reader_writer_test.go b/go/parquet/pqarrow/reader_writer_test.go
new file mode 100644
index 0000000..9c9cd2e
--- /dev/null
+++ b/go/parquet/pqarrow/reader_writer_test.go
@@ -0,0 +1,335 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pqarrow_test
+
+import (
+ "bytes"
+ "context"
+ "testing"
+ "unsafe"
+
+ "github.com/apache/arrow/go/arrow"
+ "github.com/apache/arrow/go/arrow/array"
+ "github.com/apache/arrow/go/arrow/memory"
+ "github.com/apache/arrow/go/parquet"
+ "github.com/apache/arrow/go/parquet/file"
+ "github.com/apache/arrow/go/parquet/pqarrow"
+ "golang.org/x/exp/rand"
+ "gonum.org/v1/gonum/stat/distuv"
+)
+
+const alternateOrNA = -1
+const SIZELEN = 10 * 1024 * 1024
+
+func randomUint8(size, truePct int, sampleVals [2]uint8, seed uint64) []uint8 {
+ ret := make([]uint8, size)
+ if truePct == alternateOrNA {
+ for idx := range ret {
+ ret[idx] = uint8(idx % 2)
+ }
+ return ret
+ }
+
+ dist := distuv.Bernoulli{
+ P: float64(truePct) / 100.0,
+ Src: rand.NewSource(seed),
+ }
+
+ for idx := range ret {
+ ret[idx] = sampleVals[int(dist.Rand())]
+ }
+ return ret
+}
+
+func randomInt32(size, truePct int, sampleVals [2]int32, seed uint64) []int32 {
+ ret := make([]int32, size)
+ if truePct == alternateOrNA {
+ for idx := range ret {
+ ret[idx] = int32(idx % 2)
+ }
+ return ret
+ }
+
+ dist := distuv.Bernoulli{
+ P: float64(truePct) / 100.0,
+ Src: rand.NewSource(seed),
+ }
+
+ for idx := range ret {
+ ret[idx] = sampleVals[int(dist.Rand())]
+ }
+ return ret
+}
+
+func randomInt64(size, truePct int, sampleVals [2]int64, seed uint64) []int64 {
+ ret := make([]int64, size)
+ if truePct == alternateOrNA {
+ for idx := range ret {
+ ret[idx] = int64(idx % 2)
+ }
+ return ret
+ }
+
+ dist := distuv.Bernoulli{
+ P: float64(truePct) / 100.0,
+ Src: rand.NewSource(seed),
+ }
+
+ for idx := range ret {
+ ret[idx] = sampleVals[int(dist.Rand())]
+ }
+ return ret
+}
+
+func randomFloat32(size, truePct int, sampleVals [2]float32, seed uint64) []float32 {
+ ret := make([]float32, size)
+ if truePct == alternateOrNA {
+ for idx := range ret {
+ ret[idx] = float32(idx % 2)
+ }
+ return ret
+ }
+
+ dist := distuv.Bernoulli{
+ P: float64(truePct) / 100.0,
+ Src: rand.NewSource(seed),
+ }
+
+ for idx := range ret {
+ ret[idx] = sampleVals[int(dist.Rand())]
+ }
+ return ret
+}
+
+func randomFloat64(size, truePct int, sampleVals [2]float64, seed uint64) []float64 {
+ ret := make([]float64, size)
+ if truePct == alternateOrNA {
+ for idx := range ret {
+ ret[idx] = float64(idx % 2)
+ }
+ return ret
+ }
+
+ dist := distuv.Bernoulli{
+ P: float64(truePct) / 100.0,
+ Src: rand.NewSource(seed),
+ }
+
+ for idx := range ret {
+ ret[idx] = sampleVals[int(dist.Rand())]
+ }
+ return ret
+}
+
+func tableFromVec(dt arrow.DataType, size int, data interface{}, nullable bool, nullPct int) array.Table {
+ if !nullable && nullPct != alternateOrNA {
+ panic("bad check")
+ }
+
+ var valid []bool
+ if nullable {
+ // true values select index 1 of sample values
+ validBytes := randomUint8(size, nullPct, [2]uint8{1, 0}, 500)
+ valid = *(*[]bool)(unsafe.Pointer(&validBytes))
+ }
+
+ bldr := array.NewBuilder(memory.DefaultAllocator, dt)
+ defer bldr.Release()
+
+ switch v := data.(type) {
+ case []int32:
+ bldr.(*array.Int32Builder).AppendValues(v, valid)
+ case []int64:
+ bldr.(*array.Int64Builder).AppendValues(v, valid)
+ case []float32:
+ bldr.(*array.Float32Builder).AppendValues(v, valid)
+ case []float64:
+ bldr.(*array.Float64Builder).AppendValues(v, valid)
+ }
+
+ arr := bldr.NewArray()
+
+ field := arrow.Field{Name: "column", Type: dt, Nullable: nullable}
+ sc := arrow.NewSchema([]arrow.Field{field}, nil)
+ return array.NewTable(sc, []array.Column{*array.NewColumn(field, array.NewChunked(dt, []array.Interface{arr}))}, int64(size))
+}
+
+func BenchmarkWriteColumn(b *testing.B) {
+ int32Values := make([]int32, SIZELEN)
+ int64Values := make([]int64, SIZELEN)
+ float32Values := make([]float32, SIZELEN)
+ float64Values := make([]float64, SIZELEN)
+ for i := 0; i < SIZELEN; i++ {
+ int32Values[i] = 128
+ int64Values[i] = 128
+ float32Values[i] = 128
+ float64Values[i] = 128
+ }
+
+ tests := []struct {
+ name string
+ dt arrow.DataType
+ values interface{}
+ nullable bool
+ nbytes int64
+ }{
+ {"int32 not nullable", arrow.PrimitiveTypes.Int32, int32Values, false, int64(arrow.Int32Traits.BytesRequired(SIZELEN))},
+ {"int32 nullable", arrow.PrimitiveTypes.Int32, int32Values, true, int64(arrow.Int32Traits.BytesRequired(SIZELEN))},
+ {"int64 not nullable", arrow.PrimitiveTypes.Int64, int64Values, false, int64(arrow.Int64Traits.BytesRequired(SIZELEN))},
+ {"int64 nullable", arrow.PrimitiveTypes.Int64, int64Values, true, int64(arrow.Int64Traits.BytesRequired(SIZELEN))},
+ {"float32 not nullable", arrow.PrimitiveTypes.Float32, float32Values, false, int64(arrow.Float32Traits.BytesRequired(SIZELEN))},
+ {"float32 nullable", arrow.PrimitiveTypes.Float32, float32Values, true, int64(arrow.Float32Traits.BytesRequired(SIZELEN))},
+ {"float64 not nullable", arrow.PrimitiveTypes.Float64, float64Values, false, int64(arrow.Float64Traits.BytesRequired(SIZELEN))},
+ {"float64 nullable", arrow.PrimitiveTypes.Float64, float64Values, true, int64(arrow.Float64Traits.BytesRequired(SIZELEN))},
+ }
+
+ props := parquet.NewWriterProperties(parquet.WithDictionaryDefault(false))
+ arrProps := pqarrow.DefaultWriterProps()
+
+ for _, tt := range tests {
+ b.Run(tt.name, func(b *testing.B) {
+ tbl := tableFromVec(tt.dt, SIZELEN, tt.values, tt.nullable, alternateOrNA)
+ b.Cleanup(func() { tbl.Release() })
+ var buf bytes.Buffer
+ buf.Grow(int(tt.nbytes))
+ b.ResetTimer()
+ b.SetBytes(tt.nbytes)
+
+ for i := 0; i < b.N; i++ {
+ buf.Reset()
+ err := pqarrow.WriteTable(tbl, &buf, SIZELEN, props, arrProps, memory.DefaultAllocator)
+ if err != nil {
+ b.Error(err)
+ }
+ }
+ })
+ }
+}
+
+func benchReadTable(b *testing.B, name string, tbl array.Table, nbytes int64) {
+ props := parquet.NewWriterProperties(parquet.WithDictionaryDefault(false))
+ arrProps := pqarrow.DefaultWriterProps()
+
+ var buf bytes.Buffer
+ if err := pqarrow.WriteTable(tbl, &buf, SIZELEN, props, arrProps, memory.DefaultAllocator); err != nil {
+ b.Error(err)
+ }
+ ctx := context.Background()
+
+ b.ResetTimer()
+ b.Run(name, func(b *testing.B) {
+ b.SetBytes(nbytes)
+
+ for i := 0; i < b.N; i++ {
+ pf, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()), nil, nil)
+ if err != nil {
+ b.Error(err)
+ }
+
+ reader, err := pqarrow.NewFileReader(pf, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
+ if err != nil {
+ b.Error(err)
+ }
+
+ tbl, err := reader.ReadTable(ctx)
+ if err != nil {
+ b.Error(err)
+ }
+ defer tbl.Release()
+ }
+ })
+}
+
+func BenchmarkReadColumnInt32(b *testing.B) {
+ tests := []struct {
+ name string
+ nullable bool
+ nullPct int
+ fvPct int
+ }{
+ {"int32 not null 1pct", false, alternateOrNA, 1},
+ {"int32 not null 10pct", false, alternateOrNA, 10},
+ {"int32 not null 50pct", false, alternateOrNA, 50},
+ {"int32 nullable alt", true, alternateOrNA, 0},
+ {"int32 nullable 1pct 1pct", true, 1, 1},
+ {"int32 nullable 10pct 10pct", true, 10, 10},
+ {"int32 nullable 25pct 5pct", true, 25, 5},
+ {"int32 nullable 50pct 50pct", true, 50, 50},
+ {"int32 nullable 50pct 0pct", true, 50, 0},
+ {"int32 nullable 99pct 50pct", true, 99, 50},
+ {"int32 nullable 99pct 0pct", true, 99, 0},
+ }
+
+ for _, tt := range tests {
+ values := randomInt32(SIZELEN, tt.fvPct, [2]int32{127, 128}, 500)
+ tbl := tableFromVec(arrow.PrimitiveTypes.Int32, SIZELEN, values, tt.nullable, tt.nullPct)
+ benchReadTable(b, tt.name, tbl, int64(arrow.Int32Traits.BytesRequired(SIZELEN)))
+ }
+}
+
+func BenchmarkReadColumnInt64(b *testing.B) {
+ tests := []struct {
+ name string
+ nullable bool
+ nullPct int
+ fvPct int
+ }{
+ {"int64 not null 1pct", false, alternateOrNA, 1},
+ {"int64 not null 10pct", false, alternateOrNA, 10},
+ {"int64 not null 50pct", false, alternateOrNA, 50},
+ {"int64 nullable alt", true, alternateOrNA, 0},
+ {"int64 nullable 1pct 1pct", true, 1, 1},
+ {"int64 nullable 5pct 5pct", true, 5, 5},
+ {"int64 nullable 10pct 5pct", true, 10, 5},
+ {"int64 nullable 25pct 10pct", true, 25, 10},
+ {"int64 nullable 30pct 10pct", true, 30, 10},
+ {"int64 nullable 35pct 10pct", true, 35, 10},
+ {"int64 nullable 45pct 25pct", true, 45, 25},
+ {"int64 nullable 50pct 50pct", true, 50, 50},
+ {"int64 nullable 50pct 1pct", true, 50, 1},
+ {"int64 nullable 75pct 1pct", true, 75, 1},
+ {"int64 nullable 99pct 50pct", true, 99, 50},
+ {"int64 nullable 99pct 0pct", true, 99, 0},
+ }
+
+ for _, tt := range tests {
+ values := randomInt32(SIZELEN, tt.fvPct, [2]int32{127, 128}, 500)
+ tbl := tableFromVec(arrow.PrimitiveTypes.Int32, SIZELEN, values, tt.nullable, tt.nullPct)
+ benchReadTable(b, tt.name, tbl, int64(arrow.Int32Traits.BytesRequired(SIZELEN)))
+ }
+}
+
+func BenchmarkReadColumnFloat64(b *testing.B) {
+ tests := []struct {
+ name string
+ nullable bool
+ nullPct int
+ fvPct int
+ }{
+ {"double not null 1pct", false, alternateOrNA, 0},
+ {"double not null 20pct", false, alternateOrNA, 20},
+ {"double nullable alt", true, alternateOrNA, 0},
+ {"double nullable 10pct 50pct", true, 10, 50},
+ {"double nullable 25pct 25pct", true, 25, 25},
+ }
+
+ for _, tt := range tests {
+ values := randomInt32(SIZELEN, tt.fvPct, [2]int32{127, 128}, 500)
+ tbl := tableFromVec(arrow.PrimitiveTypes.Int32, SIZELEN, values, tt.nullable, tt.nullPct)
+ benchReadTable(b, tt.name, tbl, int64(arrow.Int32Traits.BytesRequired(SIZELEN)))
+ }
+}
diff --git a/go/parquet/pqarrow/schema.go b/go/parquet/pqarrow/schema.go
new file mode 100644
index 0000000..5c72a41
--- /dev/null
+++ b/go/parquet/pqarrow/schema.go
@@ -0,0 +1,1072 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pqarrow
+
+import (
+ "encoding/base64"
+ "math"
+ "strconv"
+ "strings"
+
+ "github.com/apache/arrow/go/arrow"
+ "github.com/apache/arrow/go/arrow/flight"
+ "github.com/apache/arrow/go/arrow/memory"
+ "github.com/apache/arrow/go/parquet"
+ "github.com/apache/arrow/go/parquet/file"
+ "github.com/apache/arrow/go/parquet/metadata"
+ "github.com/apache/arrow/go/parquet/schema"
+ "golang.org/x/xerrors"
+)
+
+// SchemaField is a holder that defines a specific logical field in the schema
+// which could potentially refer to multiple physical columns in the underlying
+// parquet file if it is a nested type.
+//
+// ColIndex is only populated (not -1) when it is a leaf column.
+type SchemaField struct {
+ Field *arrow.Field
+ Children []SchemaField
+ ColIndex int
+ LevelInfo file.LevelInfo
+}
+
+// IsLeaf returns true if the SchemaField is a leaf column, ie: ColIndex != -1
+func (s *SchemaField) IsLeaf() bool { return s.ColIndex != -1 }
+
+// SchemaManifest represents a full manifest for mapping a Parquet schema
+// to an arrow Schema.
+type SchemaManifest struct {
+ descr *schema.Schema
+ OriginSchema *arrow.Schema
+ SchemaMeta *arrow.Metadata
+
+ ColIndexToField map[int]*SchemaField
+ ChildToParent map[*SchemaField]*SchemaField
+ Fields []SchemaField
+}
+
+// GetColumnField returns the corresponding Field for a given column index.
+func (sm *SchemaManifest) GetColumnField(index int) (*SchemaField, error) {
+ if field, ok := sm.ColIndexToField[index]; ok {
+ return field, nil
+ }
+ return nil, xerrors.Errorf("Column Index %d not found in schema manifest", index)
+}
+
+// GetParent gets the parent field for a given field if it is a nested column, otherwise
+// returns nil if there is no parent field.
+func (sm *SchemaManifest) GetParent(field *SchemaField) *SchemaField {
+ if p, ok := sm.ChildToParent[field]; ok {
+ return p
+ }
+ return nil
+}
+
+// GetFieldIndices coalesces a list of field indices (relative to the equivalent arrow::Schema) which
+// correspond to the column root (first node below the parquet schema's root group) of
+// each leaf referenced in column_indices.
+//
+// For example, for leaves `a.b.c`, `a.b.d.e`, and `i.j.k` (column_indices=[0,1,3])
+// the roots are `a` and `i` (return=[0,2]).
+//
+// root
+// -- a <------
+// -- -- b | |
+// -- -- -- c |
+// -- -- -- d |
+// -- -- -- -- e
+// -- f
+// -- -- g
+// -- -- -- h
+// -- i <---
+// -- -- j |
+// -- -- -- k
+func (sm *SchemaManifest) GetFieldIndices(indices []int) ([]int, error) {
+ added := make(map[int]bool)
+ ret := make([]int, 0)
+
+ for _, idx := range indices {
+ if idx < 0 || idx >= sm.descr.NumColumns() {
+ return nil, xerrors.Errorf("column index %d is not valid", idx)
+ }
+
+ fieldNode := sm.descr.ColumnRoot(idx)
+ fieldIdx := sm.descr.Root().FieldIndexByField(fieldNode)
+ if fieldIdx == -1 {
+ return nil, xerrors.Errorf("column index %d is not valid", idx)
+ }
+
+ if _, ok := added[fieldIdx]; !ok {
+ ret = append(ret, fieldIdx)
+ added[fieldIdx] = true
+ }
+ }
+ return ret, nil
+}
+
+func arrowTimestampToLogical(typ *arrow.TimestampType, unit arrow.TimeUnit) schema.LogicalType {
+ utc := typ.TimeZone == "" || typ.TimeZone == "UTC"
+
+ // for forward compatibility reasons, and because there's no other way
+ // to signal to old readers that values are timestamps, we force
+ // the convertedtype field to be set to the corresponding TIMESTAMP_* value.
+ // this does cause some ambiguity as parquet readers have not been consistent
+ // about the interpretation of TIMESTAMP_* values as being utc-normalized
+ // see ARROW-5878
+ var scunit schema.TimeUnitType
+ switch unit {
+ case arrow.Millisecond:
+ scunit = schema.TimeUnitMillis
+ case arrow.Microsecond:
+ scunit = schema.TimeUnitMicros
+ case arrow.Nanosecond:
+ scunit = schema.TimeUnitNanos
+ case arrow.Second:
+ // no equivalent in parquet
+ return schema.NoLogicalType{}
+ }
+
+ return schema.NewTimestampLogicalTypeForce(utc, scunit)
+}
+
+func getTimestampMeta(typ *arrow.TimestampType, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (parquet.Type, schema.LogicalType, error) {
+ coerce := arrprops.coerceTimestamps
+ target := typ.Unit
+ if coerce {
+ target = arrprops.coerceTimestampUnit
+ }
+
+ // user is explicitly asking for int96, no logical type
+ if arrprops.timestampAsInt96 && target == arrow.Nanosecond {
+ return parquet.Types.Int96, schema.NoLogicalType{}, nil
+ }
+
+ physical := parquet.Types.Int64
+ logicalType := arrowTimestampToLogical(typ, target)
+
+ // user is explicitly asking for timestamp data to be converted to the specified
+ // units (target) via coercion
+ if coerce {
+ if props.Version() == parquet.V1_0 || props.Version() == parquet.V2_4 {
+ switch target {
+ case arrow.Millisecond, arrow.Microsecond:
+ case arrow.Nanosecond, arrow.Second:
+ return physical, nil, xerrors.Errorf("parquet version %s files can only coerce arrow timestamps to millis or micros", props.Version())
+ }
+ } else if target == arrow.Second {
+ return physical, nil, xerrors.Errorf("parquet version %s files can only coerce arrow timestampts to millis, micros or nanos", props.Version())
+ }
+ return physical, logicalType, nil
+ }
+
+ // the user implicitly wants timestamp data to retain its original time units
+ // however the converted type field used to indicate logical types for parquet
+ // version <=2.4 fields, does not allow for nanosecond time units and so nanos
+ // must be coerced to micros
+ if (props.Version() == parquet.V1_0 || props.Version() == parquet.V2_4) && typ.Unit == arrow.Nanosecond {
+ logicalType = arrowTimestampToLogical(typ, arrow.Microsecond)
+ return physical, logicalType, nil
+ }
+
+ // the user implicitly wants timestamp data to retain it's original time units,
+ // however the arrow seconds time unit cannot be represented in parquet, so must
+ // be coerced to milliseconds
+ if typ.Unit == arrow.Second {
+ logicalType = arrowTimestampToLogical(typ, arrow.Millisecond)
+ }
+
+ return physical, logicalType, nil
+}
+
+// DecimalSize returns the minimum number of bytes necessary to represent a decimal
+// with the requested precision.
+//
+// Taken from the Apache Impala codebase. The comments next to the return values
+// are the maximum value that can be represented in 2's complement with the returned
+// number of bytes
+func DecimalSize(precision int32) int32 {
+ if precision < 1 {
+ panic("precision must be >= 1")
+ }
+
+ // generated in python with:
+ // >>> decimal_size = lambda prec: int(math.ceil((prec * math.log2(10) + 1) / 8))
+ // >>> [-1] + [decimal_size(i) for i in range(1, 77)]
+ var byteblock = [...]int32{
+ -1, 1, 1, 2, 2, 3, 3, 4, 4, 4, 5, 5, 6, 6, 6, 7, 7, 8, 8, 9,
+ 9, 9, 10, 10, 11, 11, 11, 12, 12, 13, 13, 13, 14, 14, 15, 15, 16, 16, 16, 17,
+ 17, 18, 18, 18, 19, 19, 20, 20, 21, 21, 21, 22, 22, 23, 23, 23, 24, 24, 25, 25,
+ 26, 26, 26, 27, 27, 28, 28, 28, 29, 29, 30, 30, 31, 31, 31, 32, 32,
+ }
+
+ if precision <= 76 {
+ return byteblock[precision]
+ }
+ return int32(math.Ceil(float64(precision)/8.0)*math.Log2(10) + 1)
+}
+
+func repFromNullable(isnullable bool) parquet.Repetition {
+ if isnullable {
+ return parquet.Repetitions.Optional
+ }
+ return parquet.Repetitions.Required
+}
+
+func structToNode(typ *arrow.StructType, name string, nullable bool, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) {
+ if len(typ.Fields()) == 0 {
+ return nil, xerrors.Errorf("cannot write struct type '%s' with no children field to parquet. Consider adding a dummy child", name)
+ }
+
+ children := make(schema.FieldList, 0, len(typ.Fields()))
+ for _, f := range typ.Fields() {
+ n, err := fieldToNode(f.Name, f, props, arrprops)
+ if err != nil {
+ return nil, err
+ }
+ children = append(children, n)
+ }
+
+ return schema.NewGroupNode(name, repFromNullable(nullable), children, -1)
+}
+
+func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) {
+ var (
+ logicalType schema.LogicalType = schema.NoLogicalType{}
+ typ parquet.Type
+ repType = repFromNullable(field.Nullable)
+ length = -1
+ precision = -1
+ scale = -1
+ err error
+ )
+
+ switch field.Type.ID() {
+ case arrow.NULL:
+ typ = parquet.Types.Int32
+ logicalType = &schema.NullLogicalType{}
+ if repType != parquet.Repetitions.Optional {
+ return nil, xerrors.New("nulltype arrow field must be nullable")
+ }
+ case arrow.BOOL:
+ typ = parquet.Types.Boolean
+ case arrow.UINT8:
+ typ = parquet.Types.Int32
+ logicalType = schema.NewIntLogicalType(8, false)
+ case arrow.INT8:
+ typ = parquet.Types.Int32
+ logicalType = schema.NewIntLogicalType(8, true)
+ case arrow.UINT16:
+ typ = parquet.Types.Int32
+ logicalType = schema.NewIntLogicalType(16, false)
+ case arrow.INT16:
+ typ = parquet.Types.Int32
+ logicalType = schema.NewIntLogicalType(16, true)
+ case arrow.UINT32:
+ typ = parquet.Types.Int32
+ logicalType = schema.NewIntLogicalType(32, false)
+ case arrow.INT32:
+ typ = parquet.Types.Int32
+ logicalType = schema.NewIntLogicalType(32, true)
+ case arrow.UINT64:
+ typ = parquet.Types.Int64
+ logicalType = schema.NewIntLogicalType(64, false)
+ case arrow.INT64:
+ typ = parquet.Types.Int64
+ logicalType = schema.NewIntLogicalType(64, true)
+ case arrow.FLOAT32:
+ typ = parquet.Types.Float
+ case arrow.FLOAT64:
+ typ = parquet.Types.Double
+ case arrow.STRING:
+ logicalType = schema.StringLogicalType{}
+ fallthrough
+ case arrow.BINARY:
+ typ = parquet.Types.ByteArray
+ case arrow.FIXED_SIZE_BINARY:
+ typ = parquet.Types.FixedLenByteArray
+ length = field.Type.(*arrow.FixedSizeBinaryType).ByteWidth
+ case arrow.DECIMAL:
+ typ = parquet.Types.FixedLenByteArray
+ dectype := field.Type.(*arrow.Decimal128Type)
+ precision = int(dectype.Precision)
+ scale = int(dectype.Scale)
+ length = int(DecimalSize(int32(precision)))
+ logicalType = schema.NewDecimalLogicalType(int32(precision), int32(scale))
+ case arrow.DATE32:
+ typ = parquet.Types.Int32
+ logicalType = schema.DateLogicalType{}
+ case arrow.DATE64:
+ typ = parquet.Types.Int64
+ logicalType = schema.NewTimestampLogicalType(true, schema.TimeUnitMillis)
+ case arrow.TIMESTAMP:
+ typ, logicalType, err = getTimestampMeta(field.Type.(*arrow.TimestampType), props, arrprops)
+ if err != nil {
+ return nil, err
+ }
+ case arrow.TIME32:
+ typ = parquet.Types.Int32
+ logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitMillis)
+ case arrow.TIME64:
+ typ = parquet.Types.Int64
+ timeType := field.Type.(*arrow.Time64Type)
+ if timeType.Unit == arrow.Nanosecond {
+ logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitNanos)
+ } else {
+ logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitMicros)
+ }
+ case arrow.STRUCT:
+ return structToNode(field.Type.(*arrow.StructType), field.Name, field.Nullable, props, arrprops)
+ case arrow.FIXED_SIZE_LIST, arrow.LIST:
+ var elem arrow.DataType
+ if lt, ok := field.Type.(*arrow.ListType); ok {
+ elem = lt.Elem()
+ } else {
+ elem = field.Type.(*arrow.FixedSizeListType).Elem()
+ }
+
+ child, err := fieldToNode(name, arrow.Field{Name: name, Type: elem, Nullable: true}, props, arrprops)
+ if err != nil {
+ return nil, err
+ }
+
+ return schema.ListOf(child, repFromNullable(field.Nullable), -1)
+ case arrow.DICTIONARY:
+ // parquet has no dictionary type, dictionary is encoding, not schema level
+ return nil, xerrors.New("not implemented yet")
+ case arrow.EXTENSION:
+ return nil, xerrors.New("not implemented yet")
+ case arrow.MAP:
+ mapType := field.Type.(*arrow.MapType)
+ keyNode, err := fieldToNode("key", mapType.KeyField(), props, arrprops)
+ if err != nil {
+ return nil, err
+ }
+
+ valueNode, err := fieldToNode("value", mapType.ItemField(), props, arrprops)
+ if err != nil {
+ return nil, err
+ }
+
+ if arrprops.noMapLogicalType {
+ keyval := schema.FieldList{keyNode, valueNode}
+ keyvalNode, err := schema.NewGroupNode("key_value", parquet.Repetitions.Repeated, keyval, -1)
+ if err != nil {
+ return nil, err
+ }
+ return schema.NewGroupNode(field.Name, repFromNullable(field.Nullable), schema.FieldList{
+ keyvalNode,
+ }, -1)
+ }
+ return schema.MapOf(field.Name, keyNode, valueNode, repFromNullable(field.Nullable), -1)
+ default:
+ return nil, xerrors.New("not implemented yet")
+ }
+
+ return schema.NewPrimitiveNodeLogical(name, repType, logicalType, typ, length, fieldIDFromMeta(field.Metadata))
+}
+
+const fieldIDKey = "PARQUET:field_id"
+
+func fieldIDFromMeta(m arrow.Metadata) int32 {
+ if m.Len() == 0 {
+ return -1
+ }
+
+ key := m.FindKey(fieldIDKey)
+ if key < 0 {
+ return -1
+ }
+
+ id, err := strconv.ParseInt(m.Values()[key], 10, 32)
+ if err != nil {
+ return -1
+ }
+
+ if id < 0 {
+ return -1
+ }
+
+ return int32(id)
+}
+
+// ToParquet generates a Parquet Schema from an arrow Schema using the given properties to make
+// decisions when determining the logical/physical types of the columns.
+func ToParquet(sc *arrow.Schema, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (*schema.Schema, error) {
+ if props == nil {
+ props = parquet.NewWriterProperties()
+ }
+
+ nodes := make(schema.FieldList, 0, len(sc.Fields()))
+ for _, f := range sc.Fields() {
+ n, err := fieldToNode(f.Name, f, props, arrprops)
+ if err != nil {
+ return nil, err
+ }
+ nodes = append(nodes, n)
+ }
+
+ root, err := schema.NewGroupNode("schema", parquet.Repetitions.Repeated, nodes, -1)
+ return schema.NewSchema(root), err
+}
+
+type schemaTree struct {
+ manifest *SchemaManifest
+
+ schema *schema.Schema
+ props *ArrowReadProperties
+}
+
+func (s schemaTree) LinkParent(child, parent *SchemaField) {
+ s.manifest.ChildToParent[child] = parent
+}
+
+func (s schemaTree) RecordLeaf(leaf *SchemaField) {
+ s.manifest.ColIndexToField[leaf.ColIndex] = leaf
+}
+
+func arrowInt(log *schema.IntLogicalType) (arrow.DataType, error) {
+ switch log.BitWidth() {
+ case 8:
+ if log.IsSigned() {
+ return arrow.PrimitiveTypes.Int8, nil
+ }
+ return arrow.PrimitiveTypes.Uint8, nil
+ case 16:
+ if log.IsSigned() {
+ return arrow.PrimitiveTypes.Int16, nil
+ }
+ return arrow.PrimitiveTypes.Uint16, nil
+ case 32:
+ if log.IsSigned() {
+ return arrow.PrimitiveTypes.Int32, nil
+ }
+ return arrow.PrimitiveTypes.Uint32, nil
+ case 64:
+ if log.IsSigned() {
+ return arrow.PrimitiveTypes.Int64, nil
+ }
+ return arrow.PrimitiveTypes.Uint64, nil
+ default:
+ return nil, xerrors.New("invalid logical type for int32")
+ }
+}
+
+func arrowTime32(logical *schema.TimeLogicalType) (arrow.DataType, error) {
+ if logical.TimeUnit() == schema.TimeUnitMillis {
+ return arrow.FixedWidthTypes.Time32ms, nil
+ }
+
+ return nil, xerrors.New(logical.String() + " cannot annotate a time32")
+}
+
+func arrowTime64(logical *schema.TimeLogicalType) (arrow.DataType, error) {
+ switch logical.TimeUnit() {
+ case schema.TimeUnitMicros:
+ return arrow.FixedWidthTypes.Time64us, nil
+ case schema.TimeUnitNanos:
+ return arrow.FixedWidthTypes.Time64ns, nil
+ default:
+ return nil, xerrors.New(logical.String() + " cannot annotate int64")
+ }
+}
+
+func arrowTimestamp(logical *schema.TimestampLogicalType) (arrow.DataType, error) {
+ tz := "UTC"
+ if logical.IsFromConvertedType() {
+ tz = ""
+ }
+
+ switch logical.TimeUnit() {
+ case schema.TimeUnitMillis:
+ return &arrow.TimestampType{TimeZone: tz, Unit: arrow.Millisecond}, nil
+ case schema.TimeUnitMicros:
+ return &arrow.TimestampType{TimeZone: tz, Unit: arrow.Microsecond}, nil
+ case schema.TimeUnitNanos:
+ return &arrow.TimestampType{TimeZone: tz, Unit: arrow.Nanosecond}, nil
+ default:
+ return nil, xerrors.New("Unrecognized unit in timestamp logical type " + logical.String())
+ }
+}
+
+func arrowFromInt32(logical schema.LogicalType) (arrow.DataType, error) {
+ switch logtype := logical.(type) {
+ case schema.NoLogicalType:
+ return arrow.PrimitiveTypes.Int32, nil
+ case *schema.TimeLogicalType:
+ return arrowTime32(logtype)
+ case *schema.DecimalLogicalType:
+ return &arrow.Decimal128Type{Precision: logtype.Precision(), Scale: logtype.Scale()}, nil
+ case *schema.IntLogicalType:
+ return arrowInt(logtype)
+ case schema.DateLogicalType:
+ return arrow.FixedWidthTypes.Date32, nil
+ default:
+ return nil, xerrors.New(logical.String() + " cannot annotate int32")
+ }
+}
+
+func arrowFromInt64(logical schema.LogicalType) (arrow.DataType, error) {
+ if logical.IsNone() {
+ return arrow.PrimitiveTypes.Int64, nil
+ }
+
+ switch logtype := logical.(type) {
+ case *schema.IntLogicalType:
+ return arrowInt(logtype)
+ case *schema.DecimalLogicalType:
+ return &arrow.Decimal128Type{Precision: logtype.Precision(), Scale: logtype.Scale()}, nil
+ case *schema.TimeLogicalType:
+ return arrowTime64(logtype)
+ case *schema.TimestampLogicalType:
+ return arrowTimestamp(logtype)
+ default:
+ return nil, xerrors.New(logical.String() + " cannot annotate int64")
+ }
+}
+
+func arrowFromByteArray(logical schema.LogicalType) (arrow.DataType, error) {
+ switch logtype := logical.(type) {
+ case schema.StringLogicalType:
+ return arrow.BinaryTypes.String, nil
+ case *schema.DecimalLogicalType:
+ return &arrow.Decimal128Type{Precision: logtype.Precision(), Scale: logtype.Scale()}, nil
+ case schema.NoLogicalType,
+ schema.EnumLogicalType,
+ schema.JSONLogicalType,
+ schema.BSONLogicalType:
+ return arrow.BinaryTypes.Binary, nil
+ default:
+ return nil, xerrors.New("unhandled logicaltype " + logical.String() + " for byte_array")
+ }
+}
+
+func arrowFromFLBA(logical schema.LogicalType, length int) (arrow.DataType, error) {
+ switch logtype := logical.(type) {
+ case *schema.DecimalLogicalType:
+ return &arrow.Decimal128Type{Precision: logtype.Precision(), Scale: logtype.Scale()}, nil
+ case schema.NoLogicalType, schema.IntervalLogicalType, schema.UUIDLogicalType:
+ return &arrow.FixedSizeBinaryType{ByteWidth: int(length)}, nil
+ default:
+ return nil, xerrors.New("unhandled logical type " + logical.String() + " for fixed-length byte array")
+ }
+}
+
+func getArrowType(physical parquet.Type, logical schema.LogicalType, typeLen int) (arrow.DataType, error) {
+ if !logical.IsValid() || logical.Equals(schema.NullLogicalType{}) {
+ return arrow.Null, nil
+ }
+
+ switch physical {
+ case parquet.Types.Boolean:
+ return arrow.FixedWidthTypes.Boolean, nil
+ case parquet.Types.Int32:
+ return arrowFromInt32(logical)
+ case parquet.Types.Int64:
+ return arrowFromInt64(logical)
+ case parquet.Types.Int96:
+ return arrow.FixedWidthTypes.Timestamp_ns, nil
+ case parquet.Types.Float:
+ return arrow.PrimitiveTypes.Float32, nil
+ case parquet.Types.Double:
+ return arrow.PrimitiveTypes.Float64, nil
+ case parquet.Types.ByteArray:
+ return arrowFromByteArray(logical)
+ case parquet.Types.FixedLenByteArray:
+ return arrowFromFLBA(logical, typeLen)
+ default:
+ return nil, xerrors.New("invalid physical column type")
+ }
+}
+
+func populateLeaf(colIndex int, field *arrow.Field, currentLevels file.LevelInfo, ctx *schemaTree, parent *SchemaField, out *SchemaField) {
+ out.Field = field
+ out.ColIndex = colIndex
+ out.LevelInfo = currentLevels
+ ctx.RecordLeaf(out)
+ ctx.LinkParent(out, parent)
+}
+
+func listToSchemaField(n *schema.GroupNode, currentLevels file.LevelInfo, ctx *schemaTree, parent, out *SchemaField) error {
+ if n.NumFields() != 1 {
+ return xerrors.New("LIST groups must have only 1 child")
+ }
+
+ if n.RepetitionType() == parquet.Repetitions.Repeated {
+ return xerrors.New("LIST groups must not be repeated")
+ }
+
+ currentLevels.Increment(n)
+
+ out.Children = make([]SchemaField, n.NumFields())
+ ctx.LinkParent(out, parent)
+ ctx.LinkParent(&out.Children[0], out)
+
+ listNode := n.Field(0)
+ if listNode.RepetitionType() != parquet.Repetitions.Repeated {
+ return xerrors.New("non-repeated nodes in a list group are not supported")
+ }
+
+ repeatedAncestorDef := currentLevels.IncrementRepeated()
+ if listNode.Type() == schema.Group {
+ // Resolve 3-level encoding
+ //
+ // required/optional group name=whatever {
+ // repeated group name=list {
+ // required/optional TYPE item;
+ // }
+ // }
+ //
+ // yields list<item: TYPE ?nullable> ?nullable
+ //
+ // We distinguish the special case that we have
+ //
+ // required/optional group name=whatever {
+ // repeated group name=array or $SOMETHING_tuple {
+ // required/optional TYPE item;
+ // }
+ // }
+ //
+ // In this latter case, the inner type of the list should be a struct
+ // rather than a primitive value
+ //
+ // yields list<item: struct<item: TYPE ?nullable> not null> ?nullable
+ // Special case mentioned in the format spec:
+ // If the name is array or ends in _tuple, this should be a list of struct
+ // even for single child elements.
+ listGroup := listNode.(*schema.GroupNode)
+ if listGroup.NumFields() == 1 && (listGroup.Name() == "array" || strings.HasSuffix(listGroup.Name(), "_tuple")) {
+ // list of primitive type
+ if err := groupToStructField(listGroup, currentLevels, ctx, out, &out.Children[0]); err != nil {
+ return err
+ }
+ } else {
+ if err := nodeToSchemaField(listGroup.Field(0), currentLevels, ctx, out, &out.Children[0]); err != nil {
+ return err
+ }
+ }
+ } else {
+ // Two-level list encoding
+ //
+ // required/optional group LIST {
+ // repeated TYPE;
+ // }
+ primitiveNode := listNode.(*schema.PrimitiveNode)
+ colIndex := ctx.schema.ColumnIndexByNode(primitiveNode)
+ arrowType, err := getArrowType(primitiveNode.PhysicalType(), primitiveNode.LogicalType(), primitiveNode.TypeLength())
+ if err != nil {
+ return err
+ }
+
+ itemField := arrow.Field{Name: listNode.Name(), Type: arrowType, Nullable: false, Metadata: createFieldMeta(int(listNode.FieldID()))}
+ populateLeaf(colIndex, &itemField, currentLevels, ctx, out, &out.Children[0])
+ }
+
+ out.Field = &arrow.Field{Name: n.Name(), Type: arrow.ListOf(out.Children[0].Field.Type),
+ Nullable: n.RepetitionType() == parquet.Repetitions.Optional, Metadata: createFieldMeta(int(n.FieldID()))}
+ out.LevelInfo = currentLevels
+ // At this point current levels contains the def level for this list,
+ // we need to reset to the prior parent.
+ out.LevelInfo.RepeatedAncestorDefLevel = repeatedAncestorDef
+ return nil
+}
+
+func groupToStructField(n *schema.GroupNode, currentLevels file.LevelInfo, ctx *schemaTree, parent, out *SchemaField) error {
+ arrowFields := make([]arrow.Field, 0, n.NumFields())
+ out.Children = make([]SchemaField, n.NumFields())
+
+ for i := 0; i < n.NumFields(); i++ {
+ if err := nodeToSchemaField(n.Field(i), currentLevels, ctx, out, &out.Children[i]); err != nil {
+ return err
+ }
+ arrowFields = append(arrowFields, *out.Children[i].Field)
+ }
+
+ out.Field = &arrow.Field{Name: n.Name(), Type: arrow.StructOf(arrowFields...),
+ Nullable: n.RepetitionType() == parquet.Repetitions.Optional, Metadata: createFieldMeta(int(n.FieldID()))}
+ out.LevelInfo = currentLevels
+ return nil
+}
+
+func mapToSchemaField(n *schema.GroupNode, currentLevels file.LevelInfo, ctx *schemaTree, parent, out *SchemaField) error {
+ if n.NumFields() != 1 {
+ return xerrors.New("MAP group must have exactly 1 child")
+ }
+ if n.RepetitionType() == parquet.Repetitions.Repeated {
+ return xerrors.New("MAP groups must not be repeated")
+ }
+
+ keyvalueNode := n.Field(0)
+ if keyvalueNode.RepetitionType() != parquet.Repetitions.Repeated {
+ return xerrors.New("Non-repeated keyvalue group in MAP group is not supported")
+ }
+
+ if keyvalueNode.Type() != schema.Group {
+ return xerrors.New("keyvalue node must be a group")
+ }
+
+ kvgroup := keyvalueNode.(*schema.GroupNode)
+ if kvgroup.NumFields() != 1 && kvgroup.NumFields() != 2 {
+ return xerrors.Errorf("keyvalue node group must have exactly 1 or 2 child elements, Found %d", kvgroup.NumFields())
+ }
+
+ keyNode := kvgroup.Field(0)
+ if keyNode.RepetitionType() != parquet.Repetitions.Required {
+ return xerrors.New("MAP keys must be required")
+ }
+
+ // Arrow doesn't support 1 column maps (i.e. Sets). The options are to either
+ // make the values column nullable, or process the map as a list. We choose the latter
+ // as it is simpler.
+ if kvgroup.NumFields() == 1 {
+ return listToSchemaField(n, currentLevels, ctx, parent, out)
+ }
+
+ currentLevels.Increment(n)
+ repeatedAncestorDef := currentLevels.IncrementRepeated()
+ out.Children = make([]SchemaField, 1)
+
+ kvfield := &out.Children[0]
+ kvfield.Children = make([]SchemaField, 2)
+
+ keyField := &kvfield.Children[0]
+ valueField := &kvfield.Children[1]
+
+ ctx.LinkParent(out, parent)
+ ctx.LinkParent(kvfield, out)
+ ctx.LinkParent(keyField, kvfield)
+ ctx.LinkParent(valueField, kvfield)
+
+ // required/optional group name=whatever {
+ // repeated group name=key_values{
+ // required TYPE key;
+ // required/optional TYPE value;
+ // }
+ // }
+ //
+
+ if err := nodeToSchemaField(keyNode, currentLevels, ctx, kvfield, keyField); err != nil {
+ return err
+ }
+ if err := nodeToSchemaField(kvgroup.Field(1), currentLevels, ctx, kvfield, valueField); err != nil {
+ return err
+ }
+
+ kvfield.Field = &arrow.Field{Name: n.Name(), Type: arrow.StructOf(*keyField.Field, *valueField.Field),
+ Nullable: false, Metadata: createFieldMeta(int(kvgroup.FieldID()))}
+
+ kvfield.LevelInfo = currentLevels
+ out.Field = &arrow.Field{Name: n.Name(), Type: arrow.MapOf(keyField.Field.Type, valueField.Field.Type),
+ Nullable: n.RepetitionType() == parquet.Repetitions.Optional,
+ Metadata: createFieldMeta(int(n.FieldID()))}
+ out.LevelInfo = currentLevels
+ // At this point current levels contains the def level for this map,
+ // we need to reset to the prior parent.
+ out.LevelInfo.RepeatedAncestorDefLevel = repeatedAncestorDef
+ return nil
+}
+
+func groupToSchemaField(n *schema.GroupNode, currentLevels file.LevelInfo, ctx *schemaTree, parent, out *SchemaField) error {
+ if n.LogicalType().Equals(schema.NewListLogicalType()) {
+ return listToSchemaField(n, currentLevels, ctx, parent, out)
+ } else if n.LogicalType().Equals(schema.MapLogicalType{}) {
+ return mapToSchemaField(n, currentLevels, ctx, parent, out)
+ }
+
+ if n.RepetitionType() == parquet.Repetitions.Repeated {
+ // Simple repeated struct
+ //
+ // repeated group $NAME {
+ // r/o TYPE[0] f0
+ // r/o TYPE[1] f1
+ // }
+ out.Children = make([]SchemaField, 1)
+ repeatedAncestorDef := currentLevels.IncrementRepeated()
+ if err := groupToStructField(n, currentLevels, ctx, out, &out.Children[0]); err != nil {
+ return err
+ }
+
+ out.Field = &arrow.Field{Name: n.Name(), Type: arrow.ListOf(out.Children[0].Field.Type), Nullable: false,
+ Metadata: createFieldMeta(int(n.FieldID()))}
+ ctx.LinkParent(&out.Children[0], out)
+ out.LevelInfo = currentLevels
+ out.LevelInfo.RepeatedAncestorDefLevel = repeatedAncestorDef
+ return nil
+ }
+
+ currentLevels.Increment(n)
+ return groupToStructField(n, currentLevels, ctx, parent, out)
+}
+
+func createFieldMeta(fieldID int) arrow.Metadata {
+ return arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{strconv.Itoa(fieldID)})
+}
+
+func nodeToSchemaField(n schema.Node, currentLevels file.LevelInfo, ctx *schemaTree, parent, out *SchemaField) error {
+ ctx.LinkParent(out, parent)
+
+ if n.Type() == schema.Group {
+ return groupToSchemaField(n.(*schema.GroupNode), currentLevels, ctx, parent, out)
+ }
+
+ // Either a normal flat primitive type, or a list type encoded with 1-level
+ // list encoding. Note that the 3-level encoding is the form recommended by
+ // the parquet specification, but technically we can have either
+ //
+ // required/optional $TYPE $FIELD_NAME
+ //
+ // or
+ //
+ // repeated $TYPE $FIELD_NAME
+
+ primitive := n.(*schema.PrimitiveNode)
+ colIndex := ctx.schema.ColumnIndexByNode(primitive)
+ arrowType, err := getArrowType(primitive.PhysicalType(), primitive.LogicalType(), primitive.TypeLength())
+ if err != nil {
+ return err
+ }
+
+ if primitive.RepetitionType() == parquet.Repetitions.Repeated {
+ // one-level list encoding e.g. a: repeated int32;
+ repeatedAncestorDefLevel := currentLevels.IncrementRepeated()
+ out.Children = make([]SchemaField, 1)
+ child := arrow.Field{Name: primitive.Name(), Type: arrowType, Nullable: false}
+ populateLeaf(colIndex, &child, currentLevels, ctx, out, &out.Children[0])
+ out.Field = &arrow.Field{Name: primitive.Name(), Type: arrow.ListOf(child.Type), Nullable: false,
+ Metadata: createFieldMeta(int(primitive.FieldID()))}
+ out.LevelInfo = currentLevels
+ out.LevelInfo.RepeatedAncestorDefLevel = repeatedAncestorDefLevel
+ return nil
+ }
+
+ currentLevels.Increment(n)
+ populateLeaf(colIndex, &arrow.Field{Name: n.Name(), Type: arrowType,
+ Nullable: n.RepetitionType() == parquet.Repetitions.Optional,
+ Metadata: createFieldMeta(int(n.FieldID()))},
+ currentLevels, ctx, parent, out)
+ return nil
+}
+
+func getOriginSchema(meta metadata.KeyValueMetadata, mem memory.Allocator) (*arrow.Schema, error) {
+ if meta == nil {
+ return nil, nil
+ }
+
+ const arrowSchemaKey = "ARROW:schema"
+ serialized := meta.FindValue(arrowSchemaKey)
+ if serialized == nil {
+ return nil, nil
+ }
+
+ decoded, err := base64.RawStdEncoding.DecodeString(*serialized)
+ if err != nil {
+ return nil, err
+ }
+
+ return flight.DeserializeSchema(decoded, mem)
+}
+
+func getNestedFactory(origin, inferred arrow.DataType) func(fieldList []arrow.Field) arrow.DataType {
+ switch inferred.ID() {
+ case arrow.STRUCT:
+ if origin.ID() == arrow.STRUCT {
+ return func(list []arrow.Field) arrow.DataType {
+ return arrow.StructOf(list...)
+ }
+ }
+ case arrow.LIST:
+ switch origin.ID() {
+ case arrow.LIST:
+ return func(list []arrow.Field) arrow.DataType {
+ return arrow.ListOf(list[0].Type)
+ }
+ case arrow.FIXED_SIZE_LIST:
+ sz := origin.(*arrow.FixedSizeListType).Len()
+ return func(list []arrow.Field) arrow.DataType {
+ return arrow.FixedSizeListOf(sz, list[0].Type)
+ }
+ }
+ case arrow.MAP:
+ if origin.ID() == arrow.MAP {
+ return func(list []arrow.Field) arrow.DataType {
+ valType := list[0].Type.(*arrow.StructType)
+ return arrow.MapOf(valType.Field(0).Type, valType.Field(1).Type)
+ }
+ }
+ }
+ return nil
+}
+
+func applyOriginalStorageMetadata(origin arrow.Field, inferred *SchemaField) (modified bool, err error) {
+ nchildren := len(inferred.Children)
+ switch origin.Type.ID() {
+ case arrow.EXTENSION, arrow.UNION, arrow.DICTIONARY:
+ err = xerrors.New("unimplemented type")
+ case arrow.STRUCT:
+ typ := origin.Type.(*arrow.StructType)
+ if nchildren != len(typ.Fields()) {
+ return
+ }
+
+ factory := getNestedFactory(typ, inferred.Field.Type)
+ if factory == nil {
+ return
+ }
+
+ modified = typ.ID() != inferred.Field.Type.ID()
+ for idx := range inferred.Children {
+ childMod, err := applyOriginalMetadata(typ.Field(idx), &inferred.Children[idx])
+ if err != nil {
+ return false, err
+ }
+ modified = modified || childMod
+ }
+ if modified {
+ modifiedChildren := make([]arrow.Field, len(inferred.Children))
+ for idx, child := range inferred.Children {
+ modifiedChildren[idx] = *child.Field
+ }
+ inferred.Field.Type = factory(modifiedChildren)
+ }
+ case arrow.FIXED_SIZE_LIST, arrow.LIST, arrow.MAP:
+ if nchildren != 1 {
+ return
+ }
+ factory := getNestedFactory(origin.Type, inferred.Field.Type)
+ if factory == nil {
+ return
+ }
+
+ modified = origin.Type.ID() != inferred.Field.Type.ID()
+ var childModified bool
+ switch typ := origin.Type.(type) {
+ case *arrow.FixedSizeListType:
+ childModified, err = applyOriginalMetadata(arrow.Field{Type: typ.Elem()}, &inferred.Children[0])
+ case *arrow.ListType:
+ childModified, err = applyOriginalMetadata(arrow.Field{Type: typ.Elem()}, &inferred.Children[0])
+ case *arrow.MapType:
+ childModified, err = applyOriginalMetadata(arrow.Field{Type: typ.ValueType()}, &inferred.Children[0])
+ }
+ if err != nil {
+ return
+ }
+ modified = modified || childModified
+ if modified {
+ inferred.Field.Type = factory([]arrow.Field{*inferred.Children[0].Field})
+ }
+ case arrow.TIMESTAMP:
+ if inferred.Field.Type.ID() != arrow.TIMESTAMP {
+ return
+ }
+
+ tsOtype := origin.Type.(*arrow.TimestampType)
+ tsInfType := inferred.Field.Type.(*arrow.TimestampType)
+
+ // if the unit is the same and the data is tz-aware, then set the original time zone
+ // since parquet has no native storage of timezones
+ if tsOtype.Unit == tsInfType.Unit && tsInfType.TimeZone == "UTC" && tsOtype.TimeZone != "" {
+ inferred.Field.Type = origin.Type
+ }
+ modified = true
+ }
+
+ if origin.HasMetadata() {
+ meta := origin.Metadata
+ if inferred.Field.HasMetadata() {
+ final := make(map[string]string)
+ for idx, k := range meta.Keys() {
+ final[k] = meta.Values()[idx]
+ }
+ for idx, k := range inferred.Field.Metadata.Keys() {
+ final[k] = inferred.Field.Metadata.Values()[idx]
+ }
+ inferred.Field.Metadata = arrow.MetadataFrom(final)
+ } else {
+ inferred.Field.Metadata = meta
+ }
+ modified = true
+ }
+
+ return
+}
+
+func applyOriginalMetadata(origin arrow.Field, inferred *SchemaField) (bool, error) {
+ if origin.Type.ID() == arrow.EXTENSION {
+ return false, xerrors.New("extension types not implemented yet")
+ }
+
+ return applyOriginalStorageMetadata(origin, inferred)
+}
+
+// NewSchemaManifest creates a manifest for mapping a parquet schema to a given arrow schema.
+//
+// The metadata passed in should be the file level key value metadata from the parquet file or nil.
+// If the ARROW:schema was in the metadata, then it is utilized to determine types.
+func NewSchemaManifest(sc *schema.Schema, meta metadata.KeyValueMetadata, props *ArrowReadProperties) (*SchemaManifest, error) {
+ var ctx schemaTree
+ ctx.manifest = &SchemaManifest{
+ ColIndexToField: make(map[int]*SchemaField),
+ ChildToParent: make(map[*SchemaField]*SchemaField),
+ descr: sc,
+ Fields: make([]SchemaField, sc.Root().NumFields()),
+ }
+ ctx.props = props
+ ctx.schema = sc
+
+ var err error
+ ctx.manifest.OriginSchema, err = getOriginSchema(meta, memory.DefaultAllocator)
+ if err != nil {
+ return nil, err
+ }
+
+ // if original schema is not compatible with the parquet schema, ignore it
+ if ctx.manifest.OriginSchema != nil && len(ctx.manifest.OriginSchema.Fields()) != sc.Root().NumFields() {
+ ctx.manifest.OriginSchema = nil
+ }
+
+ for idx := range ctx.manifest.Fields {
+ field := &ctx.manifest.Fields[idx]
+ if err := nodeToSchemaField(sc.Root().Field(idx), file.LevelInfo{NullSlotUsage: 1}, &ctx, nil, field); err != nil {
+ return nil, err
+ }
+
+ if ctx.manifest.OriginSchema != nil {
+ if _, err := applyOriginalMetadata(ctx.manifest.OriginSchema.Field(idx), field); err != nil {
+ return nil, err
+ }
+ }
+ }
+ return ctx.manifest, nil
+}
+
+// FromParquet generates an arrow Schema from a provided Parquet Schema
+func FromParquet(sc *schema.Schema, props *ArrowReadProperties, kv metadata.KeyValueMetadata) (*arrow.Schema, error) {
+ manifest, err := NewSchemaManifest(sc, kv, props)
+ if err != nil {
+ return nil, err
+ }
+
+ fields := make([]arrow.Field, len(manifest.Fields))
+ for idx, field := range manifest.Fields {
+ fields[idx] = *field.Field
+ }
+
+ if manifest.OriginSchema != nil {
+ meta := manifest.OriginSchema.Metadata()
+ return arrow.NewSchema(fields, &meta), nil
+ }
+ return arrow.NewSchema(fields, manifest.SchemaMeta), nil
+}
diff --git a/go/parquet/pqarrow/schema_test.go b/go/parquet/pqarrow/schema_test.go
new file mode 100644
index 0000000..9e5359d
--- /dev/null
+++ b/go/parquet/pqarrow/schema_test.go
@@ -0,0 +1,245 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pqarrow_test
+
+import (
+ "testing"
+
+ "github.com/apache/arrow/go/arrow"
+ "github.com/apache/arrow/go/parquet"
+ "github.com/apache/arrow/go/parquet/pqarrow"
+ "github.com/apache/arrow/go/parquet/schema"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestConvertArrowFlatPrimitives(t *testing.T) {
+ parquetFields := make(schema.FieldList, 0)
+ arrowFields := make([]arrow.Field, 0)
+
+ parquetFields = append(parquetFields, schema.NewBooleanNode("boolean", parquet.Repetitions.Required, -1))
+ arrowFields = append(arrowFields, arrow.Field{Name: "boolean", Type: arrow.FixedWidthTypes.Boolean, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("int8", parquet.Repetitions.Required,
+ schema.NewIntLogicalType(8, true), parquet.Types.Int32, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "int8", Type: arrow.PrimitiveTypes.Int8, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("uint8", parquet.Repetitions.Required,
+ schema.NewIntLogicalType(8, false), parquet.Types.Int32, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "uint8", Type: arrow.PrimitiveTypes.Uint8, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("int16", parquet.Repetitions.Required,
+ schema.NewIntLogicalType(16, true), parquet.Types.Int32, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "int16", Type: arrow.PrimitiveTypes.Int16, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("uint16", parquet.Repetitions.Required,
+ schema.NewIntLogicalType(16, false), parquet.Types.Int32, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "uint16", Type: arrow.PrimitiveTypes.Uint16, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("int32", parquet.Repetitions.Required,
+ schema.NewIntLogicalType(32, true), parquet.Types.Int32, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "int32", Type: arrow.PrimitiveTypes.Int32, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("uint32", parquet.Repetitions.Required,
+ schema.NewIntLogicalType(32, false), parquet.Types.Int32, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "uint32", Type: arrow.PrimitiveTypes.Uint32, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("int64", parquet.Repetitions.Required,
+ schema.NewIntLogicalType(64, true), parquet.Types.Int64, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "int64", Type: arrow.PrimitiveTypes.Int64, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("uint64", parquet.Repetitions.Required,
+ schema.NewIntLogicalType(64, false), parquet.Types.Int64, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "uint64", Type: arrow.PrimitiveTypes.Uint64, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeConverted("timestamp", parquet.Repetitions.Required,
+ parquet.Types.Int64, schema.ConvertedTypes.TimestampMillis, 0, 0, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "timestamp", Type: arrow.FixedWidthTypes.Timestamp_ms, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeConverted("timestamp[us]", parquet.Repetitions.Required,
+ parquet.Types.Int64, schema.ConvertedTypes.TimestampMicros, 0, 0, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "timestamp[us]", Type: arrow.FixedWidthTypes.Timestamp_us, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("date", parquet.Repetitions.Required,
+ schema.DateLogicalType{}, parquet.Types.Int32, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "date", Type: arrow.FixedWidthTypes.Date32, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("date64", parquet.Repetitions.Required,
+ schema.NewTimestampLogicalType(true, schema.TimeUnitMillis), parquet.Types.Int64, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "date64", Type: arrow.FixedWidthTypes.Date64, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("time32", parquet.Repetitions.Required,
+ schema.NewTimeLogicalType(true, schema.TimeUnitMillis), parquet.Types.Int32, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "time32", Type: arrow.FixedWidthTypes.Time32ms, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("time64", parquet.Repetitions.Required,
+ schema.NewTimeLogicalType(true, schema.TimeUnitMicros), parquet.Types.Int64, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "time64", Type: arrow.FixedWidthTypes.Time64us, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.NewInt96Node("timestamp96", parquet.Repetitions.Required, -1))
+ arrowFields = append(arrowFields, arrow.Field{Name: "timestamp96", Type: arrow.FixedWidthTypes.Timestamp_ns, Nullable: false})
+
+ parquetFields = append(parquetFields, schema.NewFloat32Node("float", parquet.Repetitions.Optional, -1))
+ arrowFields = append(arrowFields, arrow.Field{Name: "float", Type: arrow.PrimitiveTypes.Float32, Nullable: true})
+
+ parquetFields = append(parquetFields, schema.NewFloat64Node("double", parquet.Repetitions.Optional, -1))
+ arrowFields = append(arrowFields, arrow.Field{Name: "double", Type: arrow.PrimitiveTypes.Float64, Nullable: true})
+
+ parquetFields = append(parquetFields, schema.NewByteArrayNode("binary", parquet.Repetitions.Optional, -1))
+ arrowFields = append(arrowFields, arrow.Field{Name: "binary", Type: arrow.BinaryTypes.Binary, Nullable: true})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("string", parquet.Repetitions.Optional,
+ schema.StringLogicalType{}, parquet.Types.ByteArray, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "string", Type: arrow.BinaryTypes.String, Nullable: true})
+
+ parquetFields = append(parquetFields, schema.NewFixedLenByteArrayNode("flba-binary", parquet.Repetitions.Optional, 12, -1))
+ arrowFields = append(arrowFields, arrow.Field{Name: "flba-binary", Type: &arrow.FixedSizeBinaryType{ByteWidth: 12}, Nullable: true})
+
+ arrowSchema := arrow.NewSchema(arrowFields, nil)
+ parquetSchema := schema.NewSchema(schema.MustGroup(schema.NewGroupNode("schema", parquet.Repetitions.Repeated, parquetFields, -1)))
+
+ result, err := pqarrow.ToParquet(arrowSchema, nil, pqarrow.NewArrowWriterProperties(pqarrow.WithDeprecatedInt96Timestamps(true)))
+ assert.NoError(t, err)
+ assert.True(t, parquetSchema.Equals(result))
+ for i := 0; i < parquetSchema.NumColumns(); i++ {
+ assert.Truef(t, parquetSchema.Column(i).Equals(result.Column(i)), "Column %d didn't match: %s", i, parquetSchema.Column(i).Name())
+ }
+}
+
+func TestConvertArrowParquetLists(t *testing.T) {
+ parquetFields := make(schema.FieldList, 0)
+ arrowFields := make([]arrow.Field, 0)
+
+ parquetFields = append(parquetFields, schema.MustGroup(schema.ListOf(schema.Must(schema.NewPrimitiveNodeLogical("my_list",
+ parquet.Repetitions.Optional, schema.StringLogicalType{}, parquet.Types.ByteArray, 0, -1)), parquet.Repetitions.Required, -1)))
+
+ arrowFields = append(arrowFields, arrow.Field{Name: "my_list", Type: arrow.ListOf(arrow.BinaryTypes.String)})
+
+ parquetFields = append(parquetFields, schema.MustGroup(schema.ListOf(schema.Must(schema.NewPrimitiveNodeLogical("my_list",
+ parquet.Repetitions.Optional, schema.StringLogicalType{}, parquet.Types.ByteArray, 0, -1)), parquet.Repetitions.Optional, -1)))
+
+ arrowFields = append(arrowFields, arrow.Field{Name: "my_list", Type: arrow.ListOf(arrow.BinaryTypes.String), Nullable: true})
+
+ arrowSchema := arrow.NewSchema(arrowFields, nil)
+ parquetSchema := schema.NewSchema(schema.MustGroup(schema.NewGroupNode("schema", parquet.Repetitions.Repeated, parquetFields, -1)))
+
+ result, err := pqarrow.ToParquet(arrowSchema, nil, pqarrow.NewArrowWriterProperties(pqarrow.WithDeprecatedInt96Timestamps(true)))
+ assert.NoError(t, err)
+ assert.True(t, parquetSchema.Equals(result), parquetSchema.String(), result.String())
+ for i := 0; i < parquetSchema.NumColumns(); i++ {
+ assert.Truef(t, parquetSchema.Column(i).Equals(result.Column(i)), "Column %d didn't match: %s", i, parquetSchema.Column(i).Name())
+ }
+}
+
+func TestConvertArrowDecimals(t *testing.T) {
+ parquetFields := make(schema.FieldList, 0)
+ arrowFields := make([]arrow.Field, 0)
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("decimal_8_4", parquet.Repetitions.Required,
+ schema.NewDecimalLogicalType(8, 4), parquet.Types.FixedLenByteArray, 4, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "decimal_8_4", Type: &arrow.Decimal128Type{Precision: 8, Scale: 4}})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("decimal_20_4", parquet.Repetitions.Required,
+ schema.NewDecimalLogicalType(20, 4), parquet.Types.FixedLenByteArray, 9, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "decimal_20_4", Type: &arrow.Decimal128Type{Precision: 20, Scale: 4}})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("decimal_77_4", parquet.Repetitions.Required,
+ schema.NewDecimalLogicalType(77, 4), parquet.Types.FixedLenByteArray, 34, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "decimal_77_4", Type: &arrow.Decimal128Type{Precision: 77, Scale: 4}})
+
+ arrowSchema := arrow.NewSchema(arrowFields, nil)
+ parquetSchema := schema.NewSchema(schema.MustGroup(schema.NewGroupNode("schema", parquet.Repetitions.Repeated, parquetFields, -1)))
+
+ result, err := pqarrow.ToParquet(arrowSchema, nil, pqarrow.NewArrowWriterProperties(pqarrow.WithDeprecatedInt96Timestamps(true)))
+ assert.NoError(t, err)
+ assert.True(t, parquetSchema.Equals(result))
+ for i := 0; i < parquetSchema.NumColumns(); i++ {
+ assert.Truef(t, parquetSchema.Column(i).Equals(result.Column(i)), "Column %d didn't match: %s", i, parquetSchema.Column(i).Name())
+ }
+}
+
+func TestCoerceTImestampV1(t *testing.T) {
+ parquetFields := make(schema.FieldList, 0)
+ arrowFields := make([]arrow.Field, 0)
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("timestamp", parquet.Repetitions.Required,
+ schema.NewTimestampLogicalTypeForce(false, schema.TimeUnitMicros), parquet.Types.Int64, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "timestamp", Type: &arrow.TimestampType{Unit: arrow.Millisecond, TimeZone: "EST"}})
+
+ arrowSchema := arrow.NewSchema(arrowFields, nil)
+ parquetSchema := schema.NewSchema(schema.MustGroup(schema.NewGroupNode("schema", parquet.Repetitions.Repeated, parquetFields, -1)))
+
+ result, err := pqarrow.ToParquet(arrowSchema, parquet.NewWriterProperties(parquet.WithVersion(parquet.V1_0)), pqarrow.NewArrowWriterProperties(pqarrow.WithCoerceTimestamps(arrow.Microsecond)))
+ assert.NoError(t, err)
+ assert.True(t, parquetSchema.Equals(result))
+ for i := 0; i < parquetSchema.NumColumns(); i++ {
+ assert.Truef(t, parquetSchema.Column(i).Equals(result.Column(i)), "Column %d didn't match: %s", i, parquetSchema.Column(i).Name())
+ }
+}
+
+func TestAutoCoerceTImestampV1(t *testing.T) {
+ parquetFields := make(schema.FieldList, 0)
+ arrowFields := make([]arrow.Field, 0)
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("timestamp", parquet.Repetitions.Required,
+ schema.NewTimestampLogicalTypeForce(false, schema.TimeUnitMicros), parquet.Types.Int64, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "timestamp", Type: &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "EST"}})
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("timestamp[ms]", parquet.Repetitions.Required,
+ schema.NewTimestampLogicalTypeForce(true, schema.TimeUnitMillis), parquet.Types.Int64, 0, -1)))
+ arrowFields = append(arrowFields, arrow.Field{Name: "timestamp[ms]", Type: &arrow.TimestampType{Unit: arrow.Second}})
+
+ arrowSchema := arrow.NewSchema(arrowFields, nil)
+ parquetSchema := schema.NewSchema(schema.MustGroup(schema.NewGroupNode("schema", parquet.Repetitions.Repeated, parquetFields, -1)))
+
+ result, err := pqarrow.ToParquet(arrowSchema, parquet.NewWriterProperties(parquet.WithVersion(parquet.V1_0)), pqarrow.NewArrowWriterProperties())
+ assert.NoError(t, err)
+ assert.True(t, parquetSchema.Equals(result))
+ for i := 0; i < parquetSchema.NumColumns(); i++ {
+ assert.Truef(t, parquetSchema.Column(i).Equals(result.Column(i)), "Column %d didn't match: %s", i, parquetSchema.Column(i).Name())
+ }
+}
+
+func TestConvertArrowStruct(t *testing.T) {
+ parquetFields := make(schema.FieldList, 0)
+ arrowFields := make([]arrow.Field, 0)
+
+ parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("leaf1", parquet.Repetitions.Optional, schema.NewIntLogicalType(32, true), parquet.Types.Int32, 0, -1)))
+ parquetFields = append(parquetFields, schema.Must(schema.NewGroupNode("outerGroup", parquet.Repetitions.Required, schema.FieldList{
+ schema.Must(schema.NewPrimitiveNodeLogical("leaf2", parquet.Repetitions.Optional, schema.NewIntLogicalType(32, true), parquet.Types.Int32, 0, -1)),
+ schema.Must(schema.NewGroupNode("innerGroup", parquet.Repetitions.Required, schema.FieldList{
+ schema.Must(schema.NewPrimitiveNodeLogical("leaf3", parquet.Repetitions.Optional, schema.NewIntLogicalType(32, true), parquet.Types.Int32, 0, -1)),
+ }, -1)),
+ }, -1)))
+
+ arrowFields = append(arrowFields, arrow.Field{Name: "leaf1", Type: arrow.PrimitiveTypes.Int32, Nullable: true})
+ arrowFields = append(arrowFields, arrow.Field{Name: "outerGroup", Type: arrow.StructOf(
+ arrow.Field{Name: "leaf2", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
+ arrow.Field{Name: "innerGroup", Type: arrow.StructOf(
+ arrow.Field{Name: "leaf3", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
+ )},
+ )})
+
+ arrowSchema := arrow.NewSchema(arrowFields, nil)
+ parquetSchema := schema.NewSchema(schema.MustGroup(schema.NewGroupNode("schema", parquet.Repetitions.Repeated, parquetFields, -1)))
+
+ result, err := pqarrow.ToParquet(arrowSchema, nil, pqarrow.NewArrowWriterProperties())
+ assert.NoError(t, err)
+ assert.True(t, parquetSchema.Equals(result))
+ for i := 0; i < parquetSchema.NumColumns(); i++ {
+ assert.Truef(t, parquetSchema.Column(i).Equals(result.Column(i)), "Column %d didn't match: %s", i, parquetSchema.Column(i).Name())
+ }
+}