You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ze...@apache.org on 2022/04/21 14:40:47 UTC
[arrow] branch master updated: ARROW-3039: [Go] Add support for DictionaryArray
This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 4608157f8d ARROW-3039: [Go] Add support for DictionaryArray
4608157f8d is described below
commit 4608157f8d492354db80456626129216d2acd855
Author: Matthew Topol <mt...@factset.com>
AuthorDate: Thu Apr 21 10:42:34 2022 -0400
ARROW-3039: [Go] Add support for DictionaryArray
Also resolves ARROW-5267, ARROW-7286, and ARROW-9378 which all are different pieces of the DictionaryArray support.
This *does not* implement Dictionary support for scalars yet, nor does it yet support concatenating Dictionary Arrays and dictionary unification. Cards will be created to track work for those two pieces of functionality separately.
Closes #12158 from zeroshade/goarrow-dictionaries
Lead-authored-by: Matthew Topol <mt...@factset.com>
Co-authored-by: Matt Topol <zo...@gmail.com>
Signed-off-by: Matthew Topol <mt...@factset.com>
---
dev/archery/archery/integration/datagen.py | 6 +-
docs/source/status.rst | 6 +-
go/arrow/_tools/tools.go | 2 +
go/arrow/array/array.go | 2 +-
go/arrow/array/array_test.go | 26 +-
go/arrow/array/builder.go | 2 +
go/arrow/array/compare.go | 6 +
go/arrow/array/data.go | 65 +-
go/arrow/array/dictionary.go | 1294 +++++++++
go/arrow/array/dictionary_test.go | 1183 +++++++++
go/arrow/datatype_fixedwidth.go | 30 +
go/arrow/internal/arrjson/arrjson.go | 457 ++--
go/arrow/internal/arrjson/arrjson_test.go | 448 +++-
go/arrow/internal/arrjson/reader.go | 9 +-
go/arrow/internal/arrjson/writer.go | 35 +-
go/arrow/internal/dictutils/dict.go | 399 +++
go/arrow/{ipc => internal/dictutils}/dict_test.go | 68 +-
go/arrow/ipc/dict.go | 85 -
go/arrow/ipc/file_reader.go | 232 +-
go/arrow/ipc/file_writer.go | 19 +-
go/arrow/ipc/ipc_test.go | 56 +
go/arrow/ipc/metadata.go | 247 +-
go/arrow/ipc/metadata_test.go | 8 +-
go/arrow/ipc/reader.go | 70 +-
go/arrow/ipc/writer.go | 147 +-
go/go.mod | 2 +-
go/go.sum | 2 +
go/{parquet => }/internal/hashing/hashing_test.go | 0
go/internal/hashing/types.tmpldata | 42 +
go/internal/hashing/xxh3_memo_table.gen.go | 2783 ++++++++++++++++++++
.../internal/hashing/xxh3_memo_table.gen.go.tmpl | 24 +-
.../internal/hashing/xxh3_memo_table.go | 92 +-
go/{parquet => }/internal/utils/Makefile | 32 +-
.../utils/_lib/arch.h} | 22 +-
go/{parquet => }/internal/utils/_lib/min_max.c | 52 +
go/internal/utils/_lib/min_max_avx2_amd64.s | 1009 +++++++
.../internal/utils/_lib/min_max_neon.s | 0
go/internal/utils/_lib/min_max_sse4_amd64.s | 1091 ++++++++
.../utils/endians_default.go} | 19 +-
.../utils/endians_s390x.go} | 23 +-
go/{parquet => }/internal/utils/min_max.go | 92 +
go/{parquet => }/internal/utils/min_max_amd64.go | 13 +
go/{parquet => }/internal/utils/min_max_arm64.go | 9 +-
.../internal/utils/min_max_avx2_amd64.go | 33 +
go/internal/utils/min_max_avx2_amd64.s | 927 +++++++
.../internal/utils/min_max_neon_arm64.go | 0
.../internal/utils/min_max_neon_arm64.s | 0
go/{parquet => }/internal/utils/min_max_noasm.go | 5 +
go/{parquet => }/internal/utils/min_max_s390x.go | 5 +
.../internal/utils/min_max_sse4_amd64.go | 33 +
.../internal/utils/min_max_sse4_amd64.s | 664 ++++-
.../internal/encoding/encoding_benchmarks_test.go | 6 +-
go/parquet/internal/encoding/memo_table.go | 28 +-
go/parquet/internal/encoding/memo_table_test.go | 18 +-
go/parquet/internal/encoding/typed_encoder.gen.go | 35 +-
.../internal/encoding/typed_encoder.gen.go.tmpl | 5 +-
go/parquet/internal/hashing/types.tmpldata | 18 -
go/parquet/internal/hashing/xxh3_memo_table.gen.go | 1103 --------
go/parquet/internal/utils/Makefile | 17 +-
go/parquet/internal/utils/_lib/min_max_avx2.s | 473 ----
go/parquet/internal/utils/_lib/min_max_sse4.s | 613 -----
go/parquet/internal/utils/min_max_avx2_amd64.s | 443 ----
go/parquet/metadata/statistics_types.gen.go | 17 +-
go/parquet/metadata/statistics_types.gen.go.tmpl | 9 +-
64 files changed, 11154 insertions(+), 3507 deletions(-)
diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py
index 611a63d96e..ea69fe2956 100644
--- a/dev/archery/archery/integration/datagen.py
+++ b/dev/archery/archery/integration/datagen.py
@@ -1661,24 +1661,20 @@ def get_generated_json_files(tempdir=None):
# TODO(ARROW-3039, ARROW-5267): Dictionaries in GO
generate_dictionary_case()
- .skip_category('C#')
- .skip_category('Go'),
+ .skip_category('C#'),
generate_dictionary_unsigned_case()
.skip_category('C#')
- .skip_category('Go') # TODO(ARROW-9378)
.skip_category('Java'), # TODO(ARROW-9377)
generate_nested_dictionary_case()
.skip_category('C#')
- .skip_category('Go')
.skip_category('Java') # TODO(ARROW-7779)
.skip_category('JS')
.skip_category('Rust'),
generate_extension_case()
.skip_category('C#')
- .skip_category('Go') # TODO(ARROW-3039): requires dictionaries
.skip_category('JS')
.skip_category('Rust'),
]
diff --git a/docs/source/status.rst b/docs/source/status.rst
index c30caed2f8..b89ba95c7a 100644
--- a/docs/source/status.rst
+++ b/docs/source/status.rst
@@ -90,7 +90,7 @@ Data Types
| Data type | C++ | Java | Go | JavaScript | C# | Rust | Julia |
| (special) | | | | | | | |
+===================+=======+=======+=======+============+=======+=======+=======+
-| Dictionary | ✓ | ✓ (1) | | ✓ (1) | ✓ (1) | ✓ (1) | ✓ |
+| Dictionary | ✓ | ✓ (1) | ✓ | ✓ (1) | ✓ (1) | ✓ (1) | ✓ |
+-------------------+-------+-------+-------+------------+-------+-------+-------+
| Extension | ✓ | ✓ | ✓ | | | | ✓ |
+-------------------+-------+-------+-------+------------+-------+-------+-------+
@@ -118,9 +118,9 @@ IPC Format
+-----------------------------+-------+-------+-------+------------+-------+-------+-------+
| Dictionaries | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
+-----------------------------+-------+-------+-------+------------+-------+-------+-------+
-| Replacement dictionaries | ✓ | ✓ | | | | | ✓ |
+| Replacement dictionaries | ✓ | ✓ | ✓ | | | | ✓ |
+-----------------------------+-------+-------+-------+------------+-------+-------+-------+
-| Delta dictionaries | ✓ (1) | | | | ✓ | | ✓ |
+| Delta dictionaries | ✓ (1) | | ✓ (1) | | ✓ | | ✓ |
+-----------------------------+-------+-------+-------+------------+-------+-------+-------+
| Tensors | ✓ | | | | | | |
+-----------------------------+-------+-------+-------+------------+-------+-------+-------+
diff --git a/go/arrow/_tools/tools.go b/go/arrow/_tools/tools.go
index 6c494bb4a3..262880bca8 100644
--- a/go/arrow/_tools/tools.go
+++ b/go/arrow/_tools/tools.go
@@ -14,10 +14,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+//go:build tools
// +build tools
package _tools
import (
+ _ "golang.org/x/tools/cmd/goimports"
_ "golang.org/x/tools/cmd/stringer"
)
diff --git a/go/arrow/array/array.go b/go/arrow/array/array.go
index b5538ed4c7..74fd79060b 100644
--- a/go/arrow/array/array.go
+++ b/go/arrow/array/array.go
@@ -177,7 +177,7 @@ func init() {
arrow.STRUCT: func(data arrow.ArrayData) arrow.Array { return NewStructData(data) },
arrow.SPARSE_UNION: unsupportedArrayType,
arrow.DENSE_UNION: unsupportedArrayType,
- arrow.DICTIONARY: unsupportedArrayType,
+ arrow.DICTIONARY: func(data arrow.ArrayData) Interface { return NewDictionaryData(data) },
arrow.MAP: func(data arrow.ArrayData) arrow.Array { return NewMapData(data) },
arrow.EXTENSION: func(data arrow.ArrayData) arrow.Array { return NewExtensionData(data) },
arrow.FIXED_SIZE_LIST: func(data arrow.ArrayData) arrow.Array { return NewFixedSizeListData(data) },
diff --git a/go/arrow/array/array_test.go b/go/arrow/array/array_test.go
index 1af1bc6544..ee0ce0e81c 100644
--- a/go/arrow/array/array_test.go
+++ b/go/arrow/array/array_test.go
@@ -42,6 +42,7 @@ func TestMakeFromData(t *testing.T) {
d arrow.DataType
size int
child []arrow.ArrayData
+ dict *array.Data
expPanic bool
expError string
}{
@@ -95,13 +96,23 @@ func TestMakeFromData(t *testing.T) {
}, 0 /* nulls */, 0 /* offset */)},
},
+ // various dictionary index types and value types
+ {name: "dictionary", d: &testDataType{arrow.DICTIONARY}, expPanic: true, expError: "arrow/array: no dictionary set in Data for Dictionary array"},
+ {name: "dictionary", d: &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int8, ValueType: &testDataType{arrow.INT64}}, dict: array.NewData(&testDataType{arrow.INT64}, 0 /* length */, make([]*memory.Buffer, 2 /*null bitmap, values*/), nil /* childData */, 0 /* nulls */, 0 /* offset */)},
+ {name: "dictionary", d: &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Uint8, ValueType: &testDataType{arrow.INT32}}, dict: array.NewData(&testDataType{arrow.INT32}, 0 /* length */, make([]*memory.Buffer, 2 /*null bitmap, values*/), nil /* childData */, 0 /* nulls */, 0 /* offset */)},
+ {name: "dictionary", d: &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int16, ValueType: &testDataType{arrow.UINT16}}, dict: array.NewData(&testDataType{arrow.UINT16}, 0 /* length */, make([]*memory.Buffer, 2 /*null bitmap, values*/), nil /* childData */, 0 /* nulls */, 0 /* offset */)},
+ {name: "dictionary", d: &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Uint16, ValueType: &testDataType{arrow.INT64}}, dict: array.NewData(&testDataType{arrow.INT64}, 0 /* length */, make([]*memory.Buffer, 2 /*null bitmap, values*/), nil /* childData */, 0 /* nulls */, 0 /* offset */)},
+ {name: "dictionary", d: &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: &testDataType{arrow.UINT32}}, dict: array.NewData(&testDataType{arrow.UINT32}, 0 /* length */, make([]*memory.Buffer, 2 /*null bitmap, values*/), nil /* childData */, 0 /* nulls */, 0 /* offset */)},
+ {name: "dictionary", d: &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Uint32, ValueType: &testDataType{arrow.TIMESTAMP}}, dict: array.NewData(&testDataType{arrow.TIMESTAMP}, 0 /* length */, make([]*memory.Buffer, 2 /*null bitmap, values*/), nil /* childData */, 0 /* nulls */, 0 /* offset */)},
+ {name: "dictionary", d: &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int64, ValueType: &testDataType{arrow.UINT32}}, dict: array.NewData(&testDataType{arrow.UINT32}, 0 /* length */, make([]*memory.Buffer, 2 /*null bitmap, values*/), nil /* childData */, 0 /* nulls */, 0 /* offset */)},
+ {name: "dictionary", d: &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Uint64, ValueType: &testDataType{arrow.TIMESTAMP}}, dict: array.NewData(&testDataType{arrow.TIMESTAMP}, 0 /* length */, make([]*memory.Buffer, 2 /*null bitmap, values*/), nil /* childData */, 0 /* nulls */, 0 /* offset */)},
+
{name: "extension", d: &testDataType{arrow.EXTENSION}, expPanic: true, expError: "arrow/array: DataType for ExtensionArray must implement arrow.ExtensionType"},
{name: "extension", d: types.NewUUIDType()},
// unsupported types
{name: "sparse union", d: &testDataType{arrow.SPARSE_UNION}, expPanic: true, expError: "unsupported data type: SPARSE_UNION"},
{name: "dense union", d: &testDataType{arrow.DENSE_UNION}, expPanic: true, expError: "unsupported data type: DENSE_UNION"},
- {name: "dictionary", d: &testDataType{arrow.DICTIONARY}, expPanic: true, expError: "unsupported data type: DICTIONARY"},
{name: "large string", d: &testDataType{arrow.LARGE_STRING}, expPanic: true, expError: "unsupported data type: LARGE_STRING"},
{name: "large binary", d: &testDataType{arrow.LARGE_BINARY}, expPanic: true, expError: "unsupported data type: LARGE_BINARY"},
{name: "large list", d: &testDataType{arrow.LARGE_LIST}, expPanic: true, expError: "unsupported data type: LARGE_LIST"},
@@ -113,12 +124,19 @@ func TestMakeFromData(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- var b [4]*memory.Buffer
- var n = 4
+ var (
+ b [4]*memory.Buffer
+ n = 4
+ data arrow.ArrayData
+ )
if test.size != 0 {
n = test.size
}
- data := array.NewData(test.d, 0, b[:n], test.child, 0, 0)
+ if test.dict != nil {
+ data = array.NewDataWithDictionary(test.d, 0, b[:n], 0, 0, test.dict)
+ } else {
+ data = array.NewData(test.d, 0, b[:n], test.child, 0, 0)
+ }
if test.expPanic {
assert.PanicsWithValue(t, test.expError, func() {
diff --git a/go/arrow/array/builder.go b/go/arrow/array/builder.go
index 43710fedc6..91905a26c8 100644
--- a/go/arrow/array/builder.go
+++ b/go/arrow/array/builder.go
@@ -296,6 +296,8 @@ func NewBuilder(mem memory.Allocator, dtype arrow.DataType) Builder {
case arrow.SPARSE_UNION:
case arrow.DENSE_UNION:
case arrow.DICTIONARY:
+ typ := dtype.(*arrow.DictionaryType)
+ return NewDictionaryBuilder(mem, typ)
case arrow.LARGE_STRING:
case arrow.LARGE_BINARY:
case arrow.LARGE_LIST:
diff --git a/go/arrow/array/compare.go b/go/arrow/array/compare.go
index 6ef5faa99f..49e0199fc1 100644
--- a/go/arrow/array/compare.go
+++ b/go/arrow/array/compare.go
@@ -303,6 +303,9 @@ func ArrayEqual(left, right arrow.Array) bool {
case ExtensionArray:
r := right.(ExtensionArray)
return arrayEqualExtension(l, r)
+ case *Dictionary:
+ r := right.(*Dictionary)
+ return arrayEqualDict(l, r)
default:
panic(fmt.Errorf("arrow/array: unknown array type %T", l))
}
@@ -507,6 +510,9 @@ func arrayApproxEqual(left, right arrow.Array, opt equalOption) bool {
case *Map:
r := right.(*Map)
return arrayApproxEqualList(l.List, r.List, opt)
+ case *Dictionary:
+ r := right.(*Dictionary)
+ return arrayApproxEqualDict(l, r, opt)
case ExtensionArray:
r := right.(ExtensionArray)
return arrayApproxEqualExtension(l, r, opt)
diff --git a/go/arrow/array/data.go b/go/arrow/array/data.go
index 8386e60af7..03d3f6d328 100644
--- a/go/arrow/array/data.go
+++ b/go/arrow/array/data.go
@@ -29,13 +29,17 @@ import (
// Data represents the memory and metadata of an Arrow array.
type Data struct {
- refCount int64
- dtype arrow.DataType
- nulls int
- offset int
- length int
- buffers []*memory.Buffer // TODO(sgc): should this be an interface?
- childData []arrow.ArrayData // TODO(sgc): managed by ListArray, StructArray and UnionArray types
+ refCount int64
+ dtype arrow.DataType
+ nulls int
+ offset int
+ length int
+
+ // for dictionary arrays: buffers will be the null validity bitmap and the indexes that reference
+ // values in the dictionary member. childData would be empty in a dictionary array
+ buffers []*memory.Buffer // TODO(sgc): should this be an interface?
+ childData []arrow.ArrayData // TODO(sgc): managed by ListArray, StructArray and UnionArray types
+ dictionary *Data // only populated for dictionary arrays
}
// NewData creates a new Data.
@@ -63,6 +67,16 @@ func NewData(dtype arrow.DataType, length int, buffers []*memory.Buffer, childDa
}
}
+// NewDataWithDictionary creates a new data object, but also sets the provided dictionary into the data if it's not nil
+func NewDataWithDictionary(dtype arrow.DataType, length int, buffers []*memory.Buffer, nulls, offset int, dict *Data) *Data {
+ data := NewData(dtype, length, buffers, nil, nulls, offset)
+ if dict != nil {
+ dict.Retain()
+ }
+ data.dictionary = dict
+ return data
+}
+
// Reset sets the Data for re-use.
func (d *Data) Reset(dtype arrow.DataType, length int, buffers []*memory.Buffer, childData []arrow.ArrayData, nulls, offset int) {
// Retain new buffers before releasing existing buffers in-case they're the same ones to prevent accidental premature
@@ -121,7 +135,11 @@ func (d *Data) Release() {
for _, b := range d.childData {
b.Release()
}
- d.buffers, d.childData = nil, nil
+
+ if d.dictionary != nil {
+ d.dictionary.Release()
+ }
+ d.dictionary, d.buffers, d.childData = nil, nil, nil
}
}
@@ -142,6 +160,18 @@ func (d *Data) Buffers() []*memory.Buffer { return d.buffers }
func (d *Data) Children() []arrow.ArrayData { return d.childData }
+// Dictionary returns the ArrayData object for the dictionary member, or nil
+func (d *Data) Dictionary() arrow.ArrayData { return d.dictionary }
+
+// SetDictionary allows replacing the dictionary for this particular Data object
+func (d *Data) SetDictionary(dict arrow.ArrayData) {
+ dict.Retain()
+ if d.dictionary != nil {
+ d.dictionary.Release()
+ }
+ d.dictionary = dict.(*Data)
+}
+
// NewSliceData returns a new slice that shares backing data with the input.
// The returned Data slice starts at i and extends j-i elements, such as:
// slice := data[i:j]
@@ -166,14 +196,19 @@ func NewSliceData(data arrow.ArrayData, i, j int64) arrow.ArrayData {
}
}
+ if data.(*Data).dictionary != nil {
+ data.(*Data).dictionary.Retain()
+ }
+
o := &Data{
- refCount: 1,
- dtype: data.DataType(),
- nulls: UnknownNullCount,
- length: int(j - i),
- offset: data.Offset() + int(i),
- buffers: data.Buffers(),
- childData: data.Children(),
+ refCount: 1,
+ dtype: data.DataType(),
+ nulls: UnknownNullCount,
+ length: int(j - i),
+ offset: data.Offset() + int(i),
+ buffers: data.Buffers(),
+ childData: data.Children(),
+ dictionary: data.(*Data).dictionary,
}
if data.NullN() == 0 {
diff --git a/go/arrow/array/dictionary.go b/go/arrow/array/dictionary.go
new file mode 100644
index 0000000000..2a62a24d49
--- /dev/null
+++ b/go/arrow/array/dictionary.go
@@ -0,0 +1,1294 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package array
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "math/bits"
+ "sync/atomic"
+ "unsafe"
+
+ "github.com/apache/arrow/go/v8/arrow"
+ "github.com/apache/arrow/go/v8/arrow/bitutil"
+ "github.com/apache/arrow/go/v8/arrow/decimal128"
+ "github.com/apache/arrow/go/v8/arrow/float16"
+ "github.com/apache/arrow/go/v8/arrow/internal/debug"
+ "github.com/apache/arrow/go/v8/arrow/memory"
+ "github.com/apache/arrow/go/v8/internal/hashing"
+ "github.com/apache/arrow/go/v8/internal/utils"
+ "github.com/goccy/go-json"
+)
+
+// Dictionary represents the type for dictionary-encoded data with a data
+// dependent dictionary.
+//
+// A dictionary array contains an array of non-negative integers (the "dictionary"
+// indices") along with a data type containing a "dictionary" corresponding to
+// the distinct values represented in the data.
+//
+// For example, the array:
+//
+// ["foo", "bar", "foo", "bar", "foo", "bar"]
+//
+// with dictionary ["bar", "foo"], would have the representation of:
+//
+// indices: [1, 0, 1, 0, 1, 0]
+// dictionary: ["bar", "foo"]
+//
+// The indices in principle may be any integer type.
+type Dictionary struct {
+ array
+
+ indices Interface
+ dict Interface
+}
+
+// NewDictionaryArray constructs a dictionary array with the provided indices
+// and dictionary using the given type.
+func NewDictionaryArray(typ arrow.DataType, indices, dict Interface) *Dictionary {
+ a := &Dictionary{}
+ a.array.refCount = 1
+ dictdata := NewData(typ, indices.Len(), indices.Data().Buffers(), indices.Data().Children(), indices.NullN(), indices.Data().Offset())
+ dictdata.dictionary = dict.Data().(*Data)
+ dict.Data().Retain()
+
+ defer dictdata.Release()
+ a.setData(dictdata)
+ return a
+}
+
+// checkIndexBounds returns an error if any value in the provided integer
+// arraydata is >= the passed upperlimit or < 0. otherwise nil
+func checkIndexBounds(indices *Data, upperlimit uint64) error {
+ if indices.length == 0 {
+ return nil
+ }
+
+ var maxval uint64
+ switch indices.dtype.ID() {
+ case arrow.UINT8:
+ maxval = math.MaxUint8
+ case arrow.UINT16:
+ maxval = math.MaxUint16
+ case arrow.UINT32:
+ maxval = math.MaxUint32
+ case arrow.UINT64:
+ maxval = math.MaxUint64
+ }
+ // for unsigned integers, if the values array is larger than the maximum
+ // index value (especially for UINT8/UINT16), then there's no need to
+ // boundscheck. for signed integers we still need to bounds check
+ // because a value could be < 0.
+ isSigned := maxval == 0
+ if !isSigned && upperlimit > maxval {
+ return nil
+ }
+
+ start := indices.offset
+ end := indices.offset + indices.length
+
+ // TODO(ARROW-15950): lift BitSetRunReader from parquet to utils
+ // and use it here for performance improvement.
+
+ switch indices.dtype.ID() {
+ case arrow.INT8:
+ data := arrow.Int8Traits.CastFromBytes(indices.buffers[1].Bytes())
+ min, max := utils.GetMinMaxInt8(data[start:end])
+ if min < 0 || max >= int8(upperlimit) {
+ return fmt.Errorf("contains out of bounds index: min: %d, max: %d", min, max)
+ }
+ case arrow.UINT8:
+ data := arrow.Uint8Traits.CastFromBytes(indices.buffers[1].Bytes())
+ _, max := utils.GetMinMaxUint8(data[start:end])
+ if max >= uint8(upperlimit) {
+ return fmt.Errorf("contains out of bounds index: max: %d", max)
+ }
+ case arrow.INT16:
+ data := arrow.Int16Traits.CastFromBytes(indices.buffers[1].Bytes())
+ min, max := utils.GetMinMaxInt16(data[start:end])
+ if min < 0 || max >= int16(upperlimit) {
+ return fmt.Errorf("contains out of bounds index: min: %d, max: %d", min, max)
+ }
+ case arrow.UINT16:
+ data := arrow.Uint16Traits.CastFromBytes(indices.buffers[1].Bytes())
+ _, max := utils.GetMinMaxUint16(data[start:end])
+ if max >= uint16(upperlimit) {
+ return fmt.Errorf("contains out of bounds index: max: %d", max)
+ }
+ case arrow.INT32:
+ data := arrow.Int32Traits.CastFromBytes(indices.buffers[1].Bytes())
+ min, max := utils.GetMinMaxInt32(data[start:end])
+ if min < 0 || max >= int32(upperlimit) {
+ return fmt.Errorf("contains out of bounds index: min: %d, max: %d", min, max)
+ }
+ case arrow.UINT32:
+ data := arrow.Uint32Traits.CastFromBytes(indices.buffers[1].Bytes())
+ _, max := utils.GetMinMaxUint32(data[start:end])
+ if max >= uint32(upperlimit) {
+ return fmt.Errorf("contains out of bounds index: max: %d", max)
+ }
+ case arrow.INT64:
+ data := arrow.Int64Traits.CastFromBytes(indices.buffers[1].Bytes())
+ min, max := utils.GetMinMaxInt64(data[start:end])
+ if min < 0 || max >= int64(upperlimit) {
+ return fmt.Errorf("contains out of bounds index: min: %d, max: %d", min, max)
+ }
+ case arrow.UINT64:
+ data := arrow.Uint64Traits.CastFromBytes(indices.buffers[1].Bytes())
+ _, max := utils.GetMinMaxUint64(data[indices.offset : indices.offset+indices.length])
+ if max >= upperlimit {
+ return fmt.Errorf("contains out of bounds value: max: %d", max)
+ }
+ default:
+ return fmt.Errorf("invalid type for bounds checking: %T", indices.dtype)
+ }
+
+ return nil
+}
+
+// NewValidatedDictionaryArray constructs a dictionary array from the provided indices
+// and dictionary arrays, while also performing validation checks to ensure correctness
+// such as bounds checking at are usually skipped for performance.
+func NewValidatedDictionaryArray(typ *arrow.DictionaryType, indices, dict Interface) (*Dictionary, error) {
+ if indices.DataType().ID() != typ.IndexType.ID() {
+ return nil, fmt.Errorf("dictionary type index (%T) does not match indices array type (%T)", typ.IndexType, indices.DataType())
+ }
+
+ if !arrow.TypeEqual(typ.ValueType, dict.DataType()) {
+ return nil, fmt.Errorf("dictionary value type (%T) does not match dict array type (%T)", typ.ValueType, dict.DataType())
+ }
+
+ if err := checkIndexBounds(indices.Data().(*Data), uint64(dict.Len())); err != nil {
+ return nil, err
+ }
+
+ return NewDictionaryArray(typ, indices, dict), nil
+}
+
+// NewDictionaryData creates a strongly typed Dictionary array from
+// an ArrayData object with a datatype of arrow.Dictionary and a dictionary
+func NewDictionaryData(data arrow.ArrayData) *Dictionary {
+ a := &Dictionary{}
+ a.refCount = 1
+ a.setData(data.(*Data))
+ return a
+}
+
+func (d *Dictionary) Retain() {
+ atomic.AddInt64(&d.refCount, 1)
+}
+
+func (d *Dictionary) Release() {
+ debug.Assert(atomic.LoadInt64(&d.refCount) > 0, "too many releases")
+
+ if atomic.AddInt64(&d.refCount, -1) == 0 {
+ d.data.Release()
+ d.data, d.nullBitmapBytes = nil, nil
+ d.indices.Release()
+ d.indices = nil
+ if d.dict != nil {
+ d.dict.Release()
+ d.dict = nil
+ }
+ }
+}
+
+func (d *Dictionary) setData(data *Data) {
+ d.array.setData(data)
+
+ if data.dictionary == nil {
+ panic("arrow/array: no dictionary set in Data for Dictionary array")
+ }
+
+ dictType := data.dtype.(*arrow.DictionaryType)
+ debug.Assert(arrow.TypeEqual(dictType.ValueType, data.dictionary.DataType()), "mismatched dictionary value types")
+
+ indexData := NewData(dictType.IndexType, data.length, data.buffers, data.childData, data.nulls, data.offset)
+ defer indexData.Release()
+ d.indices = MakeFromData(indexData)
+}
+
+// Dictionary returns the values array that makes up the dictionary for this
+// array.
+func (d *Dictionary) Dictionary() Interface {
+ if d.dict == nil {
+ d.dict = MakeFromData(d.data.dictionary)
+ }
+ return d.dict
+}
+
+// Indices returns the underlying array of indices as it's own array
+func (d *Dictionary) Indices() Interface {
+ return d.indices
+}
+
+// CanCompareIndices returns true if the dictionary arrays can be compared
+// without having to unify the dictionaries themselves first.
+// This means that the index types are equal too.
+func (d *Dictionary) CanCompareIndices(other *Dictionary) bool {
+ if !arrow.TypeEqual(d.indices.DataType(), other.indices.DataType()) {
+ return false
+ }
+
+ minlen := int64(min(d.data.dictionary.length, other.data.dictionary.length))
+ return ArraySliceEqual(d.Dictionary(), 0, minlen, other.Dictionary(), 0, minlen)
+}
+
+func (d *Dictionary) String() string {
+ return fmt.Sprintf("{ dictionary: %v\n indices: %v }", d.Dictionary(), d.Indices())
+}
+
+// GetValueIndex returns the dictionary index for the value at index i of the array.
+// The actual value can be retrieved by using d.Dictionary().(valuetype).Value(d.GetValueIndex(i))
+func (d *Dictionary) GetValueIndex(i int) int {
+ indiceData := d.data.buffers[1].Bytes()
+ // we know the value is non-negative per the spec, so
+ // we can use the unsigned value regardless.
+ switch d.indices.DataType().ID() {
+ case arrow.UINT8, arrow.INT8:
+ return int(uint8(indiceData[d.data.offset+i]))
+ case arrow.UINT16, arrow.INT16:
+ return int(arrow.Uint16Traits.CastFromBytes(indiceData)[d.data.offset+i])
+ case arrow.UINT32, arrow.INT32:
+ idx := arrow.Uint32Traits.CastFromBytes(indiceData)[d.data.offset+i]
+ debug.Assert(bits.UintSize == 64 || idx <= math.MaxInt32, "arrow/dictionary: truncation of index value")
+ return int(idx)
+ case arrow.UINT64, arrow.INT64:
+ idx := arrow.Uint64Traits.CastFromBytes(indiceData)[d.data.offset+i]
+ debug.Assert((bits.UintSize == 32 && idx <= math.MaxInt32) || (bits.UintSize == 64 && idx <= math.MaxInt64), "arrow/dictionary: truncation of index value")
+ return int(idx)
+ }
+ debug.Assert(false, "unreachable dictionary index")
+ return -1
+}
+
+func (d *Dictionary) getOneForMarshal(i int) interface{} {
+ if d.IsNull(i) {
+ return nil
+ }
+ vidx := d.GetValueIndex(i)
+ return d.Dictionary().(arraymarshal).getOneForMarshal(vidx)
+}
+
+func (d *Dictionary) MarshalJSON() ([]byte, error) {
+ vals := make([]interface{}, d.Len())
+ for i := 0; i < d.Len(); i++ {
+ vals[i] = d.getOneForMarshal(i)
+ }
+ return json.Marshal(vals)
+}
+
+func arrayEqualDict(l, r *Dictionary) bool {
+ return ArrayEqual(l.Dictionary(), r.Dictionary()) && ArrayEqual(l.indices, r.indices)
+}
+
+func arrayApproxEqualDict(l, r *Dictionary, opt equalOption) bool {
+ return arrayApproxEqual(l.Dictionary(), r.Dictionary(), opt) && arrayApproxEqual(l.indices, r.indices, opt)
+}
+
+// helper for building the properly typed indices of the dictionary builder
+type indexBuilder struct {
+ Builder
+ Append func(int)
+}
+
+func createIndexBuilder(mem memory.Allocator, dt arrow.FixedWidthDataType) (ret indexBuilder, err error) {
+ ret = indexBuilder{Builder: NewBuilder(mem, dt)}
+ switch dt.ID() {
+ case arrow.INT8:
+ ret.Append = func(idx int) {
+ ret.Builder.(*Int8Builder).Append(int8(idx))
+ }
+ case arrow.UINT8:
+ ret.Append = func(idx int) {
+ ret.Builder.(*Uint8Builder).Append(uint8(idx))
+ }
+ case arrow.INT16:
+ ret.Append = func(idx int) {
+ ret.Builder.(*Int16Builder).Append(int16(idx))
+ }
+ case arrow.UINT16:
+ ret.Append = func(idx int) {
+ ret.Builder.(*Uint16Builder).Append(uint16(idx))
+ }
+ case arrow.INT32:
+ ret.Append = func(idx int) {
+ ret.Builder.(*Int32Builder).Append(int32(idx))
+ }
+ case arrow.UINT32:
+ ret.Append = func(idx int) {
+ ret.Builder.(*Uint32Builder).Append(uint32(idx))
+ }
+ case arrow.INT64:
+ ret.Append = func(idx int) {
+ ret.Builder.(*Int64Builder).Append(int64(idx))
+ }
+ case arrow.UINT64:
+ ret.Append = func(idx int) {
+ ret.Builder.(*Uint64Builder).Append(uint64(idx))
+ }
+ default:
+ debug.Assert(false, "dictionary index type must be integral")
+ err = fmt.Errorf("dictionary index type must be integral, not %s", dt)
+ }
+
+ return
+}
+
+// helper function to construct an appropriately typed memo table based on
+// the value type for the dictionary
+func createMemoTable(mem memory.Allocator, dt arrow.DataType) (ret hashing.MemoTable, err error) {
+ switch dt.ID() {
+ case arrow.INT8:
+ ret = hashing.NewInt8MemoTable(0)
+ case arrow.UINT8:
+ ret = hashing.NewUint8MemoTable(0)
+ case arrow.INT16:
+ ret = hashing.NewInt16MemoTable(0)
+ case arrow.UINT16:
+ ret = hashing.NewUint16MemoTable(0)
+ case arrow.INT32:
+ ret = hashing.NewInt32MemoTable(0)
+ case arrow.UINT32:
+ ret = hashing.NewUint32MemoTable(0)
+ case arrow.INT64:
+ ret = hashing.NewInt64MemoTable(0)
+ case arrow.UINT64:
+ ret = hashing.NewUint64MemoTable(0)
+ case arrow.DURATION, arrow.TIMESTAMP, arrow.DATE64, arrow.TIME64:
+ ret = hashing.NewInt64MemoTable(0)
+ case arrow.TIME32, arrow.DATE32, arrow.INTERVAL_MONTHS:
+ ret = hashing.NewInt32MemoTable(0)
+ case arrow.FLOAT16:
+ ret = hashing.NewUint16MemoTable(0)
+ case arrow.FLOAT32:
+ ret = hashing.NewFloat32MemoTable(0)
+ case arrow.FLOAT64:
+ ret = hashing.NewFloat64MemoTable(0)
+ case arrow.BINARY, arrow.FIXED_SIZE_BINARY, arrow.DECIMAL128, arrow.INTERVAL_DAY_TIME, arrow.INTERVAL_MONTH_DAY_NANO:
+ ret = hashing.NewBinaryMemoTable(0, 0, NewBinaryBuilder(mem, arrow.BinaryTypes.Binary))
+ case arrow.STRING:
+ ret = hashing.NewBinaryMemoTable(0, 0, NewBinaryBuilder(mem, arrow.BinaryTypes.String))
+ case arrow.NULL:
+ default:
+ debug.Assert(false, "unimplemented dictionary value type")
+ err = fmt.Errorf("unimplemented dictionary value type, %s", dt)
+ }
+
+ return
+}
+
+type DictionaryBuilder interface {
+ Builder
+
+ NewDictionaryArray() *Dictionary
+ NewDelta() (indices, delta Interface, err error)
+ AppendArray(Interface) error
+ ResetFull()
+}
+
+type dictionaryBuilder struct {
+ builder
+
+ dt *arrow.DictionaryType
+ deltaOffset int
+ memoTable hashing.MemoTable
+ idxBuilder indexBuilder
+}
+
+// NewDictionaryBuilderWithDict initializes a dictionary builder and inserts the values from `init` as the first
+// values in the dictionary, but does not insert them as values into the array.
+func NewDictionaryBuilderWithDict(mem memory.Allocator, dt *arrow.DictionaryType, init Interface) DictionaryBuilder {
+ if init != nil && !arrow.TypeEqual(dt.ValueType, init.DataType()) {
+ panic(fmt.Errorf("arrow/array: cannot initialize dictionary type %T with array of type %T", dt.ValueType, init.DataType()))
+ }
+
+ idxbldr, err := createIndexBuilder(mem, dt.IndexType.(arrow.FixedWidthDataType))
+ if err != nil {
+ panic(fmt.Errorf("arrow/array: unsupported builder for index type of %T", dt))
+ }
+
+ memo, err := createMemoTable(mem, dt.ValueType)
+ if err != nil {
+ panic(fmt.Errorf("arrow/array: unsupported builder for value type of %T", dt))
+ }
+
+ bldr := dictionaryBuilder{
+ builder: builder{refCount: 1, mem: mem},
+ idxBuilder: idxbldr,
+ memoTable: memo,
+ dt: dt,
+ }
+
+ switch dt.ValueType.ID() {
+ case arrow.NULL:
+ ret := &NullDictionaryBuilder{bldr}
+ debug.Assert(init == nil, "arrow/array: doesn't make sense to init a null dictionary")
+ return ret
+ case arrow.UINT8:
+ ret := &Uint8DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Uint8)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.INT8:
+ ret := &Int8DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Int8)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.UINT16:
+ ret := &Uint16DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Uint16)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.INT16:
+ ret := &Int16DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Int16)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.UINT32:
+ ret := &Uint32DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Uint32)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.INT32:
+ ret := &Int32DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Int32)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.UINT64:
+ ret := &Uint64DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Uint64)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.INT64:
+ ret := &Int64DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Int64)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.FLOAT16:
+ ret := &Float16DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Float16)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.FLOAT32:
+ ret := &Float32DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Float32)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.FLOAT64:
+ ret := &Float64DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Float64)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.STRING:
+ ret := &BinaryDictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertStringDictValues(init.(*String)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.BINARY:
+ ret := &BinaryDictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Binary)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.FIXED_SIZE_BINARY:
+ ret := &FixedSizeBinaryDictionaryBuilder{
+ bldr, dt.ValueType.(*arrow.FixedSizeBinaryType).ByteWidth,
+ }
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*FixedSizeBinary)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.DATE32:
+ ret := &Date32DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Date32)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.DATE64:
+ ret := &Date64DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Date64)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.TIMESTAMP:
+ ret := &TimestampDictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Timestamp)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.TIME32:
+ ret := &Time32DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Time32)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.TIME64:
+ ret := &Time64DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Time64)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.INTERVAL_MONTHS:
+ ret := &MonthIntervalDictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*MonthInterval)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.INTERVAL_DAY_TIME:
+ ret := &DayTimeDictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*DayTimeInterval)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.DECIMAL128:
+ ret := &Decimal128DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Decimal128)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.DECIMAL256:
+ case arrow.LIST:
+ case arrow.STRUCT:
+ case arrow.SPARSE_UNION:
+ case arrow.DENSE_UNION:
+ case arrow.DICTIONARY:
+ case arrow.MAP:
+ case arrow.EXTENSION:
+ case arrow.FIXED_SIZE_LIST:
+ case arrow.DURATION:
+ ret := &DurationDictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Duration)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.LARGE_STRING:
+ case arrow.LARGE_BINARY:
+ case arrow.LARGE_LIST:
+ case arrow.INTERVAL_MONTH_DAY_NANO:
+ ret := &MonthDayNanoDictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*MonthDayNanoInterval)); err != nil {
+ panic(err)
+ }
+ }
+ return ret
+ }
+
+ panic("arrow/array: unimplemented dictionary key type")
+}
+
+func NewDictionaryBuilder(mem memory.Allocator, dt *arrow.DictionaryType) DictionaryBuilder {
+ return NewDictionaryBuilderWithDict(mem, dt, nil)
+}
+
+func (b *dictionaryBuilder) Release() {
+ debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases")
+
+ if atomic.AddInt64(&b.refCount, -1) == 0 {
+ b.idxBuilder.Release()
+ b.idxBuilder.Builder = nil
+ if binmemo, ok := b.memoTable.(*hashing.BinaryMemoTable); ok {
+ binmemo.Release()
+ }
+ b.memoTable = nil
+ }
+}
+
+func (b *dictionaryBuilder) AppendNull() {
+ b.length += 1
+ b.nulls += 1
+ b.idxBuilder.AppendNull()
+}
+
+func (b *dictionaryBuilder) Reserve(n int) {
+ b.idxBuilder.Reserve(n)
+}
+
+func (b *dictionaryBuilder) Resize(n int) {
+ b.idxBuilder.Resize(n)
+ b.length = b.idxBuilder.Len()
+}
+
+func (b *dictionaryBuilder) ResetFull() {
+ b.builder.reset()
+ b.idxBuilder.NewArray().Release()
+ b.memoTable.Reset()
+}
+
+func (b *dictionaryBuilder) Cap() int { return b.idxBuilder.Cap() }
+
+// UnmarshalJSON is not yet implemented for dictionary builders and will always error.
+func (b *dictionaryBuilder) UnmarshalJSON([]byte) error {
+ return errors.New("unmarshal json to dictionary not yet implemented")
+}
+
+func (b *dictionaryBuilder) unmarshal(dec *json.Decoder) error {
+ return errors.New("unmarshal json to dictionary not yet implemented")
+}
+
+func (b *dictionaryBuilder) unmarshalOne(dec *json.Decoder) error {
+ return errors.New("unmarshal json to dictionary not yet implemented")
+}
+
+func (b *dictionaryBuilder) NewArray() Interface {
+ return b.NewDictionaryArray()
+}
+
+func (b *dictionaryBuilder) NewDictionaryArray() *Dictionary {
+ a := &Dictionary{}
+ a.refCount = 1
+
+ indices, dict, err := b.newWithDictOffset(0)
+ if err != nil {
+ panic(err)
+ }
+ defer indices.Release()
+
+ indices.dtype = b.dt
+ indices.dictionary = dict
+ a.setData(indices)
+ return a
+}
+
+func (b *dictionaryBuilder) newWithDictOffset(offset int) (indices, dict *Data, err error) {
+ idxarr := b.idxBuilder.NewArray()
+ defer idxarr.Release()
+
+ indices = idxarr.Data().(*Data)
+ indices.Retain()
+
+ dictBuffers := make([]*memory.Buffer, 2)
+
+ dictLength := b.memoTable.Size() - offset
+ dictBuffers[1] = memory.NewResizableBuffer(b.mem)
+ defer dictBuffers[1].Release()
+
+ if bintbl, ok := b.memoTable.(*hashing.BinaryMemoTable); ok {
+ switch b.dt.ValueType.ID() {
+ case arrow.BINARY, arrow.STRING:
+ dictBuffers = append(dictBuffers, memory.NewResizableBuffer(b.mem))
+ defer dictBuffers[2].Release()
+
+ dictBuffers[1].Resize(arrow.Int32SizeBytes * (dictLength + 1))
+ offsets := arrow.Int32Traits.CastFromBytes(dictBuffers[1].Bytes())
+ bintbl.CopyOffsetsSubset(offset, offsets)
+
+ valuesz := offsets[len(offsets)-1] - offsets[0]
+ dictBuffers[2].Resize(int(valuesz))
+ bintbl.CopyValuesSubset(offset, dictBuffers[2].Bytes())
+ default: // fixed size
+ bw := int(bitutil.BytesForBits(int64(b.dt.ValueType.(arrow.FixedWidthDataType).BitWidth())))
+ dictBuffers[1].Resize(dictLength * bw)
+ bintbl.CopyFixedWidthValues(offset, bw, dictBuffers[1].Bytes())
+ }
+ } else {
+ dictBuffers[1].Resize(b.memoTable.TypeTraits().BytesRequired(dictLength))
+ b.memoTable.WriteOutSubset(offset, dictBuffers[1].Bytes())
+ }
+
+ var nullcount int
+ if idx, ok := b.memoTable.GetNull(); ok && idx >= offset {
+ dictBuffers[0] = memory.NewResizableBuffer(b.mem)
+ defer dictBuffers[0].Release()
+
+ nullcount = 1
+
+ dictBuffers[0].Resize(int(bitutil.BytesForBits(int64(dictLength))))
+ memory.Set(dictBuffers[0].Bytes(), 0xFF)
+ bitutil.ClearBit(dictBuffers[0].Bytes(), idx)
+ }
+
+ b.deltaOffset = b.memoTable.Size()
+ dict = NewData(b.dt.ValueType, dictLength, dictBuffers, nil, nullcount, 0)
+ b.reset()
+ return
+}
+
+// NewDelta returns the dictionary indices and a delta dictionary since the
+// last time NewArray or NewDictionaryArray were called, and resets the state
+// of the builder (except for the dictionary / memotable)
+func (b *dictionaryBuilder) NewDelta() (indices, delta Interface, err error) {
+ indicesData, deltaData, err := b.newWithDictOffset(b.deltaOffset)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ defer indicesData.Release()
+ defer deltaData.Release()
+ indices, delta = MakeFromData(indicesData), MakeFromData(deltaData)
+ return
+}
+
+func (b *dictionaryBuilder) insertDictValue(val interface{}) error {
+ _, _, err := b.memoTable.GetOrInsert(val)
+ return err
+}
+
+func (b *dictionaryBuilder) appendValue(val interface{}) error {
+ idx, _, err := b.memoTable.GetOrInsert(val)
+ b.idxBuilder.Append(idx)
+ b.length += 1
+ return err
+}
+
+func getvalFn(arr Interface) func(i int) interface{} {
+ switch typedarr := arr.(type) {
+ case *Int8:
+ return func(i int) interface{} { return typedarr.Value(i) }
+ case *Uint8:
+ return func(i int) interface{} { return typedarr.Value(i) }
+ case *Int16:
+ return func(i int) interface{} { return typedarr.Value(i) }
+ case *Uint16:
+ return func(i int) interface{} { return typedarr.Value(i) }
+ case *Int32:
+ return func(i int) interface{} { return typedarr.Value(i) }
+ case *Uint32:
+ return func(i int) interface{} { return typedarr.Value(i) }
+ case *Int64:
+ return func(i int) interface{} { return typedarr.Value(i) }
+ case *Uint64:
+ return func(i int) interface{} { return typedarr.Value(i) }
+ case *Float16:
+ return func(i int) interface{} { return typedarr.Value(i).Uint16() }
+ case *Float32:
+ return func(i int) interface{} { return typedarr.Value(i) }
+ case *Float64:
+ return func(i int) interface{} { return typedarr.Value(i) }
+ case *Duration:
+ return func(i int) interface{} { return int64(typedarr.Value(i)) }
+ case *Timestamp:
+ return func(i int) interface{} { return int64(typedarr.Value(i)) }
+ case *Date64:
+ return func(i int) interface{} { return int64(typedarr.Value(i)) }
+ case *Time64:
+ return func(i int) interface{} { return int64(typedarr.Value(i)) }
+ case *Time32:
+ return func(i int) interface{} { return int32(typedarr.Value(i)) }
+ case *Date32:
+ return func(i int) interface{} { return int32(typedarr.Value(i)) }
+ case *MonthInterval:
+ return func(i int) interface{} { return int32(typedarr.Value(i)) }
+ case *Binary:
+ return func(i int) interface{} { return typedarr.Value(i) }
+ case *FixedSizeBinary:
+ return func(i int) interface{} { return typedarr.Value(i) }
+ case *String:
+ return func(i int) interface{} { return typedarr.Value(i) }
+ case *Decimal128:
+ return func(i int) interface{} {
+ val := typedarr.Value(i)
+ return (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+ }
+ case *DayTimeInterval:
+ return func(i int) interface{} {
+ val := typedarr.Value(i)
+ return (*(*[arrow.DayTimeIntervalSizeBytes]byte)(unsafe.Pointer(&val)))[:]
+ }
+ case *MonthDayNanoInterval:
+ return func(i int) interface{} {
+ val := typedarr.Value(i)
+ return (*(*[arrow.MonthDayNanoIntervalSizeBytes]byte)(unsafe.Pointer(&val)))[:]
+ }
+ }
+
+ panic("arrow/array: invalid dictionary value type")
+}
+
+func (b *dictionaryBuilder) AppendArray(arr Interface) error {
+ debug.Assert(arrow.TypeEqual(b.dt.ValueType, arr.DataType()), "wrong value type of array to append to dict")
+
+ valfn := getvalFn(arr)
+ for i := 0; i < arr.Len(); i++ {
+ if arr.IsNull(i) {
+ b.AppendNull()
+ } else {
+ if err := b.appendValue(valfn(i)); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+type NullDictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *NullDictionaryBuilder) NewArray() Interface {
+ return b.NewDictionaryArray()
+}
+
+func (b *NullDictionaryBuilder) NewDictionaryArray() *Dictionary {
+ idxarr := b.idxBuilder.NewArray()
+ defer idxarr.Release()
+
+ out := idxarr.Data().(*Data)
+ dictarr := NewNull(0)
+ defer dictarr.Release()
+
+ dictarr.data.Retain()
+ out.dtype = b.dt
+ out.dictionary = dictarr.data
+
+ return NewDictionaryData(out)
+}
+
+func (b *NullDictionaryBuilder) AppendArray(arr Interface) error {
+ if arr.DataType().ID() != arrow.NULL {
+ return fmt.Errorf("cannot append non-null array to null dictionary")
+ }
+
+ for i := 0; i < arr.(*Null).Len(); i++ {
+ b.AppendNull()
+ }
+ return nil
+}
+
+type Int8DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Int8DictionaryBuilder) Append(v int8) error { return b.appendValue(v) }
+func (b *Int8DictionaryBuilder) InsertDictValues(arr *Int8) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(v); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type Uint8DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Uint8DictionaryBuilder) Append(v uint8) error { return b.appendValue(v) }
+func (b *Uint8DictionaryBuilder) InsertDictValues(arr *Uint8) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(v); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type Int16DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Int16DictionaryBuilder) Append(v int16) error { return b.appendValue(v) }
+func (b *Int16DictionaryBuilder) InsertDictValues(arr *Int16) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(v); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type Uint16DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Uint16DictionaryBuilder) Append(v uint16) error { return b.appendValue(v) }
+func (b *Uint16DictionaryBuilder) InsertDictValues(arr *Uint16) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(v); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type Int32DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Int32DictionaryBuilder) Append(v int32) error { return b.appendValue(v) }
+func (b *Int32DictionaryBuilder) InsertDictValues(arr *Int32) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(v); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type Uint32DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Uint32DictionaryBuilder) Append(v uint32) error { return b.appendValue(v) }
+func (b *Uint32DictionaryBuilder) InsertDictValues(arr *Uint32) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(v); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type Int64DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Int64DictionaryBuilder) Append(v int64) error { return b.appendValue(v) }
+func (b *Int64DictionaryBuilder) InsertDictValues(arr *Int64) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(v); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type Uint64DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Uint64DictionaryBuilder) Append(v uint64) error { return b.appendValue(v) }
+func (b *Uint64DictionaryBuilder) InsertDictValues(arr *Uint64) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(v); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type DurationDictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *DurationDictionaryBuilder) Append(v arrow.Duration) error { return b.appendValue(int64(v)) }
+func (b *DurationDictionaryBuilder) InsertDictValues(arr *Duration) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(int64(v)); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type TimestampDictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *TimestampDictionaryBuilder) Append(v arrow.Timestamp) error { return b.appendValue(int64(v)) }
+func (b *TimestampDictionaryBuilder) InsertDictValues(arr *Timestamp) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(int64(v)); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type Time32DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Time32DictionaryBuilder) Append(v arrow.Time32) error { return b.appendValue(int32(v)) }
+func (b *Time32DictionaryBuilder) InsertDictValues(arr *Time32) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(int32(v)); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type Time64DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Time64DictionaryBuilder) Append(v arrow.Time64) error { return b.appendValue(int64(v)) }
+func (b *Time64DictionaryBuilder) InsertDictValues(arr *Time64) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(int64(v)); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type Date32DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Date32DictionaryBuilder) Append(v arrow.Date32) error { return b.appendValue(int32(v)) }
+func (b *Date32DictionaryBuilder) InsertDictValues(arr *Date32) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(int32(v)); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type Date64DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Date64DictionaryBuilder) Append(v arrow.Date64) error { return b.appendValue(int64(v)) }
+func (b *Date64DictionaryBuilder) InsertDictValues(arr *Date64) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(int64(v)); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type MonthIntervalDictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *MonthIntervalDictionaryBuilder) Append(v arrow.MonthInterval) error {
+ return b.appendValue(int32(v))
+}
+func (b *MonthIntervalDictionaryBuilder) InsertDictValues(arr *MonthInterval) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(int32(v)); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type Float16DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Float16DictionaryBuilder) Append(v float16.Num) error { return b.appendValue(v.Uint16()) }
+func (b *Float16DictionaryBuilder) InsertDictValues(arr *Float16) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(v.Uint16()); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type Float32DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Float32DictionaryBuilder) Append(v float32) error { return b.appendValue(v) }
+func (b *Float32DictionaryBuilder) InsertDictValues(arr *Float32) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(v); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type Float64DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Float64DictionaryBuilder) Append(v float64) error { return b.appendValue(v) }
+func (b *Float64DictionaryBuilder) InsertDictValues(arr *Float64) (err error) {
+ for _, v := range arr.values {
+ if err = b.insertDictValue(v); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type BinaryDictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *BinaryDictionaryBuilder) Append(v []byte) error {
+ if v == nil {
+ b.AppendNull()
+ return nil
+ }
+ return b.appendValue(v)
+}
+func (b *BinaryDictionaryBuilder) AppendString(v string) error { return b.appendValue(v) }
+func (b *BinaryDictionaryBuilder) InsertDictValues(arr *Binary) (err error) {
+ if !arrow.TypeEqual(arr.DataType(), b.dt.ValueType) {
+ return fmt.Errorf("dictionary insert type mismatch: cannot insert values of type %T to dictionary type %T", arr.DataType(), b.dt.ValueType)
+ }
+
+ for i := 0; i < arr.Len(); i++ {
+ if err = b.insertDictValue(arr.Value(i)); err != nil {
+ break
+ }
+ }
+ return
+}
+func (b *BinaryDictionaryBuilder) InsertStringDictValues(arr *String) (err error) {
+ if !arrow.TypeEqual(arr.DataType(), b.dt.ValueType) {
+ return fmt.Errorf("dictionary insert type mismatch: cannot insert values of type %T to dictionary type %T", arr.DataType(), b.dt.ValueType)
+ }
+
+ for i := 0; i < arr.Len(); i++ {
+ if err = b.insertDictValue(arr.Value(i)); err != nil {
+ break
+ }
+ }
+ return
+}
+
+type FixedSizeBinaryDictionaryBuilder struct {
+ dictionaryBuilder
+ byteWidth int
+}
+
+func (b *FixedSizeBinaryDictionaryBuilder) Append(v []byte) error {
+ return b.appendValue(v[:b.byteWidth])
+}
+func (b *FixedSizeBinaryDictionaryBuilder) InsertDictValues(arr *FixedSizeBinary) (err error) {
+ var (
+ beg = arr.array.data.offset * b.byteWidth
+ end = (arr.array.data.offset + arr.data.length) * b.byteWidth
+ )
+ data := arr.valueBytes[beg:end]
+ for len(data) > 0 {
+ if err = b.insertDictValue(data[:b.byteWidth]); err != nil {
+ break
+ }
+ data = data[b.byteWidth:]
+ }
+ return
+}
+
+type Decimal128DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Decimal128DictionaryBuilder) Append(v decimal128.Num) error {
+ return b.appendValue((*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&v)))[:])
+}
+func (b *Decimal128DictionaryBuilder) InsertDictValues(arr *Decimal128) (err error) {
+ data := arrow.Decimal128Traits.CastToBytes(arr.values)
+ for len(data) > 0 {
+ if err = b.insertDictValue(data[:arrow.Decimal128SizeBytes]); err != nil {
+ break
+ }
+ data = data[arrow.Decimal128SizeBytes:]
+ }
+ return
+}
+
+type MonthDayNanoDictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *MonthDayNanoDictionaryBuilder) Append(v arrow.MonthDayNanoInterval) error {
+ return b.appendValue((*(*[arrow.MonthDayNanoIntervalSizeBytes]byte)(unsafe.Pointer(&v)))[:])
+}
+func (b *MonthDayNanoDictionaryBuilder) InsertDictValues(arr *MonthDayNanoInterval) (err error) {
+ data := arrow.MonthDayNanoIntervalTraits.CastToBytes(arr.values)
+ for len(data) > 0 {
+ if err = b.insertDictValue(data[:arrow.MonthDayNanoIntervalSizeBytes]); err != nil {
+ break
+ }
+ data = data[arrow.MonthDayNanoIntervalSizeBytes:]
+ }
+ return
+}
+
+type DayTimeDictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *DayTimeDictionaryBuilder) Append(v arrow.DayTimeInterval) error {
+ return b.appendValue((*(*[arrow.DayTimeIntervalSizeBytes]byte)(unsafe.Pointer(&v)))[:])
+}
+func (b *DayTimeDictionaryBuilder) InsertDictValues(arr *DayTimeInterval) (err error) {
+ data := arrow.DayTimeIntervalTraits.CastToBytes(arr.values)
+ for len(data) > 0 {
+ if err = b.insertDictValue(data[:arrow.DayTimeIntervalSizeBytes]); err != nil {
+ break
+ }
+ data = data[arrow.DayTimeIntervalSizeBytes:]
+ }
+ return
+}
+
+var (
+ _ Interface = (*Dictionary)(nil)
+ _ Builder = (*dictionaryBuilder)(nil)
+)
diff --git a/go/arrow/array/dictionary_test.go b/go/arrow/array/dictionary_test.go
new file mode 100644
index 0000000000..906b0f54b8
--- /dev/null
+++ b/go/arrow/array/dictionary_test.go
@@ -0,0 +1,1183 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package array_test
+
+import (
+ "fmt"
+ "math"
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/apache/arrow/go/v8/arrow"
+ "github.com/apache/arrow/go/v8/arrow/array"
+ "github.com/apache/arrow/go/v8/arrow/bitutil"
+ "github.com/apache/arrow/go/v8/arrow/decimal128"
+ "github.com/apache/arrow/go/v8/arrow/memory"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/suite"
+)
+
+type PrimitiveDictionaryTestSuite struct {
+ suite.Suite
+
+ mem *memory.CheckedAllocator
+ typ arrow.DataType
+ reftyp reflect.Type
+}
+
+func (p *PrimitiveDictionaryTestSuite) SetupTest() {
+ p.mem = memory.NewCheckedAllocator(memory.DefaultAllocator)
+}
+
+func (p *PrimitiveDictionaryTestSuite) TearDownTest() {
+ p.mem.AssertSize(p.T(), 0)
+}
+
+func TestPrimitiveDictionaryBuilders(t *testing.T) {
+ tests := []struct {
+ name string
+ typ arrow.DataType
+ reftyp reflect.Type
+ }{
+ {"int8", arrow.PrimitiveTypes.Int8, reflect.TypeOf(int8(0))},
+ {"uint8", arrow.PrimitiveTypes.Uint8, reflect.TypeOf(uint8(0))},
+ {"int16", arrow.PrimitiveTypes.Int16, reflect.TypeOf(int16(0))},
+ {"uint16", arrow.PrimitiveTypes.Uint16, reflect.TypeOf(uint16(0))},
+ {"int32", arrow.PrimitiveTypes.Int32, reflect.TypeOf(int32(0))},
+ {"uint32", arrow.PrimitiveTypes.Uint32, reflect.TypeOf(uint32(0))},
+ {"int64", arrow.PrimitiveTypes.Int64, reflect.TypeOf(int64(0))},
+ {"uint64", arrow.PrimitiveTypes.Uint64, reflect.TypeOf(uint64(0))},
+ {"float32", arrow.PrimitiveTypes.Float32, reflect.TypeOf(float32(0))},
+ {"float64", arrow.PrimitiveTypes.Float64, reflect.TypeOf(float64(0))},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ suite.Run(t, &PrimitiveDictionaryTestSuite{typ: tt.typ, reftyp: tt.reftyp})
+ })
+ }
+}
+
+func (p *PrimitiveDictionaryTestSuite) TestDictionaryBuilderBasic() {
+ expectedType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: p.typ}
+ bldr := array.NewDictionaryBuilder(p.mem, expectedType)
+ defer bldr.Release()
+
+ builder := reflect.ValueOf(bldr)
+ appfn := builder.MethodByName("Append")
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(2).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ bldr.AppendNull()
+
+ p.EqualValues(4, bldr.Len())
+ p.EqualValues(1, bldr.NullN())
+
+ arr := bldr.NewArray().(*array.Dictionary)
+ defer arr.Release()
+
+ p.True(arrow.TypeEqual(expectedType, arr.DataType()))
+ expectedDict, _, err := array.FromJSON(p.mem, expectedType.ValueType, strings.NewReader("[1, 2]"))
+ p.NoError(err)
+ defer expectedDict.Release()
+
+ expectedIndices, _, err := array.FromJSON(p.mem, expectedType.IndexType, strings.NewReader("[0, 1, 0, null]"))
+ p.NoError(err)
+ defer expectedIndices.Release()
+
+ expected := array.NewDictionaryArray(expectedType, expectedIndices, expectedDict)
+ defer expected.Release()
+
+ p.True(array.ArrayEqual(expected, arr))
+}
+
+func (p *PrimitiveDictionaryTestSuite) TestDictionaryBuilderInit() {
+ valueType := p.typ
+ dictArr, _, err := array.FromJSON(p.mem, valueType, strings.NewReader("[1, 2]"))
+ p.NoError(err)
+ defer dictArr.Release()
+
+ dictType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: valueType}
+ bldr := array.NewDictionaryBuilderWithDict(p.mem, dictType, dictArr)
+ defer bldr.Release()
+
+ builder := reflect.ValueOf(bldr)
+ appfn := builder.MethodByName("Append")
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(2).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ bldr.AppendNull()
+
+ p.EqualValues(4, bldr.Len())
+ p.EqualValues(1, bldr.NullN())
+
+ arr := bldr.NewDictionaryArray()
+ defer arr.Release()
+
+ expectedIndices, _, err := array.FromJSON(p.mem, dictType.IndexType, strings.NewReader("[0, 1, 0, null]"))
+ p.NoError(err)
+ defer expectedIndices.Release()
+
+ expected := array.NewDictionaryArray(dictType, expectedIndices, dictArr)
+ defer expected.Release()
+
+ p.True(array.ArrayEqual(expected, arr))
+}
+
+func (p *PrimitiveDictionaryTestSuite) TestDictionaryNewBuilder() {
+ valueType := p.typ
+ dictArr, _, err := array.FromJSON(p.mem, valueType, strings.NewReader("[1, 2]"))
+ p.NoError(err)
+ defer dictArr.Release()
+
+ dictType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: valueType}
+ bldr := array.NewBuilder(p.mem, dictType)
+ defer bldr.Release()
+
+ builder := reflect.ValueOf(bldr)
+ appfn := builder.MethodByName("Append")
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(2).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ bldr.AppendNull()
+
+ p.EqualValues(4, bldr.Len())
+ p.EqualValues(1, bldr.NullN())
+
+ arr := bldr.NewArray().(*array.Dictionary)
+ defer arr.Release()
+
+ expectedIndices, _, err := array.FromJSON(p.mem, dictType.IndexType, strings.NewReader("[0, 1, 0, null]"))
+ p.NoError(err)
+ defer expectedIndices.Release()
+
+ expected := array.NewDictionaryArray(dictType, expectedIndices, dictArr)
+ defer expected.Release()
+
+ p.True(array.ArrayEqual(expected, arr))
+}
+
+func (p *PrimitiveDictionaryTestSuite) TestDictionaryBuilderAppendArr() {
+ valueType := p.typ
+ intermediate, _, err := array.FromJSON(p.mem, valueType, strings.NewReader("[1, 2, 1]"))
+ p.NoError(err)
+ defer intermediate.Release()
+
+ expectedType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: p.typ}
+ bldr := array.NewDictionaryBuilder(p.mem, expectedType)
+ defer bldr.Release()
+
+ bldr.AppendArray(intermediate)
+ result := bldr.NewArray()
+ defer result.Release()
+
+ expectedDict, _, err := array.FromJSON(p.mem, expectedType.ValueType, strings.NewReader("[1, 2]"))
+ p.NoError(err)
+ defer expectedDict.Release()
+
+ expectedIndices, _, err := array.FromJSON(p.mem, expectedType.IndexType, strings.NewReader("[0, 1, 0]"))
+ p.NoError(err)
+ defer expectedIndices.Release()
+
+ expected := array.NewDictionaryArray(expectedType, expectedIndices, expectedDict)
+ defer expected.Release()
+
+ p.True(array.ArrayEqual(expected, result))
+}
+
+func (p *PrimitiveDictionaryTestSuite) TestDictionaryBuilderDeltaDictionary() {
+ expectedType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: p.typ}
+ bldr := array.NewDictionaryBuilder(p.mem, expectedType)
+ defer bldr.Release()
+
+ builder := reflect.ValueOf(bldr)
+ appfn := builder.MethodByName("Append")
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(2).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(2).Convert(p.reftyp)})[0].Interface())
+
+ result := bldr.NewArray()
+ defer result.Release()
+
+ exdict, _, err := array.FromJSON(p.mem, p.typ, strings.NewReader("[1, 2]"))
+ p.NoError(err)
+ defer exdict.Release()
+ exindices, _, err := array.FromJSON(p.mem, arrow.PrimitiveTypes.Int8, strings.NewReader("[0, 1, 0, 1]"))
+ p.NoError(err)
+ defer exindices.Release()
+ expected := array.NewDictionaryArray(result.DataType().(*arrow.DictionaryType), exindices, exdict)
+ defer expected.Release()
+ p.True(array.ArrayEqual(expected, result))
+
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(2).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(3).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(3).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(3).Convert(p.reftyp)})[0].Interface())
+
+ indices, delta, err := bldr.NewDelta()
+ p.NoError(err)
+ defer indices.Release()
+ defer delta.Release()
+
+ exindices, _, _ = array.FromJSON(p.mem, arrow.PrimitiveTypes.Int8, strings.NewReader("[1, 2, 2, 0, 2]"))
+ defer exindices.Release()
+ exdelta, _, _ := array.FromJSON(p.mem, p.typ, strings.NewReader("[3]"))
+ defer exdelta.Release()
+
+ p.True(array.ArrayEqual(exindices, indices))
+ p.True(array.ArrayEqual(exdelta, delta))
+}
+
+func (p *PrimitiveDictionaryTestSuite) TestDictionaryBuilderDoubleDeltaDictionary() {
+ expectedType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: p.typ}
+ bldr := array.NewDictionaryBuilder(p.mem, expectedType)
+ defer bldr.Release()
+
+ builder := reflect.ValueOf(bldr)
+ appfn := builder.MethodByName("Append")
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(2).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(2).Convert(p.reftyp)})[0].Interface())
+
+ result := bldr.NewArray()
+ defer result.Release()
+
+ exdict, _, err := array.FromJSON(p.mem, p.typ, strings.NewReader("[1, 2]"))
+ p.NoError(err)
+ defer exdict.Release()
+ exindices, _, err := array.FromJSON(p.mem, arrow.PrimitiveTypes.Int8, strings.NewReader("[0, 1, 0, 1]"))
+ p.NoError(err)
+ defer exindices.Release()
+ expected := array.NewDictionaryArray(result.DataType().(*arrow.DictionaryType), exindices, exdict)
+ defer expected.Release()
+ p.True(array.ArrayEqual(expected, result))
+
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(2).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(3).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(3).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(3).Convert(p.reftyp)})[0].Interface())
+
+ indices, delta, err := bldr.NewDelta()
+ p.NoError(err)
+ defer indices.Release()
+ defer delta.Release()
+
+ exindices, _, _ = array.FromJSON(p.mem, arrow.PrimitiveTypes.Int8, strings.NewReader("[1, 2, 2, 0, 2]"))
+ defer exindices.Release()
+ exdelta, _, _ := array.FromJSON(p.mem, p.typ, strings.NewReader("[3]"))
+ defer exdelta.Release()
+
+ p.True(array.ArrayEqual(exindices, indices))
+ p.True(array.ArrayEqual(exdelta, delta))
+
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(2).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(3).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(4).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(5).Convert(p.reftyp)})[0].Interface())
+
+ indices, delta, err = bldr.NewDelta()
+ p.NoError(err)
+ defer indices.Release()
+ defer delta.Release()
+
+ exindices, _, _ = array.FromJSON(p.mem, arrow.PrimitiveTypes.Int8, strings.NewReader("[0, 1, 2, 3, 4]"))
+ defer exindices.Release()
+ exdelta, _, _ = array.FromJSON(p.mem, p.typ, strings.NewReader("[4, 5]"))
+ defer exdelta.Release()
+
+ p.True(array.ArrayEqual(exindices, indices))
+ p.True(array.ArrayEqual(exdelta, delta))
+}
+
+func (p *PrimitiveDictionaryTestSuite) TestNewResetBehavior() {
+ expectedType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: p.typ}
+ bldr := array.NewDictionaryBuilder(p.mem, expectedType)
+ defer bldr.Release()
+
+ builder := reflect.ValueOf(bldr)
+ appfn := builder.MethodByName("Append")
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ bldr.AppendNull()
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(2).Convert(p.reftyp)})[0].Interface())
+
+ p.Less(0, bldr.Cap())
+ p.Less(0, bldr.NullN())
+ p.Equal(4, bldr.Len())
+
+ result := bldr.NewDictionaryArray()
+ defer result.Release()
+
+ p.Zero(bldr.Cap())
+ p.Zero(bldr.Len())
+ p.Zero(bldr.NullN())
+
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(3).Convert(p.reftyp)})[0].Interface())
+ bldr.AppendNull()
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(4).Convert(p.reftyp)})[0].Interface())
+
+ result = bldr.NewDictionaryArray()
+ defer result.Release()
+
+ p.Equal(4, result.Dictionary().Len())
+}
+
+func (p *PrimitiveDictionaryTestSuite) TestResetFull() {
+ expectedType := &arrow.DictionaryType{IndexType: &arrow.Int32Type{}, ValueType: p.typ}
+ bldr := array.NewDictionaryBuilder(p.mem, expectedType)
+ defer bldr.Release()
+
+ builder := reflect.ValueOf(bldr)
+ appfn := builder.MethodByName("Append")
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ bldr.AppendNull()
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(1).Convert(p.reftyp)})[0].Interface())
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(2).Convert(p.reftyp)})[0].Interface())
+
+ result := bldr.NewDictionaryArray()
+ defer result.Release()
+
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(3).Convert(p.reftyp)})[0].Interface())
+ result = bldr.NewDictionaryArray()
+ defer result.Release()
+
+ exindices, _, _ := array.FromJSON(p.mem, arrow.PrimitiveTypes.Int32, strings.NewReader("[2]"))
+ exdict, _, _ := array.FromJSON(p.mem, p.typ, strings.NewReader("[1, 2, 3]"))
+ defer exindices.Release()
+ defer exdict.Release()
+
+ p.True(array.ArrayEqual(exindices, result.Indices()))
+ p.True(array.ArrayEqual(exdict, result.Dictionary()))
+
+ bldr.ResetFull()
+ p.Nil(appfn.Call([]reflect.Value{reflect.ValueOf(4).Convert(p.reftyp)})[0].Interface())
+ result = bldr.NewDictionaryArray()
+ defer result.Release()
+
+ exindices, _, _ = array.FromJSON(p.mem, arrow.PrimitiveTypes.Int32, strings.NewReader("[0]"))
+ exdict, _, _ = array.FromJSON(p.mem, p.typ, strings.NewReader("[4]"))
+ defer exindices.Release()
+ defer exdict.Release()
+
+ p.True(array.ArrayEqual(exindices, result.Indices()))
+ p.True(array.ArrayEqual(exdict, result.Dictionary()))
+}
+
+func TestBasicStringDictionaryBuilder(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ dictType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: arrow.BinaryTypes.String}
+ bldr := array.NewDictionaryBuilder(mem, dictType)
+ defer bldr.Release()
+
+ builder := bldr.(*array.BinaryDictionaryBuilder)
+ assert.NoError(t, builder.Append([]byte("test")))
+ assert.NoError(t, builder.AppendString("test2"))
+ assert.NoError(t, builder.AppendString("test"))
+
+ result := bldr.NewDictionaryArray()
+ defer result.Release()
+
+ exdict, _, _ := array.FromJSON(mem, arrow.BinaryTypes.String, strings.NewReader(`["test", "test2"]`))
+ defer exdict.Release()
+ exint, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int8, strings.NewReader("[0, 1, 0]"))
+ defer exint.Release()
+
+ assert.True(t, arrow.TypeEqual(dictType, result.DataType()))
+ expected := array.NewDictionaryArray(dictType, exint, exdict)
+ defer expected.Release()
+
+ assert.True(t, array.ArrayEqual(expected, result))
+}
+
+func TestStringDictionaryInsertValues(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ exdict, _, _ := array.FromJSON(mem, arrow.BinaryTypes.String, strings.NewReader(`["c", "a", "b", "d"]`))
+ defer exdict.Release()
+
+ invalidDict, _, err := array.FromJSON(mem, arrow.BinaryTypes.Binary, strings.NewReader(`["ZQ==", "Zg=="]`))
+ assert.NoError(t, err)
+ defer invalidDict.Release()
+
+ dictType := &arrow.DictionaryType{IndexType: &arrow.Int16Type{}, ValueType: arrow.BinaryTypes.String}
+ bldr := array.NewDictionaryBuilder(mem, dictType)
+ defer bldr.Release()
+
+ builder := bldr.(*array.BinaryDictionaryBuilder)
+ assert.NoError(t, builder.InsertStringDictValues(exdict.(*array.String)))
+ // inserting again should have no effect
+ assert.NoError(t, builder.InsertStringDictValues(exdict.(*array.String)))
+
+ assert.Error(t, builder.InsertDictValues(invalidDict.(*array.Binary)))
+
+ for i := 0; i < 2; i++ {
+ builder.AppendString("c")
+ builder.AppendString("a")
+ builder.AppendString("b")
+ builder.AppendNull()
+ builder.AppendString("d")
+ }
+
+ assert.Equal(t, 10, bldr.Len())
+
+ result := bldr.NewDictionaryArray()
+ defer result.Release()
+
+ exindices, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int16, strings.NewReader("[0, 1, 2, null, 3, 0, 1, 2, null, 3]"))
+ defer exindices.Release()
+ expected := array.NewDictionaryArray(dictType, exindices, exdict)
+ defer expected.Release()
+ assert.True(t, array.ArrayEqual(expected, result))
+}
+
+func TestStringDictionaryBuilderInit(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ dictArr, _, _ := array.FromJSON(mem, arrow.BinaryTypes.String, strings.NewReader(`["test", "test2"]`))
+ defer dictArr.Release()
+ intarr, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int8, strings.NewReader("[0, 1, 0]"))
+ defer intarr.Release()
+
+ dictType := &arrow.DictionaryType{IndexType: intarr.DataType().(arrow.FixedWidthDataType), ValueType: arrow.BinaryTypes.String}
+ bldr := array.NewDictionaryBuilderWithDict(mem, dictType, dictArr)
+ defer bldr.Release()
+
+ builder := bldr.(*array.BinaryDictionaryBuilder)
+ assert.NoError(t, builder.AppendString("test"))
+ assert.NoError(t, builder.AppendString("test2"))
+ assert.NoError(t, builder.AppendString("test"))
+
+ result := bldr.NewDictionaryArray()
+ defer result.Release()
+
+ expected := array.NewDictionaryArray(dictType, intarr, dictArr)
+ defer expected.Release()
+
+ assert.True(t, array.ArrayEqual(expected, result))
+}
+
+func TestStringDictionaryBuilderOnlyNull(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ dictType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: arrow.BinaryTypes.String}
+ bldr := array.NewDictionaryBuilder(mem, dictType)
+ defer bldr.Release()
+
+ bldr.AppendNull()
+ result := bldr.NewDictionaryArray()
+ defer result.Release()
+
+ dict, _, _ := array.FromJSON(mem, arrow.BinaryTypes.String, strings.NewReader("[]"))
+ defer dict.Release()
+ intarr, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int8, strings.NewReader("[null]"))
+ defer intarr.Release()
+
+ expected := array.NewDictionaryArray(dictType, intarr, dict)
+ defer expected.Release()
+
+ assert.True(t, array.ArrayEqual(expected, result))
+}
+
+func TestStringDictionaryBuilderDelta(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ dictType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: arrow.BinaryTypes.String}
+ bldr := array.NewDictionaryBuilder(mem, dictType)
+ defer bldr.Release()
+
+ builder := bldr.(*array.BinaryDictionaryBuilder)
+ assert.NoError(t, builder.AppendString("test"))
+ assert.NoError(t, builder.AppendString("test2"))
+ assert.NoError(t, builder.AppendString("test"))
+
+ result := bldr.NewDictionaryArray()
+ defer result.Release()
+
+ exdict, _, _ := array.FromJSON(mem, arrow.BinaryTypes.String, strings.NewReader(`["test", "test2"]`))
+ defer exdict.Release()
+ exint, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int8, strings.NewReader("[0, 1, 0]"))
+ defer exint.Release()
+
+ assert.True(t, arrow.TypeEqual(dictType, result.DataType()))
+ expected := array.NewDictionaryArray(dictType, exint, exdict)
+ defer expected.Release()
+
+ assert.True(t, array.ArrayEqual(expected, result))
+
+ assert.NoError(t, builder.AppendString("test2"))
+ assert.NoError(t, builder.AppendString("test3"))
+ assert.NoError(t, builder.AppendString("test2"))
+
+ indices, delta, err := builder.NewDelta()
+ assert.NoError(t, err)
+ defer indices.Release()
+ defer delta.Release()
+
+ exdelta, _, _ := array.FromJSON(mem, arrow.BinaryTypes.String, strings.NewReader(`["test3"]`))
+ defer exdelta.Release()
+ exint, _, _ = array.FromJSON(mem, arrow.PrimitiveTypes.Int8, strings.NewReader("[1, 2, 1]"))
+ defer exint.Release()
+
+ assert.True(t, array.ArrayEqual(exdelta, delta))
+ assert.True(t, array.ArrayEqual(exint, indices))
+}
+
+func TestStringDictionaryBuilderBigDelta(t *testing.T) {
+ const testlen = 2048
+
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ dictType := &arrow.DictionaryType{IndexType: &arrow.Int16Type{}, ValueType: arrow.BinaryTypes.String}
+ bldr := array.NewDictionaryBuilder(mem, dictType)
+ defer bldr.Release()
+ builder := bldr.(*array.BinaryDictionaryBuilder)
+
+ strbldr := array.NewStringBuilder(mem)
+ defer strbldr.Release()
+
+ intbldr := array.NewInt16Builder(mem)
+ defer intbldr.Release()
+
+ for idx := int16(0); idx < testlen; idx++ {
+ var b strings.Builder
+ b.WriteString("test")
+ fmt.Fprint(&b, idx)
+
+ val := b.String()
+ assert.NoError(t, builder.AppendString(val))
+ strbldr.Append(val)
+ intbldr.Append(idx)
+ }
+
+ result := bldr.NewDictionaryArray()
+ defer result.Release()
+ strarr := strbldr.NewStringArray()
+ defer strarr.Release()
+ intarr := intbldr.NewInt16Array()
+ defer intarr.Release()
+
+ expected := array.NewDictionaryArray(dictType, intarr, strarr)
+ defer expected.Release()
+
+ assert.True(t, array.ArrayEqual(expected, result))
+
+ strbldr2 := array.NewStringBuilder(mem)
+ defer strbldr2.Release()
+ intbldr2 := array.NewInt16Builder(mem)
+ defer intbldr2.Release()
+
+ for idx := int16(0); idx < testlen; idx++ {
+ builder.AppendString("test1")
+ intbldr2.Append(1)
+ }
+ for idx := int16(0); idx < testlen; idx++ {
+ builder.AppendString("test_new_value1")
+ intbldr2.Append(testlen)
+ }
+ strbldr2.Append("test_new_value1")
+
+ indices2, delta2, err := bldr.NewDelta()
+ assert.NoError(t, err)
+ defer indices2.Release()
+ defer delta2.Release()
+ strarr2 := strbldr2.NewStringArray()
+ defer strarr2.Release()
+ intarr2 := intbldr2.NewInt16Array()
+ defer intarr2.Release()
+
+ assert.True(t, array.ArrayEqual(intarr2, indices2))
+ assert.True(t, array.ArrayEqual(strarr2, delta2))
+
+ strbldr3 := array.NewStringBuilder(mem)
+ defer strbldr3.Release()
+ intbldr3 := array.NewInt16Builder(mem)
+ defer intbldr3.Release()
+
+ for idx := int16(0); idx < testlen; idx++ {
+ assert.NoError(t, builder.AppendString("test2"))
+ intbldr3.Append(2)
+ }
+ for idx := int16(0); idx < testlen; idx++ {
+ assert.NoError(t, builder.AppendString("test_new_value2"))
+ intbldr3.Append(testlen + 1)
+ }
+ strbldr3.Append("test_new_value2")
+
+ indices3, delta3, err := bldr.NewDelta()
+ assert.NoError(t, err)
+ defer indices3.Release()
+ defer delta3.Release()
+ strarr3 := strbldr3.NewStringArray()
+ defer strarr3.Release()
+ intarr3 := intbldr3.NewInt16Array()
+ defer intarr3.Release()
+
+ assert.True(t, array.ArrayEqual(intarr3, indices3))
+ assert.True(t, array.ArrayEqual(strarr3, delta3))
+}
+
+func TestFixedSizeBinaryDictionaryBuilder(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ dictType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: &arrow.FixedSizeBinaryType{ByteWidth: 4}}
+ bldr := array.NewDictionaryBuilder(mem, dictType)
+ defer bldr.Release()
+
+ builder := bldr.(*array.FixedSizeBinaryDictionaryBuilder)
+ test := []byte{12, 12, 11, 12}
+ test2 := []byte{12, 12, 11, 11}
+ assert.NoError(t, builder.Append(test))
+ assert.NoError(t, builder.Append(test2))
+ assert.NoError(t, builder.Append(test))
+
+ result := builder.NewDictionaryArray()
+ defer result.Release()
+
+ fsbBldr := array.NewFixedSizeBinaryBuilder(mem, dictType.ValueType.(*arrow.FixedSizeBinaryType))
+ defer fsbBldr.Release()
+
+ fsbBldr.Append(test)
+ fsbBldr.Append(test2)
+ fsbArr := fsbBldr.NewFixedSizeBinaryArray()
+ defer fsbArr.Release()
+
+ intbldr := array.NewInt8Builder(mem)
+ defer intbldr.Release()
+
+ intbldr.AppendValues([]int8{0, 1, 0}, nil)
+ intArr := intbldr.NewInt8Array()
+ defer intArr.Release()
+
+ expected := array.NewDictionaryArray(dictType, intArr, fsbArr)
+ defer expected.Release()
+
+ assert.True(t, array.ArrayEqual(expected, result))
+}
+
+func TestFixedSizeBinaryDictionaryBuilderInit(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ fsbBldr := array.NewFixedSizeBinaryBuilder(mem, &arrow.FixedSizeBinaryType{ByteWidth: 4})
+ defer fsbBldr.Release()
+
+ test, test2 := []byte("abcd"), []byte("wxyz")
+ fsbBldr.AppendValues([][]byte{test, test2}, nil)
+ dictArr := fsbBldr.NewFixedSizeBinaryArray()
+ defer dictArr.Release()
+
+ dictType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: dictArr.DataType()}
+ bldr := array.NewDictionaryBuilderWithDict(mem, dictType, dictArr)
+ defer bldr.Release()
+
+ builder := bldr.(*array.FixedSizeBinaryDictionaryBuilder)
+ assert.NoError(t, builder.Append(test))
+ assert.NoError(t, builder.Append(test2))
+ assert.NoError(t, builder.Append(test))
+
+ result := builder.NewDictionaryArray()
+ defer result.Release()
+
+ indices, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int8, strings.NewReader("[0, 1, 0]"))
+ defer indices.Release()
+
+ expected := array.NewDictionaryArray(dictType, indices, dictArr)
+ defer expected.Release()
+
+ assert.True(t, array.ArrayEqual(expected, result))
+}
+
+func TestFixedSizeBinaryDictionaryBuilderMakeBuilder(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ fsbBldr := array.NewFixedSizeBinaryBuilder(mem, &arrow.FixedSizeBinaryType{ByteWidth: 4})
+ defer fsbBldr.Release()
+
+ test, test2 := []byte("abcd"), []byte("wxyz")
+ fsbBldr.AppendValues([][]byte{test, test2}, nil)
+ dictArr := fsbBldr.NewFixedSizeBinaryArray()
+ defer dictArr.Release()
+
+ dictType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: dictArr.DataType()}
+ bldr := array.NewBuilder(mem, dictType)
+ defer bldr.Release()
+
+ builder := bldr.(*array.FixedSizeBinaryDictionaryBuilder)
+ assert.NoError(t, builder.Append(test))
+ assert.NoError(t, builder.Append(test2))
+ assert.NoError(t, builder.Append(test))
+
+ result := builder.NewDictionaryArray()
+ defer result.Release()
+
+ indices, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int8, strings.NewReader("[0, 1, 0]"))
+ defer indices.Release()
+
+ expected := array.NewDictionaryArray(dictType, indices, dictArr)
+ defer expected.Release()
+
+ assert.True(t, array.ArrayEqual(expected, result))
+}
+
+func TestFixedSizeBinaryDictionaryBuilderDeltaDictionary(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ dictType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: &arrow.FixedSizeBinaryType{ByteWidth: 4}}
+ bldr := array.NewDictionaryBuilder(mem, dictType)
+ defer bldr.Release()
+
+ builder := bldr.(*array.FixedSizeBinaryDictionaryBuilder)
+ test := []byte{12, 12, 11, 12}
+ test2 := []byte{12, 12, 11, 11}
+ test3 := []byte{12, 12, 11, 10}
+
+ assert.NoError(t, builder.Append(test))
+ assert.NoError(t, builder.Append(test2))
+ assert.NoError(t, builder.Append(test))
+
+ result1 := bldr.NewDictionaryArray()
+ defer result1.Release()
+
+ fsbBuilder := array.NewFixedSizeBinaryBuilder(mem, dictType.ValueType.(*arrow.FixedSizeBinaryType))
+ defer fsbBuilder.Release()
+
+ fsbBuilder.AppendValues([][]byte{test, test2}, nil)
+ fsbArr1 := fsbBuilder.NewFixedSizeBinaryArray()
+ defer fsbArr1.Release()
+
+ intBuilder := array.NewInt8Builder(mem)
+ defer intBuilder.Release()
+ intBuilder.AppendValues([]int8{0, 1, 0}, nil)
+ intArr1 := intBuilder.NewInt8Array()
+ defer intArr1.Release()
+
+ expected := array.NewDictionaryArray(dictType, intArr1, fsbArr1)
+ defer expected.Release()
+ assert.True(t, array.ArrayEqual(expected, result1))
+
+ assert.NoError(t, builder.Append(test))
+ assert.NoError(t, builder.Append(test2))
+ assert.NoError(t, builder.Append(test3))
+
+ indices2, delta2, err := builder.NewDelta()
+ assert.NoError(t, err)
+ defer indices2.Release()
+ defer delta2.Release()
+
+ fsbBuilder.Append(test3)
+ fsbArr2 := fsbBuilder.NewFixedSizeBinaryArray()
+ defer fsbArr2.Release()
+
+ intBuilder.AppendValues([]int8{0, 1, 2}, nil)
+ intArr2 := intBuilder.NewInt8Array()
+ defer intArr2.Release()
+
+ assert.True(t, array.ArrayEqual(intArr2, indices2))
+ assert.True(t, array.ArrayEqual(fsbArr2, delta2))
+}
+
+func TestDecimalDictionaryBuilderBasic(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ test := []decimal128.Num{decimal128.FromI64(12), decimal128.FromI64(12), decimal128.FromI64(11), decimal128.FromI64(12)}
+ dictType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: &arrow.Decimal128Type{Precision: 2, Scale: 0}}
+ bldr := array.NewDictionaryBuilder(mem, dictType)
+ defer bldr.Release()
+
+ builder := bldr.(*array.Decimal128DictionaryBuilder)
+ for _, v := range test {
+ assert.NoError(t, builder.Append(v))
+ }
+
+ result := bldr.NewDictionaryArray()
+ defer result.Release()
+
+ indices, _, _ := array.FromJSON(mem, dictType.IndexType, strings.NewReader("[0, 0, 1, 0]"))
+ defer indices.Release()
+ dict, _, _ := array.FromJSON(mem, dictType.ValueType, strings.NewReader("[12, 11]"))
+ defer dict.Release()
+
+ expected := array.NewDictionaryArray(dictType, indices, dict)
+ defer expected.Release()
+
+ assert.True(t, array.ArrayApproxEqual(expected, result))
+}
+
+func TestNullDictionaryBuilderBasic(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ dictType := &arrow.DictionaryType{IndexType: &arrow.Int8Type{}, ValueType: arrow.Null}
+ bldr := array.NewBuilder(mem, dictType)
+ defer bldr.Release()
+
+ builder := bldr.(*array.NullDictionaryBuilder)
+ builder.AppendNull()
+ builder.AppendNull()
+ builder.AppendNull()
+ assert.Equal(t, 3, builder.Len())
+ assert.Equal(t, 3, builder.NullN())
+
+ nullarr, _, _ := array.FromJSON(mem, arrow.Null, strings.NewReader("[null, null, null]"))
+ defer nullarr.Release()
+
+ assert.NoError(t, builder.AppendArray(nullarr))
+ assert.Equal(t, 6, bldr.Len())
+ assert.Equal(t, 6, bldr.NullN())
+
+ result := builder.NewDictionaryArray()
+ defer result.Release()
+ assert.Equal(t, 6, result.Len())
+ assert.Equal(t, 6, result.NullN())
+}
+
+func TestDictionaryEquals(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ var (
+ isValid = []bool{true, true, false, true, true, true}
+ dict, dict2 array.Interface
+ indices, indices2, indices3 array.Interface
+ )
+
+ dict, _, _ = array.FromJSON(mem, arrow.BinaryTypes.String, strings.NewReader(`["foo", "bar", "baz"]`))
+ defer dict.Release()
+ dictType := &arrow.DictionaryType{IndexType: &arrow.Uint16Type{}, ValueType: arrow.BinaryTypes.String}
+
+ dict2, _, _ = array.FromJSON(mem, arrow.BinaryTypes.String, strings.NewReader(`["foo", "bar", "baz", "qux"]`))
+ defer dict2.Release()
+ dictType2 := &arrow.DictionaryType{IndexType: &arrow.Uint16Type{}, ValueType: arrow.BinaryTypes.String}
+
+ idxbuilder := array.NewUint16Builder(mem)
+ defer idxbuilder.Release()
+
+ idxbuilder.AppendValues([]uint16{1, 2, math.MaxUint16, 0, 2, 0}, isValid)
+ indices = idxbuilder.NewArray()
+ defer indices.Release()
+
+ idxbuilder.AppendValues([]uint16{1, 2, 0, 0, 2, 0}, isValid)
+ indices2 = idxbuilder.NewArray()
+ defer indices2.Release()
+
+ idxbuilder.AppendValues([]uint16{1, 1, 0, 0, 2, 0}, isValid)
+ indices3 = idxbuilder.NewArray()
+ defer indices3.Release()
+
+ var (
+ arr = array.NewDictionaryArray(dictType, indices, dict)
+ arr2 = array.NewDictionaryArray(dictType, indices2, dict)
+ arr3 = array.NewDictionaryArray(dictType2, indices, dict2)
+ arr4 = array.NewDictionaryArray(dictType, indices3, dict)
+ )
+ defer func() {
+ arr.Release()
+ arr2.Release()
+ arr3.Release()
+ arr4.Release()
+ }()
+
+ assert.True(t, array.ArrayEqual(arr, arr))
+ // equal because the unequal index is masked by null
+ assert.True(t, array.ArrayEqual(arr, arr2))
+ // unequal dictionaries
+ assert.False(t, array.ArrayEqual(arr, arr3))
+ // unequal indices
+ assert.False(t, array.ArrayEqual(arr, arr4))
+ assert.True(t, array.ArraySliceEqual(arr, 3, 6, arr4, 3, 6))
+ assert.False(t, array.ArraySliceEqual(arr, 1, 3, arr4, 1, 3))
+
+ sz := arr.Len()
+ slice := array.NewSlice(arr, 2, int64(sz))
+ defer slice.Release()
+ slice2 := array.NewSlice(arr, 2, int64(sz))
+ defer slice2.Release()
+
+ assert.Equal(t, sz-2, slice.Len())
+ assert.True(t, array.ArrayEqual(slice, slice2))
+ assert.True(t, array.ArraySliceEqual(arr, 2, int64(arr.Len()), slice, 0, int64(slice.Len())))
+
+ // chained slice
+ slice2 = array.NewSlice(arr, 1, int64(arr.Len()))
+ defer slice2.Release()
+ slice2 = array.NewSlice(slice2, 1, int64(slice2.Len()))
+ defer slice2.Release()
+
+ assert.True(t, array.ArrayEqual(slice, slice2))
+ slice = array.NewSlice(arr, 1, 4)
+ defer slice.Release()
+ slice2 = array.NewSlice(arr, 1, 4)
+ defer slice2.Release()
+
+ assert.Equal(t, 3, slice.Len())
+ assert.True(t, array.ArrayEqual(slice, slice2))
+ assert.True(t, array.ArraySliceEqual(arr, 1, 4, slice, 0, int64(slice.Len())))
+}
+
+func TestDictionaryIndexTypes(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ dictIndexTypes := []arrow.DataType{
+ arrow.PrimitiveTypes.Int8, arrow.PrimitiveTypes.Uint8,
+ arrow.PrimitiveTypes.Int16, arrow.PrimitiveTypes.Uint16,
+ arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Uint32,
+ arrow.PrimitiveTypes.Int64, arrow.PrimitiveTypes.Uint64,
+ }
+
+ for _, indextyp := range dictIndexTypes {
+ t.Run(indextyp.Name(), func(t *testing.T) {
+ scope := memory.NewCheckedAllocatorScope(mem)
+ defer scope.CheckSize(t)
+
+ dictType := &arrow.DictionaryType{IndexType: indextyp, ValueType: arrow.BinaryTypes.String}
+ bldr := array.NewDictionaryBuilder(mem, dictType)
+ defer bldr.Release()
+
+ builder := bldr.(*array.BinaryDictionaryBuilder)
+ builder.AppendString("foo")
+ builder.AppendString("bar")
+ builder.AppendString("foo")
+ builder.AppendString("baz")
+ builder.Append(nil)
+
+ assert.Equal(t, 5, builder.Len())
+ assert.Equal(t, 1, builder.NullN())
+
+ result := builder.NewDictionaryArray()
+ defer result.Release()
+
+ expectedIndices, _, _ := array.FromJSON(mem, indextyp, strings.NewReader("[0, 1, 0, 2, null]"))
+ defer expectedIndices.Release()
+
+ assert.True(t, array.ArrayEqual(expectedIndices, result.Indices()))
+ })
+ }
+}
+
+func TestDictionaryFromArrays(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ dict, _, _ := array.FromJSON(mem, arrow.BinaryTypes.String, strings.NewReader(`["foo", "bar", "baz"]`))
+ defer dict.Release()
+
+ dictIndexTypes := []arrow.DataType{
+ arrow.PrimitiveTypes.Int8, arrow.PrimitiveTypes.Uint8,
+ arrow.PrimitiveTypes.Int16, arrow.PrimitiveTypes.Uint16,
+ arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Uint32,
+ arrow.PrimitiveTypes.Int64, arrow.PrimitiveTypes.Uint64,
+ }
+
+ for _, indextyp := range dictIndexTypes {
+ t.Run(indextyp.Name(), func(t *testing.T) {
+ scope := memory.NewCheckedAllocatorScope(mem)
+ defer scope.CheckSize(t)
+
+ dictType := &arrow.DictionaryType{IndexType: indextyp, ValueType: arrow.BinaryTypes.String}
+ indices1, _, _ := array.FromJSON(mem, indextyp, strings.NewReader("[1, 2, 0, 0, 2, 0]"))
+ defer indices1.Release()
+
+ indices2, _, _ := array.FromJSON(mem, indextyp, strings.NewReader("[1, 2, 0, 3, 2, 0]"))
+ defer indices2.Release()
+
+ arr1, err := array.NewValidatedDictionaryArray(dictType, indices1, dict)
+ assert.NoError(t, err)
+ defer arr1.Release()
+
+ _, err = array.NewValidatedDictionaryArray(dictType, indices2, dict)
+ assert.Error(t, err)
+
+ switch indextyp.ID() {
+ case arrow.INT8, arrow.INT16, arrow.INT32, arrow.INT64:
+ indices3, _, _ := array.FromJSON(mem, indextyp, strings.NewReader("[1, 2, 0, null, 2, 0]"))
+ defer indices3.Release()
+ bitutil.ClearBit(indices3.Data().Buffers()[0].Bytes(), 2)
+ arr3, err := array.NewValidatedDictionaryArray(dictType, indices3, dict)
+ assert.NoError(t, err)
+ defer arr3.Release()
+ }
+
+ indices4, _, _ := array.FromJSON(mem, indextyp, strings.NewReader("[1, 2, null, 3, 2, 0]"))
+ defer indices4.Release()
+ _, err = array.NewValidatedDictionaryArray(dictType, indices4, dict)
+ assert.Error(t, err)
+
+ diffIndexType := arrow.PrimitiveTypes.Int8
+ if indextyp.ID() == arrow.INT8 {
+ diffIndexType = arrow.PrimitiveTypes.Uint8
+ }
+ _, err = array.NewValidatedDictionaryArray(&arrow.DictionaryType{IndexType: diffIndexType, ValueType: arrow.BinaryTypes.String}, indices4, dict)
+ assert.Error(t, err)
+ })
+ }
+}
+
+func TestListOfDictionary(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ rootBuilder := array.NewBuilder(mem, arrow.ListOf(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int16, ValueType: arrow.BinaryTypes.String}))
+ defer rootBuilder.Release()
+
+ listBldr := rootBuilder.(*array.ListBuilder)
+ dictBldr := listBldr.ValueBuilder().(*array.BinaryDictionaryBuilder)
+
+ listBldr.Append(true)
+ expected := []string{}
+ for _, a := range []byte("abc") {
+ for _, d := range []byte("def") {
+ for _, g := range []byte("ghi") {
+ for _, j := range []byte("jkl") {
+ for _, m := range []byte("mno") {
+ for _, p := range []byte("pqr") {
+ if a+d+g+j+m+p%16 == 0 {
+ listBldr.Append(true)
+ }
+
+ str := string([]byte{a, d, g, j, m, p})
+ dictBldr.AppendString(str)
+ expected = append(expected, str)
+ }
+ }
+ }
+ }
+ }
+ }
+
+ strbldr := array.NewStringBuilder(mem)
+ defer strbldr.Release()
+ strbldr.AppendValues(expected, nil)
+
+ expectedDict := strbldr.NewStringArray()
+ defer expectedDict.Release()
+
+ arr := rootBuilder.NewArray()
+ defer arr.Release()
+
+ actualDict := arr.(*array.List).ListValues().(*array.Dictionary)
+ assert.True(t, array.ArrayEqual(expectedDict, actualDict.Dictionary()))
+}
+
+func TestDictionaryCanCompareIndices(t *testing.T) {
+ makeDict := func(mem memory.Allocator, idxType, valueType arrow.DataType, dictJSON string) *array.Dictionary {
+ indices, _, _ := array.FromJSON(mem, idxType, strings.NewReader("[]"))
+ defer indices.Release()
+ dict, _, _ := array.FromJSON(mem, valueType, strings.NewReader(dictJSON))
+ defer dict.Release()
+
+ out, _ := array.NewValidatedDictionaryArray(&arrow.DictionaryType{IndexType: idxType, ValueType: valueType}, indices, dict)
+ return out
+ }
+
+ compareSwap := func(t *testing.T, l, r *array.Dictionary, expected bool) {
+ assert.Equalf(t, expected, l.CanCompareIndices(r), "left: %s\nright: %s\n", l, r)
+ assert.Equalf(t, expected, r.CanCompareIndices(l), "left: %s\nright: %s\n", r, l)
+ }
+
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ t.Run("same", func(t *testing.T) {
+ arr := makeDict(mem, arrow.PrimitiveTypes.Int16, arrow.BinaryTypes.String, `["foo", "bar"]`)
+ defer arr.Release()
+ same := makeDict(mem, arrow.PrimitiveTypes.Int16, arrow.BinaryTypes.String, `["foo", "bar"]`)
+ defer same.Release()
+ compareSwap(t, arr, same, true)
+ })
+
+ t.Run("prefix dict", func(t *testing.T) {
+ arr := makeDict(mem, arrow.PrimitiveTypes.Int16, arrow.BinaryTypes.String, `["foo", "bar", "quux"]`)
+ defer arr.Release()
+ prefixDict := makeDict(mem, arrow.PrimitiveTypes.Int16, arrow.BinaryTypes.String, `["foo", "bar"]`)
+ defer prefixDict.Release()
+ compareSwap(t, arr, prefixDict, true)
+ })
+
+ t.Run("indices need cast", func(t *testing.T) {
+ arr := makeDict(mem, arrow.PrimitiveTypes.Int16, arrow.BinaryTypes.String, `["foo", "bar"]`)
+ defer arr.Release()
+ needcast := makeDict(mem, arrow.PrimitiveTypes.Int8, arrow.BinaryTypes.String, `["foo", "bar"]`)
+ defer needcast.Release()
+ compareSwap(t, arr, needcast, false)
+ })
+
+ t.Run("non prefix", func(t *testing.T) {
+ arr := makeDict(mem, arrow.PrimitiveTypes.Int16, arrow.BinaryTypes.String, `["foo", "bar", "quux"]`)
+ defer arr.Release()
+ nonPrefix := makeDict(mem, arrow.PrimitiveTypes.Int16, arrow.BinaryTypes.String, `["foo", "blink"]`)
+ defer nonPrefix.Release()
+ compareSwap(t, arr, nonPrefix, false)
+ })
+}
+
+func TestDictionaryGetValueIndex(t *testing.T) {
+ mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer mem.AssertSize(t, 0)
+
+ indicesJson := "[5, 0, 1, 3, 2, 4]"
+ indices64, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int64, strings.NewReader(indicesJson))
+ defer indices64.Release()
+ dict, _, _ := array.FromJSON(mem, arrow.PrimitiveTypes.Int32, strings.NewReader("[10, 20, 30, 40, 50, 60]"))
+ defer dict.Release()
+
+ dictIndexTypes := []arrow.DataType{
+ arrow.PrimitiveTypes.Int8, arrow.PrimitiveTypes.Uint8,
+ arrow.PrimitiveTypes.Int16, arrow.PrimitiveTypes.Uint16,
+ arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Uint32,
+ arrow.PrimitiveTypes.Int64, arrow.PrimitiveTypes.Uint64,
+ }
+ i64Index := indices64.(*array.Int64)
+ for _, idxt := range dictIndexTypes {
+ t.Run(idxt.Name(), func(t *testing.T) {
+ indices, _, _ := array.FromJSON(mem, idxt, strings.NewReader(indicesJson))
+ defer indices.Release()
+ dictType := &arrow.DictionaryType{IndexType: idxt, ValueType: arrow.PrimitiveTypes.Int32}
+
+ dictArr := array.NewDictionaryArray(dictType, indices, dict)
+ defer dictArr.Release()
+
+ const offset = 1
+ slicedDictArr := array.NewSlice(dictArr, offset, int64(dictArr.Len()))
+ defer slicedDictArr.Release()
+
+ for i := 0; i < indices.Len(); i++ {
+ assert.EqualValues(t, i64Index.Value(i), dictArr.GetValueIndex(i))
+ if i < slicedDictArr.Len() {
+ assert.EqualValues(t, i64Index.Value(i+offset), slicedDictArr.(*array.Dictionary).GetValueIndex(i))
+ }
+ }
+ })
+ }
+}
diff --git a/go/arrow/datatype_fixedwidth.go b/go/arrow/datatype_fixedwidth.go
index 0d4f8347e7..84170e5631 100644
--- a/go/arrow/datatype_fixedwidth.go
+++ b/go/arrow/datatype_fixedwidth.go
@@ -447,6 +447,36 @@ func ConvertTimestampValue(in, out TimeUnit, value int64) int64 {
return 0
}
+// DictionaryType represents categorical or dictionary-encoded in-memory data
+// It contains a dictionary-encoded value type (any type) and an index type
+// (any integer type).
+type DictionaryType struct {
+ IndexType DataType
+ ValueType DataType
+ Ordered bool
+}
+
+func (*DictionaryType) ID() Type { return DICTIONARY }
+func (*DictionaryType) Name() string { return "dictionary" }
+func (d *DictionaryType) BitWidth() int { return d.IndexType.(FixedWidthDataType).BitWidth() }
+func (d *DictionaryType) String() string {
+ return fmt.Sprintf("%s<values=%s, indices=%s, ordered=%t>",
+ d.Name(), d.ValueType, d.IndexType, d.Ordered)
+}
+func (d *DictionaryType) Fingerprint() string {
+ indexFingerprint := d.IndexType.Fingerprint()
+ valueFingerprint := d.ValueType.Fingerprint()
+ ordered := "1"
+ if !d.Ordered {
+ ordered = "0"
+ }
+
+ if len(valueFingerprint) > 0 {
+ return typeFingerprint(d) + indexFingerprint + valueFingerprint + ordered
+ }
+ return ordered
+}
+
var (
FixedWidthTypes = struct {
Boolean FixedWidthDataType
diff --git a/go/arrow/internal/arrjson/arrjson.go b/go/arrow/internal/arrjson/arrjson.go
index cc09078da6..d818d7c0c5 100644
--- a/go/arrow/internal/arrjson/arrjson.go
+++ b/go/arrow/internal/arrjson/arrjson.go
@@ -32,6 +32,7 @@ import (
"github.com/apache/arrow/go/v8/arrow/bitutil"
"github.com/apache/arrow/go/v8/arrow/decimal128"
"github.com/apache/arrow/go/v8/arrow/float16"
+ "github.com/apache/arrow/go/v8/arrow/internal/dictutils"
"github.com/apache/arrow/go/v8/arrow/ipc"
"github.com/apache/arrow/go/v8/arrow/memory"
)
@@ -91,6 +92,13 @@ type FieldWrapper struct {
Field
}
+type FieldDict struct {
+ ID int `json:"id"`
+ Type json.RawMessage `json:"indexType"`
+ idxType arrow.DataType `json:"-"`
+ Ordered bool `json:"isOrdered"`
+}
+
type Field struct {
Name string `json:"name"`
// the arrowType will get populated during unmarshalling by processing the
@@ -99,11 +107,12 @@ type Field struct {
// leave this as a json RawMessage in order to partially unmarshal as needed
// during marshal/unmarshal time so we can determine what the structure is
// actually expected to be.
- Type json.RawMessage `json:"type"`
- Nullable bool `json:"nullable"`
- Children []FieldWrapper `json:"children"`
- arrowMeta arrow.Metadata `json:"-"`
- Metadata []metaKV `json:"metadata,omitempty"`
+ Type json.RawMessage `json:"type"`
+ Nullable bool `json:"nullable"`
+ Children []FieldWrapper `json:"children"`
+ arrowMeta arrow.Metadata `json:"-"`
+ Dictionary *FieldDict `json:"dictionary,omitempty"`
+ Metadata []metaKV `json:"metadata,omitempty"`
}
type metaKV struct {
@@ -111,27 +120,9 @@ type metaKV struct {
Value string `json:"value"`
}
-func (f FieldWrapper) MarshalJSON() ([]byte, error) {
- // for extension types, add the extension type metadata appropriately
- // and then marshal as normal for the storage type.
- if f.arrowType.ID() == arrow.EXTENSION {
- exType := f.arrowType.(arrow.ExtensionType)
-
- mdkeys := append(f.arrowMeta.Keys(), ipc.ExtensionTypeKeyName)
- mdvals := append(f.arrowMeta.Values(), exType.ExtensionName())
-
- serializedData := exType.Serialize()
- if len(serializedData) > 0 {
- mdkeys = append(mdkeys, ipc.ExtensionMetadataKeyName)
- mdvals = append(mdvals, string(serializedData))
- }
-
- f.arrowMeta = arrow.NewMetadata(mdkeys, mdvals)
- f.arrowType = exType.StorageType()
- }
-
+func typeToJSON(arrowType arrow.DataType) (json.RawMessage, error) {
var typ interface{}
- switch dt := f.arrowType.(type) {
+ switch dt := arrowType.(type) {
case *arrow.NullType:
typ = nameJSON{"null"}
case *arrow.BooleanType:
@@ -221,11 +212,40 @@ func (f FieldWrapper) MarshalJSON() ([]byte, error) {
case *arrow.Decimal128Type:
typ = decimalJSON{"decimal", int(dt.Scale), int(dt.Precision)}
default:
- return nil, fmt.Errorf("unknown arrow.DataType %v", f.arrowType)
+ return nil, fmt.Errorf("unknown arrow.DataType %v", arrowType)
+ }
+
+ return json.Marshal(typ)
+}
+
+func (f FieldWrapper) MarshalJSON() ([]byte, error) {
+ // for extension types, add the extension type metadata appropriately
+ // and then marshal as normal for the storage type.
+ if f.arrowType.ID() == arrow.EXTENSION {
+ exType := f.arrowType.(arrow.ExtensionType)
+
+ mdkeys := append(f.arrowMeta.Keys(), ipc.ExtensionTypeKeyName)
+ mdvals := append(f.arrowMeta.Values(), exType.ExtensionName())
+
+ serializedData := exType.Serialize()
+ if len(serializedData) > 0 {
+ mdkeys = append(mdkeys, ipc.ExtensionMetadataKeyName)
+ mdvals = append(mdvals, string(serializedData))
+ }
+
+ f.arrowMeta = arrow.NewMetadata(mdkeys, mdvals)
+ f.arrowType = exType.StorageType()
}
var err error
- if f.Type, err = json.Marshal(typ); err != nil {
+ if f.arrowType.ID() == arrow.DICTIONARY {
+ f.arrowType = f.arrowType.(*arrow.DictionaryType).ValueType
+ if f.Dictionary.Type, err = typeToJSON(f.Dictionary.idxType); err != nil {
+ return nil, err
+ }
+ }
+
+ if f.Type, err = typeToJSON(f.arrowType); err != nil {
return nil, err
}
@@ -244,190 +264,205 @@ func (f FieldWrapper) MarshalJSON() ([]byte, error) {
return buf.Bytes(), err
}
-func (f *FieldWrapper) UnmarshalJSON(data []byte) error {
- if err := json.Unmarshal(data, &f.Field); err != nil {
- return err
- }
-
+func typeFromJSON(typ json.RawMessage, children []FieldWrapper) (arrowType arrow.DataType, err error) {
tmp := nameJSON{}
- if err := json.Unmarshal(f.Type, &tmp); err != nil {
- return err
+ if err = json.Unmarshal(typ, &tmp); err != nil {
+ return
}
switch tmp.Name {
case "null":
- f.arrowType = arrow.Null
+ arrowType = arrow.Null
case "bool":
- f.arrowType = arrow.FixedWidthTypes.Boolean
+ arrowType = arrow.FixedWidthTypes.Boolean
case "int":
t := bitWidthJSON{}
- if err := json.Unmarshal(f.Type, &t); err != nil {
- return err
+ if err = json.Unmarshal(typ, &t); err != nil {
+ return
}
switch t.Signed {
case true:
switch t.BitWidth {
case 8:
- f.arrowType = arrow.PrimitiveTypes.Int8
+ arrowType = arrow.PrimitiveTypes.Int8
case 16:
- f.arrowType = arrow.PrimitiveTypes.Int16
+ arrowType = arrow.PrimitiveTypes.Int16
case 32:
- f.arrowType = arrow.PrimitiveTypes.Int32
+ arrowType = arrow.PrimitiveTypes.Int32
case 64:
- f.arrowType = arrow.PrimitiveTypes.Int64
+ arrowType = arrow.PrimitiveTypes.Int64
}
default:
switch t.BitWidth {
case 8:
- f.arrowType = arrow.PrimitiveTypes.Uint8
+ arrowType = arrow.PrimitiveTypes.Uint8
case 16:
- f.arrowType = arrow.PrimitiveTypes.Uint16
+ arrowType = arrow.PrimitiveTypes.Uint16
case 32:
- f.arrowType = arrow.PrimitiveTypes.Uint32
+ arrowType = arrow.PrimitiveTypes.Uint32
case 64:
- f.arrowType = arrow.PrimitiveTypes.Uint64
+ arrowType = arrow.PrimitiveTypes.Uint64
}
}
case "floatingpoint":
t := floatJSON{}
- if err := json.Unmarshal(f.Type, &t); err != nil {
- return err
+ if err = json.Unmarshal(typ, &t); err != nil {
+ return
}
switch t.Precision {
case "HALF":
- f.arrowType = arrow.FixedWidthTypes.Float16
+ arrowType = arrow.FixedWidthTypes.Float16
case "SINGLE":
- f.arrowType = arrow.PrimitiveTypes.Float32
+ arrowType = arrow.PrimitiveTypes.Float32
case "DOUBLE":
- f.arrowType = arrow.PrimitiveTypes.Float64
+ arrowType = arrow.PrimitiveTypes.Float64
}
case "binary":
- f.arrowType = arrow.BinaryTypes.Binary
+ arrowType = arrow.BinaryTypes.Binary
case "utf8":
- f.arrowType = arrow.BinaryTypes.String
+ arrowType = arrow.BinaryTypes.String
case "date":
t := unitZoneJSON{}
- if err := json.Unmarshal(f.Type, &t); err != nil {
- return err
+ if err = json.Unmarshal(typ, &t); err != nil {
+ return
}
switch t.Unit {
case "DAY":
- f.arrowType = arrow.FixedWidthTypes.Date32
+ arrowType = arrow.FixedWidthTypes.Date32
case "MILLISECOND":
- f.arrowType = arrow.FixedWidthTypes.Date64
+ arrowType = arrow.FixedWidthTypes.Date64
}
case "time":
t := bitWidthJSON{}
- if err := json.Unmarshal(f.Type, &t); err != nil {
- return err
+ if err = json.Unmarshal(typ, &t); err != nil {
+ return
}
switch t.BitWidth {
case 32:
switch t.Unit {
case "SECOND":
- f.arrowType = arrow.FixedWidthTypes.Time32s
+ arrowType = arrow.FixedWidthTypes.Time32s
case "MILLISECOND":
- f.arrowType = arrow.FixedWidthTypes.Time32ms
+ arrowType = arrow.FixedWidthTypes.Time32ms
}
case 64:
switch t.Unit {
case "MICROSECOND":
- f.arrowType = arrow.FixedWidthTypes.Time64us
+ arrowType = arrow.FixedWidthTypes.Time64us
case "NANOSECOND":
- f.arrowType = arrow.FixedWidthTypes.Time64ns
+ arrowType = arrow.FixedWidthTypes.Time64ns
}
}
case "timestamp":
t := unitZoneJSON{}
- if err := json.Unmarshal(f.Type, &t); err != nil {
- return err
+ if err = json.Unmarshal(typ, &t); err != nil {
+ return
}
- f.arrowType = &arrow.TimestampType{TimeZone: t.TimeZone}
+ arrowType = &arrow.TimestampType{TimeZone: t.TimeZone}
switch t.Unit {
case "SECOND":
- f.arrowType.(*arrow.TimestampType).Unit = arrow.Second
+ arrowType.(*arrow.TimestampType).Unit = arrow.Second
case "MILLISECOND":
- f.arrowType.(*arrow.TimestampType).Unit = arrow.Millisecond
+ arrowType.(*arrow.TimestampType).Unit = arrow.Millisecond
case "MICROSECOND":
- f.arrowType.(*arrow.TimestampType).Unit = arrow.Microsecond
+ arrowType.(*arrow.TimestampType).Unit = arrow.Microsecond
case "NANOSECOND":
- f.arrowType.(*arrow.TimestampType).Unit = arrow.Nanosecond
+ arrowType.(*arrow.TimestampType).Unit = arrow.Nanosecond
}
case "list":
- f.arrowType = arrow.ListOfField(arrow.Field{
- Name: f.Children[0].Name,
- Type: f.Children[0].arrowType,
- Metadata: f.Children[0].arrowMeta,
- Nullable: f.Children[0].Nullable,
+ arrowType = arrow.ListOfField(arrow.Field{
+ Name: children[0].Name,
+ Type: children[0].arrowType,
+ Metadata: children[0].arrowMeta,
+ Nullable: children[0].Nullable,
})
case "map":
t := mapJSON{}
- if err := json.Unmarshal(f.Type, &t); err != nil {
- return err
+ if err = json.Unmarshal(typ, &t); err != nil {
+ return
}
- pairType := f.Children[0].arrowType
- f.arrowType = arrow.MapOf(pairType.(*arrow.StructType).Field(0).Type, pairType.(*arrow.StructType).Field(1).Type)
- f.arrowType.(*arrow.MapType).KeysSorted = t.KeysSorted
+ pairType := children[0].arrowType
+ arrowType = arrow.MapOf(pairType.(*arrow.StructType).Field(0).Type, pairType.(*arrow.StructType).Field(1).Type)
+ arrowType.(*arrow.MapType).KeysSorted = t.KeysSorted
case "struct":
- f.arrowType = arrow.StructOf(fieldsFromJSON(f.Children)...)
+ arrowType = arrow.StructOf(fieldsFromJSON(children)...)
case "fixedsizebinary":
t := byteWidthJSON{}
- if err := json.Unmarshal(f.Type, &t); err != nil {
- return err
+ if err = json.Unmarshal(typ, &t); err != nil {
+ return
}
- f.arrowType = &arrow.FixedSizeBinaryType{ByteWidth: t.ByteWidth}
+ arrowType = &arrow.FixedSizeBinaryType{ByteWidth: t.ByteWidth}
case "fixedsizelist":
t := listSizeJSON{}
- if err := json.Unmarshal(f.Type, &t); err != nil {
- return err
- }
- f.arrowType = arrow.FixedSizeListOfField(t.ListSize, arrow.Field{
- Name: f.Children[0].Name,
- Type: f.Children[0].arrowType,
- Metadata: f.Children[0].arrowMeta,
- Nullable: f.Children[0].Nullable,
+ if err = json.Unmarshal(typ, &t); err != nil {
+ return
+ }
+ arrowType = arrow.FixedSizeListOfField(t.ListSize, arrow.Field{
+ Name: children[0].Name,
+ Type: children[0].arrowType,
+ Metadata: children[0].arrowMeta,
+ Nullable: children[0].Nullable,
})
case "interval":
t := unitZoneJSON{}
- if err := json.Unmarshal(f.Type, &t); err != nil {
- return err
+ if err = json.Unmarshal(typ, &t); err != nil {
+ return
}
switch t.Unit {
case "YEAR_MONTH":
- f.arrowType = arrow.FixedWidthTypes.MonthInterval
+ arrowType = arrow.FixedWidthTypes.MonthInterval
case "DAY_TIME":
- f.arrowType = arrow.FixedWidthTypes.DayTimeInterval
+ arrowType = arrow.FixedWidthTypes.DayTimeInterval
case "MONTH_DAY_NANO":
- f.arrowType = arrow.FixedWidthTypes.MonthDayNanoInterval
+ arrowType = arrow.FixedWidthTypes.MonthDayNanoInterval
}
case "duration":
t := unitZoneJSON{}
- if err := json.Unmarshal(f.Type, &t); err != nil {
- return err
+ if err = json.Unmarshal(typ, &t); err != nil {
+ return
}
switch t.Unit {
case "SECOND":
- f.arrowType = arrow.FixedWidthTypes.Duration_s
+ arrowType = arrow.FixedWidthTypes.Duration_s
case "MILLISECOND":
- f.arrowType = arrow.FixedWidthTypes.Duration_ms
+ arrowType = arrow.FixedWidthTypes.Duration_ms
case "MICROSECOND":
- f.arrowType = arrow.FixedWidthTypes.Duration_us
+ arrowType = arrow.FixedWidthTypes.Duration_us
case "NANOSECOND":
- f.arrowType = arrow.FixedWidthTypes.Duration_ns
+ arrowType = arrow.FixedWidthTypes.Duration_ns
}
case "decimal":
t := decimalJSON{}
- if err := json.Unmarshal(f.Type, &t); err != nil {
- return err
+ if err = json.Unmarshal(typ, &t); err != nil {
+ return
}
- f.arrowType = &arrow.Decimal128Type{Precision: int32(t.Precision), Scale: int32(t.Scale)}
+ arrowType = &arrow.Decimal128Type{Precision: int32(t.Precision), Scale: int32(t.Scale)}
}
- if f.arrowType == nil {
- return fmt.Errorf("unhandled type unmarshalling from json: %s", tmp.Name)
+
+ if arrowType == nil {
+ err = fmt.Errorf("unhandled type unmarshalling from json: %s", tmp.Name)
}
+ return
+}
+func (f *FieldWrapper) UnmarshalJSON(data []byte) error {
var err error
+ if err = json.Unmarshal(data, &f.Field); err != nil {
+ return err
+ }
+
+ if f.arrowType, err = typeFromJSON(f.Type, f.Children); err != nil {
+ return err
+ }
+
+ if f.Dictionary != nil {
+ if f.Dictionary.idxType, err = typeFromJSON(f.Dictionary.Type, nil); err != nil {
+ return err
+ }
+ f.arrowType = &arrow.DictionaryType{IndexType: f.Dictionary.idxType, ValueType: f.arrowType}
+ }
+
if len(f.Metadata) > 0 { // unmarshal the key/value metadata pairs
var (
mdkeys = make([]string, 0, len(f.Metadata))
@@ -538,20 +573,47 @@ type mapJSON struct {
KeysSorted bool `json:"keysSorted,omitempty"`
}
-func schemaToJSON(schema *arrow.Schema) Schema {
+func schemaToJSON(schema *arrow.Schema, mapper *dictutils.Mapper) Schema {
return Schema{
- Fields: fieldsToJSON(schema.Fields()),
+ Fields: fieldsToJSON(schema.Fields(), dictutils.NewFieldPos(), mapper),
arrowMeta: schema.Metadata(),
}
}
-func schemaFromJSON(schema Schema) *arrow.Schema {
- return arrow.NewSchema(fieldsFromJSON(schema.Fields), &schema.arrowMeta)
+func schemaFromJSON(schema Schema, memo *dictutils.Memo) *arrow.Schema {
+ sc := arrow.NewSchema(fieldsFromJSON(schema.Fields), &schema.arrowMeta)
+ dictInfoFromJSONFields(schema.Fields, dictutils.NewFieldPos(), memo)
+ return sc
+}
+
+func dictInfoFromJSONFields(fields []FieldWrapper, pos dictutils.FieldPos, memo *dictutils.Memo) {
+ for i, f := range fields {
+ dictInfoFromJSON(f, pos.Child(int32(i)), memo)
+ }
}
-func fieldsToJSON(fields []arrow.Field) []FieldWrapper {
+func dictInfoFromJSON(field FieldWrapper, pos dictutils.FieldPos, memo *dictutils.Memo) {
+ if field.Dictionary != nil {
+ typ := field.arrowType
+ if typ.ID() == arrow.EXTENSION {
+ typ = typ.(arrow.ExtensionType).StorageType()
+ }
+ valueType := typ.(*arrow.DictionaryType).ValueType
+
+ if err := memo.Mapper.AddField(int64(field.Dictionary.ID), pos.Path()); err != nil {
+ panic(err)
+ }
+ if err := memo.AddType(int64(field.Dictionary.ID), valueType); err != nil {
+ panic(err)
+ }
+ }
+ dictInfoFromJSONFields(field.Children, pos, memo)
+}
+
+func fieldsToJSON(fields []arrow.Field, parentPos dictutils.FieldPos, mapper *dictutils.Mapper) []FieldWrapper {
o := make([]FieldWrapper, len(fields))
for i, f := range fields {
+ pos := parentPos.Child(int32(i))
o[i] = FieldWrapper{Field{
Name: f.Name,
arrowType: f.Type,
@@ -559,15 +621,33 @@ func fieldsToJSON(fields []arrow.Field) []FieldWrapper {
Children: []FieldWrapper{},
arrowMeta: f.Metadata,
}}
- switch dt := f.Type.(type) {
+ typ := f.Type
+ if typ.ID() == arrow.EXTENSION {
+ typ = typ.(arrow.ExtensionType).StorageType()
+ }
+ if typ.ID() == arrow.DICTIONARY {
+ dictType := typ.(*arrow.DictionaryType)
+ typ = dictType.ValueType
+ dictID, err := mapper.GetFieldID(pos.Path())
+ if err != nil {
+ panic(err)
+ }
+ o[i].Dictionary = &FieldDict{
+ idxType: dictType.IndexType,
+ ID: int(dictID),
+ Ordered: dictType.Ordered,
+ }
+ }
+
+ switch dt := typ.(type) {
case *arrow.ListType:
- o[i].Children = fieldsToJSON([]arrow.Field{dt.ElemField()})
+ o[i].Children = fieldsToJSON([]arrow.Field{dt.ElemField()}, pos, mapper)
case *arrow.FixedSizeListType:
- o[i].Children = fieldsToJSON([]arrow.Field{dt.ElemField()})
+ o[i].Children = fieldsToJSON([]arrow.Field{dt.ElemField()}, pos, mapper)
case *arrow.StructType:
- o[i].Children = fieldsToJSON(dt.Fields())
+ o[i].Children = fieldsToJSON(dt.Fields(), pos, mapper)
case *arrow.MapType:
- o[i].Children = fieldsToJSON([]arrow.Field{dt.ValueField()})
+ o[i].Children = fieldsToJSON([]arrow.Field{dt.ValueField()}, pos, mapper)
}
}
return o
@@ -590,27 +670,50 @@ func fieldFromJSON(f Field) arrow.Field {
}
}
+type Dictionary struct {
+ ID int64 `json:"id"`
+ Data Record `json:"data"`
+}
+
+func dictionariesFromJSON(mem memory.Allocator, dicts []Dictionary, memo *dictutils.Memo) {
+ for _, d := range dicts {
+ valueType, exists := memo.Type(d.ID)
+ if !exists {
+ panic(fmt.Errorf("arrow/json: no corresponding dictionary memo for id=%d", d.ID))
+ }
+
+ dict := arrayFromJSON(mem, valueType, d.Data.Columns[0])
+ defer dict.Release()
+ memo.Add(d.ID, dict)
+ }
+}
+
type Record struct {
Count int64 `json:"count"`
Columns []Array `json:"columns"`
}
-func recordsFromJSON(mem memory.Allocator, schema *arrow.Schema, recs []Record) []arrow.Record {
+func recordsFromJSON(mem memory.Allocator, schema *arrow.Schema, recs []Record, memo *dictutils.Memo) []arrow.Record {
vs := make([]arrow.Record, len(recs))
for i, rec := range recs {
- vs[i] = recordFromJSON(mem, schema, rec)
+ vs[i] = recordFromJSON(mem, schema, rec, memo)
}
return vs
}
-func recordFromJSON(mem memory.Allocator, schema *arrow.Schema, rec Record) arrow.Record {
+func recordFromJSON(mem memory.Allocator, schema *arrow.Schema, rec Record, memo *dictutils.Memo) arrow.Record {
arrs := arraysFromJSON(mem, schema, rec.Columns)
- defer func() {
- for _, arr := range arrs {
- arr.Release()
- }
- }()
- return array.NewRecord(schema, arrs, int64(rec.Count))
+ if err := dictutils.ResolveDictionaries(memo, arrs, dictutils.NewFieldPos(), mem); err != nil {
+ panic(err)
+ }
+
+ cols := make([]arrow.Array, len(arrs))
+ for i, d := range arrs {
+ cols[i] = array.MakeFromData(d)
+ defer d.Release()
+ defer cols[i].Release()
+ }
+ return array.NewRecord(schema, cols, int64(rec.Count))
}
func recordToJSON(rec arrow.Record) Record {
@@ -629,8 +732,8 @@ type Array struct {
Children []Array `json:"children,omitempty"`
}
-func arraysFromJSON(mem memory.Allocator, schema *arrow.Schema, arrs []Array) []arrow.Array {
- o := make([]arrow.Array, len(arrs))
+func arraysFromJSON(mem memory.Allocator, schema *arrow.Schema, arrs []Array) []arrow.ArrayData {
+ o := make([]arrow.ArrayData, len(arrs))
for i, v := range arrs {
o[i] = arrayFromJSON(mem, schema.Field(i).Type, v)
}
@@ -655,10 +758,17 @@ func validsToBitmap(valids []bool, mem memory.Allocator) *memory.Buffer {
return buf
}
-func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Array {
+func returnNewArrayData(bldr array.Builder) arrow.ArrayData {
+ arr := bldr.NewArray()
+ defer arr.Release()
+ arr.Data().Retain()
+ return arr.Data()
+}
+
+func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.ArrayData {
switch dt := dt.(type) {
case *arrow.NullType:
- return array.NewNull(arr.Count)
+ return array.NewNull(arr.Count).Data()
case *arrow.BooleanType:
bldr := array.NewBooleanBuilder(mem)
@@ -666,7 +776,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := boolsFromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Int8Type:
bldr := array.NewInt8Builder(mem)
@@ -674,7 +784,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := i8FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Int16Type:
bldr := array.NewInt16Builder(mem)
@@ -682,7 +792,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := i16FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Int32Type:
bldr := array.NewInt32Builder(mem)
@@ -690,7 +800,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := i32FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Int64Type:
bldr := array.NewInt64Builder(mem)
@@ -698,7 +808,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := i64FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Uint8Type:
bldr := array.NewUint8Builder(mem)
@@ -706,7 +816,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := u8FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Uint16Type:
bldr := array.NewUint16Builder(mem)
@@ -714,7 +824,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := u16FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Uint32Type:
bldr := array.NewUint32Builder(mem)
@@ -722,7 +832,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := u32FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Uint64Type:
bldr := array.NewUint64Builder(mem)
@@ -730,7 +840,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := u64FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Float16Type:
bldr := array.NewFloat16Builder(mem)
@@ -738,7 +848,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := f16FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Float32Type:
bldr := array.NewFloat32Builder(mem)
@@ -746,7 +856,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := f32FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Float64Type:
bldr := array.NewFloat64Builder(mem)
@@ -754,7 +864,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := f64FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.StringType:
bldr := array.NewStringBuilder(mem)
@@ -762,7 +872,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := strFromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.BinaryType:
bldr := array.NewBinaryBuilder(mem, dt)
@@ -770,7 +880,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := bytesFromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.ListType:
valids := validsFromJSON(arr.Valids)
@@ -781,11 +891,9 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
defer bitmap.Release()
nulls := arr.Count - bitutil.CountSetBits(bitmap.Bytes(), 0, arr.Count)
- data := array.NewData(dt, arr.Count, []*memory.Buffer{bitmap,
+ return array.NewData(dt, arr.Count, []*memory.Buffer{bitmap,
memory.NewBufferBytes(arrow.Int32Traits.CastToBytes(arr.Offset))},
- []arrow.ArrayData{elems.Data()}, nulls, 0)
- defer data.Release()
- return array.NewListData(data)
+ []arrow.ArrayData{elems}, nulls, 0)
case *arrow.FixedSizeListType:
valids := validsFromJSON(arr.Valids)
@@ -796,9 +904,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
defer bitmap.Release()
nulls := arr.Count - bitutil.CountSetBits(bitmap.Bytes(), 0, arr.Count)
- data := array.NewData(dt, arr.Count, []*memory.Buffer{bitmap}, []arrow.ArrayData{elems.Data()}, nulls, 0)
- defer data.Release()
- return array.NewFixedSizeListData(data)
+ return array.NewData(dt, arr.Count, []*memory.Buffer{bitmap}, []arrow.ArrayData{elems}, nulls, 0)
case *arrow.StructType:
valids := validsFromJSON(arr.Valids)
@@ -811,13 +917,10 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
for i := range fields {
child := arrayFromJSON(mem, dt.Field(i).Type, arr.Children[i])
defer child.Release()
- fields[i] = child.Data()
+ fields[i] = child
}
- data := array.NewData(dt, arr.Count, []*memory.Buffer{bitmap}, fields, nulls, 0)
- defer data.Release()
-
- return array.NewStructData(data)
+ return array.NewData(dt, arr.Count, []*memory.Buffer{bitmap}, fields, nulls, 0)
case *arrow.FixedSizeBinaryType:
bldr := array.NewFixedSizeBinaryBuilder(mem, dt)
@@ -836,7 +939,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
}
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.MapType:
valids := validsFromJSON(arr.Valids)
@@ -847,11 +950,9 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
defer bitmap.Release()
nulls := arr.Count - bitutil.CountSetBits(bitmap.Bytes(), 0, arr.Count)
- data := array.NewData(dt, arr.Count, []*memory.Buffer{bitmap,
+ return array.NewData(dt, arr.Count, []*memory.Buffer{bitmap,
memory.NewBufferBytes(arrow.Int32Traits.CastToBytes(arr.Offset))},
- []arrow.ArrayData{elems.Data()}, nulls, 0)
- defer data.Release()
- return array.NewMapData(data)
+ []arrow.ArrayData{elems}, nulls, 0)
case *arrow.Date32Type:
bldr := array.NewDate32Builder(mem)
@@ -859,7 +960,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := date32FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Date64Type:
bldr := array.NewDate64Builder(mem)
@@ -867,7 +968,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := date64FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Time32Type:
bldr := array.NewTime32Builder(mem, dt)
@@ -875,7 +976,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := time32FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Time64Type:
bldr := array.NewTime64Builder(mem, dt)
@@ -883,7 +984,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := time64FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.TimestampType:
bldr := array.NewTimestampBuilder(mem, dt)
@@ -891,7 +992,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := timestampFromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.MonthIntervalType:
bldr := array.NewMonthIntervalBuilder(mem)
@@ -899,7 +1000,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := monthintervalFromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.DayTimeIntervalType:
bldr := array.NewDayTimeIntervalBuilder(mem)
@@ -907,7 +1008,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := daytimeintervalFromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.MonthDayNanoIntervalType:
bldr := array.NewMonthDayNanoIntervalBuilder(mem)
@@ -915,7 +1016,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := monthDayNanointervalFromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.DurationType:
bldr := array.NewDurationBuilder(mem, dt)
@@ -923,7 +1024,7 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := durationFromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case *arrow.Decimal128Type:
bldr := array.NewDecimal128Builder(mem, dt)
@@ -931,17 +1032,21 @@ func arrayFromJSON(mem memory.Allocator, dt arrow.DataType, arr Array) arrow.Arr
data := decimal128FromJSON(arr.Data)
valids := validsFromJSON(arr.Valids)
bldr.AppendValues(data, valids)
- return bldr.NewArray()
+ return returnNewArrayData(bldr)
case arrow.ExtensionType:
storage := arrayFromJSON(mem, dt.StorageType(), arr)
defer storage.Release()
- return array.NewExtensionArrayWithStorage(dt, storage)
+ return array.NewData(dt, storage.Len(), storage.Buffers(), storage.Children(), storage.NullN(), storage.Offset())
+
+ case *arrow.DictionaryType:
+ indices := arrayFromJSON(mem, dt.IndexType, arr)
+ defer indices.Release()
+ return array.NewData(dt, indices.Len(), indices.Buffers(), indices.Children(), indices.NullN(), indices.Offset())
default:
panic(fmt.Errorf("unknown data type %v %T", dt, dt))
}
- panic("impossible")
}
func arrayToJSON(field arrow.Field, arr arrow.Array) Array {
@@ -1209,10 +1314,12 @@ func arrayToJSON(field arrow.Field, arr arrow.Array) Array {
case array.ExtensionArray:
return arrayToJSON(field, arr.Storage())
+ case *array.Dictionary:
+ return arrayToJSON(field, arr.Indices())
+
default:
panic(fmt.Errorf("unknown array type %T", arr))
}
- panic("impossible")
}
func validsFromJSON(vs []int) []bool {
diff --git a/go/arrow/internal/arrjson/arrjson_test.go b/go/arrow/internal/arrjson/arrjson_test.go
index 3fc8dedb0a..ced0f48f0c 100644
--- a/go/arrow/internal/arrjson/arrjson_test.go
+++ b/go/arrow/internal/arrjson/arrjson_test.go
@@ -44,6 +44,7 @@ func TestReadWrite(t *testing.T) {
wantJSONs["decimal128"] = makeDecimal128sWantJSONs()
wantJSONs["maps"] = makeMapsWantJSONs()
wantJSONs["extension"] = makeExtensionsWantJSONs()
+ wantJSONs["dictionary"] = makeDictionaryWantJSONs()
tempDir, err := ioutil.TempDir("", "go-arrow-read-write-")
if err != nil {
@@ -3990,6 +3991,431 @@ func makeMapsWantJSONs() string {
}`
}
+func makeDictionaryWantJSONs() string {
+ return `{
+ "schema": {
+ "fields": [
+ {
+ "name": "dict0",
+ "type": {
+ "name": "utf8"
+ },
+ "nullable": true,
+ "children": [],
+ "dictionary": {
+ "id": 0,
+ "indexType": {
+ "name": "int",
+ "isSigned": true,
+ "bitWidth": 8
+ },
+ "isOrdered": false
+ }
+ },
+ {
+ "name": "dict1",
+ "type": {
+ "name": "utf8"
+ },
+ "nullable": true,
+ "children": [],
+ "dictionary": {
+ "id": 1,
+ "indexType": {
+ "name": "int",
+ "isSigned": true,
+ "bitWidth": 32
+ },
+ "isOrdered": false
+ }
+ },
+ {
+ "name": "dict2",
+ "type": {
+ "name": "int",
+ "isSigned": true,
+ "bitWidth": 64
+ },
+ "nullable": true,
+ "children": [],
+ "dictionary": {
+ "id": 2,
+ "indexType": {
+ "name": "int",
+ "isSigned": true,
+ "bitWidth": 16
+ },
+ "isOrdered": false
+ }
+ }
+ ]
+ },
+ "dictionaries": [
+ {
+ "id": 0,
+ "data": {
+ "count": 10,
+ "columns": [
+ {
+ "name": "DICT0",
+ "count": 10,
+ "VALIDITY": [
+ 1,
+ 1,
+ 0,
+ 0,
+ 0,
+ 1,
+ 1,
+ 0,
+ 1,
+ 0
+ ],
+ "OFFSET": [
+ 0,
+ 7,
+ 16,
+ 16,
+ 16,
+ 16,
+ 28,
+ 39,
+ 39,
+ 46,
+ 46
+ ],
+ "DATA": [
+ "gen3wjf",
+ "bbg61\u00b5\u00b0",
+ "",
+ "",
+ "",
+ "\u00f4\u00f42n\u20acm\u00a3",
+ "jb2b\u20acd\u20ac",
+ "",
+ "jfjddrg",
+ ""
+ ]
+ }
+ ]
+ }
+ },
+ {
+ "id": 1,
+ "data": {
+ "count": 5,
+ "columns": [
+ {
+ "name": "DICT1",
+ "count": 5,
+ "VALIDITY": [
+ 1,
+ 1,
+ 1,
+ 1,
+ 1
+ ],
+ "OFFSET": [
+ 0,
+ 8,
+ 18,
+ 27,
+ 35,
+ 45
+ ],
+ "DATA": [
+ "\u00c2arcall",
+ "\u77e23b\u00b0eif",
+ "i3ak\u00b0k\u00b5",
+ "gp16\u00a3nd",
+ "f4\u00b01e\u00c2\u00b0"
+ ]
+ }
+ ]
+ }
+ },
+ {
+ "id": 2,
+ "data": {
+ "count": 50,
+ "columns": [
+ {
+ "name": "DICT2",
+ "count": 50,
+ "VALIDITY": [
+ 1,
+ 0,
+ 0,
+ 1,
+ 1,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 1,
+ 1,
+ 0,
+ 0,
+ 1,
+ 1,
+ 0,
+ 1,
+ 1,
+ 1,
+ 1,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 1,
+ 0,
+ 1,
+ 1,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 1,
+ 0,
+ 1,
+ 1,
+ 1,
+ 1,
+ 0,
+ 0,
+ 1,
+ 1,
+ 0
+ ],
+ "DATA": [
+ "-2147483648",
+ "2147483647",
+ "97251241",
+ "-315526314",
+ "-256834552",
+ "-1159355470",
+ "800976983",
+ "-1728247486",
+ "-1784101814",
+ "1320684343",
+ "-788965748",
+ "1298782506",
+ "1971840342",
+ "686564052",
+ "-115364825",
+ "1787500433",
+ "-123446338",
+ "-1973712113",
+ "870684092",
+ "-994630427",
+ "-1826738974",
+ "461928552",
+ "1374967188",
+ "1317234669",
+ "1129789963",
+ "312195995",
+ "1535930156",
+ "-1610317326",
+ "-721673697",
+ "1443186644",
+ "-643456149",
+ "1132307434",
+ "1240578589",
+ "379611602",
+ "2011416968",
+ "165842874",
+ "-570054451",
+ "893435720",
+ "835998817",
+ "1223423131",
+ "-1677568310",
+ "-230900360",
+ "-229961726",
+ "2113303164",
+ "201112068",
+ "452691328",
+ "-1980985397",
+ "675701869",
+ "-1802109191",
+ "-669843831"
+ ]
+ }
+ ]
+ }
+ }
+ ],
+ "batches": [
+ {
+ "count": 7,
+ "columns": [
+ {
+ "name": "dict0",
+ "count": 7,
+ "VALIDITY": [
+ 1,
+ 1,
+ 0,
+ 1,
+ 0,
+ 1,
+ 1
+ ],
+ "DATA": [
+ 7,
+ 6,
+ 3,
+ 1,
+ 2,
+ 9,
+ 1
+ ]
+ },
+ {
+ "name": "dict1",
+ "count": 7,
+ "VALIDITY": [
+ 1,
+ 1,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0
+ ],
+ "DATA": [
+ 0,
+ 0,
+ 3,
+ 3,
+ 4,
+ 2,
+ 3
+ ]
+ },
+ {
+ "name": "dict2",
+ "count": 7,
+ "VALIDITY": [
+ 0,
+ 1,
+ 0,
+ 1,
+ 1,
+ 0,
+ 1
+ ],
+ "DATA": [
+ 3,
+ 11,
+ 0,
+ 33,
+ 5,
+ 21,
+ 9
+ ]
+ }
+ ]
+ },
+ {
+ "count": 10,
+ "columns": [
+ {
+ "name": "dict0",
+ "count": 10,
+ "VALIDITY": [
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 1,
+ 0,
+ 1,
+ 1
+ ],
+ "DATA": [
+ 9,
+ 4,
+ 3,
+ 9,
+ 5,
+ 7,
+ 9,
+ 4,
+ 0,
+ 9
+ ]
+ },
+ {
+ "name": "dict1",
+ "count": 10,
+ "VALIDITY": [
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 1,
+ 1,
+ 1,
+ 0
+ ],
+ "DATA": [
+ 1,
+ 2,
+ 4,
+ 3,
+ 3,
+ 3,
+ 2,
+ 4,
+ 4,
+ 4
+ ]
+ },
+ {
+ "name": "dict2",
+ "count": 10,
+ "VALIDITY": [
+ 0,
+ 0,
+ 1,
+ 1,
+ 1,
+ 1,
+ 0,
+ 0,
+ 1,
+ 0
+ ],
+ "DATA": [
+ 24,
+ 26,
+ 39,
+ 4,
+ 23,
+ 23,
+ 6,
+ 28,
+ 9,
+ 49
+ ]
+ }
+ ]
+ }
+ ]
+ }`
+}
+
func makeExtensionsWantJSONs() string {
return `{
"schema": {
@@ -4112,7 +4538,27 @@ func makeExtensionsWantJSONs() string {
"name": "struct"
},
"nullable": true,
- "children": [],
+ "children": [
+ {
+ "name": "a",
+ "type": {
+ "name": "int",
+ "isSigned": true,
+ "bitWidth": 64
+ },
+ "nullable": false,
+ "children": []
+ },
+ {
+ "name": "b",
+ "type": {
+ "name": "floatingpoint",
+ "precision": "DOUBLE"
+ },
+ "nullable": false,
+ "children": []
+ }
+ ],
"metadata": [
{
"key": "k1",
diff --git a/go/arrow/internal/arrjson/reader.go b/go/arrow/internal/arrjson/reader.go
index 470b02099e..7b371ce45d 100644
--- a/go/arrow/internal/arrjson/reader.go
+++ b/go/arrow/internal/arrjson/reader.go
@@ -24,6 +24,7 @@ import (
"github.com/apache/arrow/go/v8/arrow"
"github.com/apache/arrow/go/v8/arrow/arrio"
"github.com/apache/arrow/go/v8/arrow/internal/debug"
+ "github.com/apache/arrow/go/v8/arrow/internal/dictutils"
)
type Reader struct {
@@ -31,6 +32,7 @@ type Reader struct {
schema *arrow.Schema
recs []arrow.Record
+ memo *dictutils.Memo
irec int // current record index. used for the arrio.Reader interface.
}
@@ -49,11 +51,14 @@ func NewReader(r io.Reader, opts ...Option) (*Reader, error) {
opt(cfg)
}
- schema := schemaFromJSON(raw.Schema)
+ memo := dictutils.NewMemo()
+ schema := schemaFromJSON(raw.Schema, &memo)
+ dictionariesFromJSON(cfg.alloc, raw.Dictionaries, &memo)
rr := &Reader{
refs: 1,
schema: schema,
- recs: recordsFromJSON(cfg.alloc, schema, raw.Records),
+ recs: recordsFromJSON(cfg.alloc, schema, raw.Records, &memo),
+ memo: &memo,
}
return rr, nil
}
diff --git a/go/arrow/internal/arrjson/writer.go b/go/arrow/internal/arrjson/writer.go
index 9bd654af9e..1ac3738c9a 100644
--- a/go/arrow/internal/arrjson/writer.go
+++ b/go/arrow/internal/arrjson/writer.go
@@ -18,10 +18,13 @@ package arrjson
import (
"encoding/json"
+ "fmt"
"io"
"github.com/apache/arrow/go/v8/arrow"
+ "github.com/apache/arrow/go/v8/arrow/array"
"github.com/apache/arrow/go/v8/arrow/arrio"
+ "github.com/apache/arrow/go/v8/arrow/internal/dictutils"
)
const (
@@ -31,27 +34,49 @@ const (
)
type rawJSON struct {
- Schema Schema `json:"schema"`
- Records []Record `json:"batches"`
+ Schema Schema `json:"schema"`
+ Records []Record `json:"batches"`
+ Dictionaries []Dictionary `json:"dictionaries,omitempty"`
}
type Writer struct {
w io.Writer
- nrecs int64
- raw rawJSON
+ nrecs int64
+ raw rawJSON
+ mapper dictutils.Mapper
}
func NewWriter(w io.Writer, schema *arrow.Schema) (*Writer, error) {
ww := &Writer{
w: w,
}
- ww.raw.Schema = schemaToJSON(schema)
+ ww.mapper.ImportSchema(schema)
+ ww.raw.Schema = schemaToJSON(schema, &ww.mapper)
ww.raw.Records = make([]Record, 0)
return ww, nil
}
func (w *Writer) Write(rec arrow.Record) error {
+ if w.nrecs == 0 {
+ pairs, err := dictutils.CollectDictionaries(rec, &w.mapper)
+ if err != nil {
+ return err
+ }
+
+ if len(pairs) > 0 {
+ w.raw.Dictionaries = make([]Dictionary, 0, len(pairs))
+ }
+
+ for _, p := range pairs {
+ defer p.Dict.Release()
+ sc := arrow.NewSchema([]arrow.Field{{Name: fmt.Sprintf("DICT%d", p.ID), Type: p.Dict.DataType(), Nullable: true}}, nil)
+ dummy := array.NewRecord(sc, []arrow.Array{p.Dict}, int64(p.Dict.Len()))
+ defer dummy.Release()
+ w.raw.Dictionaries = append(w.raw.Dictionaries, Dictionary{ID: p.ID, Data: recordToJSON(dummy)})
+ }
+ }
+
w.raw.Records = append(w.raw.Records, recordToJSON(rec))
w.nrecs++
return nil
diff --git a/go/arrow/internal/dictutils/dict.go b/go/arrow/internal/dictutils/dict.go
new file mode 100644
index 0000000000..99ec104765
--- /dev/null
+++ b/go/arrow/internal/dictutils/dict.go
@@ -0,0 +1,399 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dictutils
+
+import (
+ "errors"
+ "fmt"
+ "hash/maphash"
+
+ "github.com/apache/arrow/go/v8/arrow"
+ "github.com/apache/arrow/go/v8/arrow/array"
+ "github.com/apache/arrow/go/v8/arrow/memory"
+ "golang.org/x/xerrors"
+)
+
+type Kind int8
+
+const (
+ KindNew Kind = iota
+ KindDelta
+ KindReplacement
+)
+
+type FieldPos struct {
+ parent *FieldPos
+ index, depth int32
+}
+
+func NewFieldPos() FieldPos { return FieldPos{index: -1} }
+
+func (f *FieldPos) Child(index int32) FieldPos {
+ return FieldPos{parent: f, index: index, depth: f.depth + 1}
+}
+
+func (f *FieldPos) Path() []int32 {
+ path := make([]int32, f.depth)
+ cur := f
+ for i := f.depth - 1; i >= 0; i-- {
+ path[i] = int32(cur.index)
+ cur = cur.parent
+ }
+ return path
+}
+
+type Mapper struct {
+ pathToID map[uint64]int64
+ hasher maphash.Hash
+}
+
+func (d *Mapper) NumDicts() int {
+ unique := make(map[int64]bool)
+ for _, id := range d.pathToID {
+ unique[id] = true
+ }
+ return len(unique)
+}
+
+func (d *Mapper) AddField(id int64, fieldPath []int32) error {
+ d.hasher.Write(arrow.Int32Traits.CastToBytes(fieldPath))
+ defer d.hasher.Reset()
+
+ sum := d.hasher.Sum64()
+ if _, ok := d.pathToID[sum]; ok {
+ return errors.New("field already mapped to id")
+ }
+
+ d.pathToID[sum] = id
+ return nil
+}
+
+func (d *Mapper) GetFieldID(fieldPath []int32) (int64, error) {
+ d.hasher.Write(arrow.Int32Traits.CastToBytes(fieldPath))
+ defer d.hasher.Reset()
+
+ id, ok := d.pathToID[d.hasher.Sum64()]
+ if !ok {
+ return -1, errors.New("arrow/ipc: dictionary field not found")
+ }
+ return id, nil
+}
+
+func (d *Mapper) NumFields() int {
+ return len(d.pathToID)
+}
+
+func (d *Mapper) InsertPath(pos FieldPos) {
+ id := len(d.pathToID)
+ d.hasher.Write(arrow.Int32Traits.CastToBytes(pos.Path()))
+
+ d.pathToID[d.hasher.Sum64()] = int64(id)
+ d.hasher.Reset()
+}
+
+func (d *Mapper) ImportField(pos FieldPos, field *arrow.Field) {
+ dt := field.Type
+ if dt.ID() == arrow.EXTENSION {
+ dt = dt.(arrow.ExtensionType).StorageType()
+ }
+
+ if dt.ID() == arrow.DICTIONARY {
+ d.InsertPath(pos)
+ // import nested dicts
+ if nested, ok := dt.(*arrow.DictionaryType).ValueType.(arrow.NestedType); ok {
+ d.ImportFields(pos, nested.Fields())
+ }
+ return
+ }
+
+ if nested, ok := dt.(arrow.NestedType); ok {
+ d.ImportFields(pos, nested.Fields())
+ }
+}
+
+func (d *Mapper) ImportFields(pos FieldPos, fields []arrow.Field) {
+ for i := range fields {
+ d.ImportField(pos.Child(int32(i)), &fields[i])
+ }
+}
+
+func (d *Mapper) ImportSchema(schema *arrow.Schema) {
+ d.pathToID = make(map[uint64]int64)
+ d.ImportFields(NewFieldPos(), schema.Fields())
+}
+
+func hasUnresolvedNestedDict(data arrow.ArrayData) bool {
+ d := data.(*array.Data)
+ if d.DataType().ID() == arrow.DICTIONARY {
+ if d.Dictionary().(*array.Data) == nil {
+ return true
+ }
+ if hasUnresolvedNestedDict(d.Dictionary()) {
+ return true
+ }
+ }
+ for _, c := range d.Children() {
+ if hasUnresolvedNestedDict(c) {
+ return true
+ }
+ }
+ return false
+}
+
+type dictpair struct {
+ ID int64
+ Dict arrow.Array
+}
+
+type dictCollector struct {
+ dictionaries []dictpair
+ mapper *Mapper
+}
+
+func (d *dictCollector) visitChildren(pos FieldPos, typ arrow.DataType, arr arrow.Array) error {
+ for i, c := range arr.Data().Children() {
+ child := array.MakeFromData(c)
+ defer child.Release()
+ if err := d.visit(pos.Child(int32(i)), child); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (d *dictCollector) visit(pos FieldPos, arr arrow.Array) error {
+ dt := arr.DataType()
+ if dt.ID() == arrow.EXTENSION {
+ dt = dt.(arrow.ExtensionType).StorageType()
+ arr = arr.(array.ExtensionArray).Storage()
+ }
+
+ if dt.ID() == arrow.DICTIONARY {
+ dictarr := arr.(*array.Dictionary)
+ dict := dictarr.Dictionary()
+
+ // traverse the dictionary to first gather any nested dictionaries
+ // so they appear in the output before their respective parents
+ dictType := dt.(*arrow.DictionaryType)
+ d.visitChildren(pos, dictType.ValueType, dict)
+
+ id, err := d.mapper.GetFieldID(pos.Path())
+ if err != nil {
+ return err
+ }
+ dict.Retain()
+ d.dictionaries = append(d.dictionaries, dictpair{ID: id, Dict: dict})
+ return nil
+ }
+ return d.visitChildren(pos, dt, arr)
+}
+
+func (d *dictCollector) collect(batch arrow.Record) error {
+ var (
+ pos = NewFieldPos()
+ schema = batch.Schema()
+ )
+ d.dictionaries = make([]dictpair, 0, d.mapper.NumFields())
+ for i := range schema.Fields() {
+ if err := d.visit(pos.Child(int32(i)), batch.Column(i)); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+type dictMap map[int64][]arrow.ArrayData
+type dictTypeMap map[int64]arrow.DataType
+
+type Memo struct {
+ Mapper Mapper
+ dict2id map[arrow.ArrayData]int64
+
+ id2type dictTypeMap
+ id2dict dictMap // map of dictionary ID to dictionary array
+}
+
+func NewMemo() Memo {
+ return Memo{
+ dict2id: make(map[arrow.ArrayData]int64),
+ id2dict: make(dictMap),
+ id2type: make(dictTypeMap),
+ Mapper: Mapper{
+ pathToID: make(map[uint64]int64),
+ },
+ }
+}
+
+func (memo *Memo) Len() int { return len(memo.id2dict) }
+
+func (memo *Memo) Clear() {
+ for id, v := range memo.id2dict {
+ delete(memo.id2dict, id)
+ for _, d := range v {
+ delete(memo.dict2id, d)
+ d.Release()
+ }
+ }
+}
+
+func (memo *Memo) reify(id int64, mem memory.Allocator) (arrow.ArrayData, error) {
+ v, ok := memo.id2dict[id]
+ if !ok {
+ return nil, fmt.Errorf("arrow/ipc: no dictionaries found for id=%d", id)
+ }
+
+ if len(v) == 1 {
+ return v[0], nil
+ }
+
+ // there are deltas we need to concatenate them with the first dictionary
+ toCombine := make([]arrow.Array, 0, len(v))
+ // NOTE: at this point the dictionary data may not be trusted. it needs to
+ // be validated as concatenation can crash on invalid or corrupted data.
+ for _, data := range v {
+ if hasUnresolvedNestedDict(data) {
+ return nil, fmt.Errorf("arrow/ipc: delta dict with unresolved nested dictionary not implemented")
+ }
+ arr := array.MakeFromData(data)
+ defer arr.Release()
+
+ toCombine = append(toCombine, arr)
+ defer data.Release()
+ }
+
+ combined, err := array.Concatenate(toCombine, mem)
+ if err != nil {
+ return nil, err
+ }
+ defer combined.Release()
+ combined.Data().Retain()
+
+ memo.id2dict[id] = []arrow.ArrayData{combined.Data()}
+ return combined.Data(), nil
+}
+
+func (memo *Memo) Dict(id int64, mem memory.Allocator) (arrow.ArrayData, error) {
+ return memo.reify(id, mem)
+}
+
+func (memo *Memo) AddType(id int64, typ arrow.DataType) error {
+ if existing, dup := memo.id2type[id]; dup && !arrow.TypeEqual(existing, typ) {
+ return fmt.Errorf("arrow/ipc: conflicting dictionary types for id %d", id)
+ }
+
+ memo.id2type[id] = typ
+ return nil
+}
+
+func (memo *Memo) Type(id int64) (arrow.DataType, bool) {
+ t, ok := memo.id2type[id]
+ return t, ok
+}
+
+// func (memo *dictMemo) ID(v arrow.Array) int64 {
+// id, ok := memo.dict2id[v]
+// if ok {
+// return id
+// }
+
+// v.Retain()
+// id = int64(len(memo.dict2id))
+// memo.dict2id[v] = id
+// memo.id2dict[id] = v
+// return id
+// }
+
+func (memo Memo) HasDict(v arrow.ArrayData) bool {
+ _, ok := memo.dict2id[v]
+ return ok
+}
+
+func (memo Memo) HasID(id int64) bool {
+ _, ok := memo.id2dict[id]
+ return ok
+}
+
+func (memo *Memo) Add(id int64, v arrow.ArrayData) {
+ if _, dup := memo.id2dict[id]; dup {
+ panic(xerrors.Errorf("arrow/ipc: duplicate id=%d", id))
+ }
+ v.Retain()
+ memo.id2dict[id] = []arrow.ArrayData{v}
+ memo.dict2id[v] = id
+}
+
+func (memo *Memo) AddDelta(id int64, v arrow.ArrayData) {
+ d, ok := memo.id2dict[id]
+ if !ok {
+ panic(fmt.Errorf("arrow/ipc: adding delta to non-existing id=%d", id))
+ }
+ v.Retain()
+ memo.id2dict[id] = append(d, v)
+}
+
+func (memo *Memo) AddOrReplace(id int64, v arrow.ArrayData) bool {
+ d, ok := memo.id2dict[id]
+ if ok {
+ d = append(d, v)
+ } else {
+ d = []arrow.ArrayData{v}
+ }
+ v.Retain()
+ memo.id2dict[id] = d
+ return !ok
+}
+
+func CollectDictionaries(batch arrow.Record, mapper *Mapper) (out []dictpair, err error) {
+ collector := dictCollector{mapper: mapper}
+ err = collector.collect(batch)
+ out = collector.dictionaries
+ return
+}
+
+func ResolveFieldDict(memo *Memo, data arrow.ArrayData, pos FieldPos, mem memory.Allocator) error {
+ typ := data.DataType()
+ if typ.ID() == arrow.EXTENSION {
+ typ = typ.(arrow.ExtensionType).StorageType()
+ }
+ if typ.ID() == arrow.DICTIONARY {
+ id, err := memo.Mapper.GetFieldID(pos.Path())
+ if err != nil {
+ return err
+ }
+ dictData, err := memo.Dict(id, mem)
+ if err != nil {
+ return err
+ }
+ data.(*array.Data).SetDictionary(dictData)
+ if err := ResolveFieldDict(memo, dictData, pos, mem); err != nil {
+ return err
+ }
+ }
+ return ResolveDictionaries(memo, data.Children(), pos, mem)
+}
+
+func ResolveDictionaries(memo *Memo, cols []arrow.ArrayData, parentPos FieldPos, mem memory.Allocator) error {
+ for i, c := range cols {
+ if c == nil {
+ continue
+ }
+ if err := ResolveFieldDict(memo, c, parentPos.Child(int32(i)), mem); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/go/arrow/ipc/dict_test.go b/go/arrow/internal/dictutils/dict_test.go
similarity index 77%
rename from go/arrow/ipc/dict_test.go
rename to go/arrow/internal/dictutils/dict_test.go
index 5e42ae7f0f..d5b89db2cc 100644
--- a/go/arrow/ipc/dict_test.go
+++ b/go/arrow/internal/dictutils/dict_test.go
@@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package ipc
+package dictutils_test
import (
"fmt"
@@ -22,6 +22,7 @@ import (
"github.com/apache/arrow/go/v8/arrow"
"github.com/apache/arrow/go/v8/arrow/array"
+ "github.com/apache/arrow/go/v8/arrow/internal/dictutils"
"github.com/apache/arrow/go/v8/arrow/memory"
)
@@ -44,15 +45,15 @@ func TestDictMemo(t *testing.T) {
f2 := bldr.NewFloat64Array()
defer f2.Release()
- memo := newMemo()
- defer memo.delete()
+ memo := dictutils.NewMemo()
+ defer memo.Clear()
if got, want := memo.Len(), 0; got != want {
t.Fatalf("invalid length: got=%d, want=%d", got, want)
}
- memo.Add(0, f0)
- memo.Add(1, f1)
+ memo.Add(0, f0.Data())
+ memo.Add(1, f1.Data())
if !memo.HasID(0) {
t.Fatalf("could not find id=0")
@@ -69,17 +70,17 @@ func TestDictMemo(t *testing.T) {
var ff arrow.Array
ff = f0
- if !memo.HasDict(ff) {
+ if !memo.HasDict(ff.Data()) {
t.Fatalf("failed to find f0 through interface")
}
ff = f1
- if !memo.HasDict(ff) {
+ if !memo.HasDict(ff.Data()) {
t.Fatalf("failed to find f1 through interface")
}
ff = f2
- if memo.HasDict(ff) {
+ if memo.HasDict(ff.Data()) {
t.Fatalf("should not have found f2")
}
@@ -87,59 +88,42 @@ func TestDictMemo(t *testing.T) {
return v
}
- if !memo.HasDict(fct(f1)) {
+ if !memo.HasDict(fct(f1).Data()) {
t.Fatalf("failed to find dict through func through interface")
}
- if memo.HasDict(f2) {
+ if memo.HasDict(f2.Data()) {
t.Fatalf("should not have found f2")
}
ff = f0
for i, f := range []arrow.Array{f0, f1, ff, fct(f0), fct(f1)} {
- if !memo.HasDict(f) {
+ if !memo.HasDict(f.Data()) {
t.Fatalf("failed to find dict %d", i)
}
}
- v, ok := memo.Dict(0)
- if !ok {
+ v, err := memo.Dict(0, mem)
+ if err != nil {
t.Fatalf("expected to find id=0")
}
- if v != f0 {
+ if v != f0.Data() {
t.Fatalf("expected fo find id=0 array")
}
- _, ok = memo.Dict(2)
- if ok {
+ _, err = memo.Dict(2, mem)
+ if err == nil {
t.Fatalf("should not have found id=2")
}
- _, ok = memo.Dict(-2)
- if ok {
+ _, err = memo.Dict(-2, mem)
+ if err == nil {
t.Fatalf("should not have found id=-2")
}
- if got, want := memo.ID(f0), int64(0); got != want {
- t.Fatalf("found invalid id. got=%d, want=%d", got, want)
- }
-
- if got, want := memo.ID(f2), int64(2); got != want {
- t.Fatalf("found invalid id. got=%d, want=%d", got, want)
- }
- if !memo.HasDict(f2) {
- t.Fatalf("should have found f2")
- }
-
// test we don't leak nor "double-delete" when adding an array multiple times.
- memo.Add(42, f2)
- if got, want := memo.ID(f2), int64(42); got != want {
- t.Fatalf("found invalid id. got=%d, want=%d", got, want)
- }
- memo.Add(43, f2)
- if got, want := memo.ID(f2), int64(43); got != want {
- t.Fatalf("found invalid id. got=%d, want=%d", got, want)
- }
- if got, want := memo.Len(), 5; got != want {
+ memo.Add(42, f2.Data())
+ memo.Add(43, f2.Data())
+ if got, want := memo.Len(), 4; got != want {
t.Fatalf("invalid length. got=%d, want=%d", got, want)
}
}
@@ -183,15 +167,15 @@ func TestDictMemoPanics(t *testing.T) {
}
}()
- memo := newMemo()
- defer memo.delete()
+ memo := dictutils.NewMemo()
+ defer memo.Clear()
if got, want := memo.Len(), 0; got != want {
t.Fatalf("invalid length: got=%d, want=%d", got, want)
}
- memo.Add(tc.ids[0], tc.vs[0])
- memo.Add(tc.ids[1], tc.vs[1])
+ memo.Add(tc.ids[0], tc.vs[0].Data())
+ memo.Add(tc.ids[1], tc.vs[1].Data())
})
}
}
diff --git a/go/arrow/ipc/dict.go b/go/arrow/ipc/dict.go
deleted file mode 100644
index e3d4637bc2..0000000000
--- a/go/arrow/ipc/dict.go
+++ /dev/null
@@ -1,85 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package ipc
-
-import (
- "fmt"
-
- "github.com/apache/arrow/go/v8/arrow"
-)
-
-type dictMap map[int64]arrow.Array
-type dictTypeMap map[int64]arrow.Field
-
-type dictMemo struct {
- dict2id map[arrow.Array]int64
- id2dict dictMap // map of dictionary ID to dictionary array
-}
-
-func newMemo() dictMemo {
- return dictMemo{
- dict2id: make(map[arrow.Array]int64),
- id2dict: make(dictMap),
- }
-}
-
-func (memo *dictMemo) Len() int { return len(memo.id2dict) }
-
-func (memo *dictMemo) delete() {
- for id, v := range memo.id2dict {
- delete(memo.id2dict, id)
- delete(memo.dict2id, v)
- v.Release()
- }
-}
-
-func (memo dictMemo) Dict(id int64) (arrow.Array, bool) {
- v, ok := memo.id2dict[id]
- return v, ok
-}
-
-func (memo *dictMemo) ID(v arrow.Array) int64 {
- id, ok := memo.dict2id[v]
- if ok {
- return id
- }
-
- v.Retain()
- id = int64(len(memo.dict2id))
- memo.dict2id[v] = id
- memo.id2dict[id] = v
- return id
-}
-
-func (memo dictMemo) HasDict(v arrow.Array) bool {
- _, ok := memo.dict2id[v]
- return ok
-}
-
-func (memo dictMemo) HasID(id int64) bool {
- _, ok := memo.id2dict[id]
- return ok
-}
-
-func (memo *dictMemo) Add(id int64, v arrow.Array) {
- if _, dup := memo.id2dict[id]; dup {
- panic(fmt.Errorf("arrow/ipc: duplicate id=%d", id))
- }
- v.Retain()
- memo.id2dict[id] = v
- memo.dict2id[v] = id
-}
diff --git a/go/arrow/ipc/file_reader.go b/go/arrow/ipc/file_reader.go
index ff562e4878..95abc0c940 100644
--- a/go/arrow/ipc/file_reader.go
+++ b/go/arrow/ipc/file_reader.go
@@ -19,12 +19,14 @@ package ipc
import (
"bytes"
"encoding/binary"
+ "errors"
"fmt"
"io"
"github.com/apache/arrow/go/v8/arrow"
"github.com/apache/arrow/go/v8/arrow/array"
"github.com/apache/arrow/go/v8/arrow/bitutil"
+ "github.com/apache/arrow/go/v8/arrow/internal/dictutils"
"github.com/apache/arrow/go/v8/arrow/internal/flatbuf"
"github.com/apache/arrow/go/v8/arrow/memory"
)
@@ -39,8 +41,8 @@ type FileReader struct {
data *flatbuf.Footer
}
- fields dictTypeMap
- memo dictMemo
+ // fields dictTypeMap
+ memo dictutils.Memo
schema *arrow.Schema
record arrow.Record
@@ -58,10 +60,9 @@ func NewFileReader(r ReadAtSeeker, opts ...Option) (*FileReader, error) {
err error
f = FileReader{
- r: r,
- fields: make(dictTypeMap),
- memo: newMemo(),
- mem: cfg.alloc,
+ r: r,
+ memo: dictutils.NewMemo(),
+ mem: cfg.alloc,
}
)
@@ -131,17 +132,24 @@ func (f *FileReader) readFooter() error {
}
func (f *FileReader) readSchema() error {
- var err error
- f.fields, err = dictTypesFromFB(f.footer.data.Schema(nil))
+ var (
+ err error
+ kind dictutils.Kind
+ )
+
+ schema := f.footer.data.Schema(nil)
+ if schema == nil {
+ return fmt.Errorf("arrow/ipc: could not load schema from flatbuffer data")
+ }
+ f.schema, err = schemaFromFB(schema, &f.memo)
if err != nil {
- return fmt.Errorf("arrow/ipc: could not load dictionary types from file: %w", err)
+ return fmt.Errorf("arrow/ipc: could not read schema: %w", err)
}
- //lint:ignore SA4008 readDictionary always panics currently. ignore lint until DictionaryArray is implemented.
for i := 0; i < f.NumDictionaries(); i++ {
blk, err := f.dict(i)
if err != nil {
- return fmt.Errorf("arrow/ipc: could read dictionary[%d]: %w", i, err)
+ return fmt.Errorf("arrow/ipc: could not read dictionary[%d]: %w", i, err)
}
switch {
case !bitutil.IsMultipleOf8(blk.Offset):
@@ -157,22 +165,13 @@ func (f *FileReader) readSchema() error {
return err
}
- id, dict, err := readDictionary(msg.meta, f.fields, f.r)
- msg.Release()
+ kind, err = readDictionary(&f.memo, msg.meta, bytes.NewReader(msg.body.Bytes()), f.mem)
if err != nil {
- return fmt.Errorf("arrow/ipc: could not read dictionary %d from file: %w", i, err)
+ return err
+ }
+ if kind == dictutils.KindReplacement {
+ return errors.New("arrow/ipc: unsupported dictionary replacement in IPC file")
}
- f.memo.Add(id, dict)
- dict.Release() // memo.Add increases ref-count of dict.
- }
-
- schema := f.footer.data.Schema(nil)
- if schema == nil {
- return fmt.Errorf("arrow/ipc: could not load schema from flatbuffer data")
- }
- f.schema, err = schemaFromFB(schema, &f.memo)
- if err != nil {
- return fmt.Errorf("arrow/ipc: could not read schema: %w", err)
}
return err
@@ -292,7 +291,7 @@ func (f *FileReader) RecordAt(i int) (arrow.Record, error) {
return nil, fmt.Errorf("arrow/ipc: message %d is not a Record", i)
}
- return newRecord(f.schema, msg.meta, bytes.NewReader(msg.body.Bytes()), f.mem), nil
+ return newRecord(f.schema, &f.memo, msg.meta, bytes.NewReader(msg.body.Bytes()), f.mem), nil
}
// Read reads the current record from the underlying stream and an error, if any.
@@ -314,7 +313,7 @@ func (f *FileReader) ReadAt(i int64) (arrow.Record, error) {
return f.Record(int(i))
}
-func newRecord(schema *arrow.Schema, meta *memory.Buffer, body ReadAtSeeker, mem memory.Allocator) arrow.Record {
+func newRecord(schema *arrow.Schema, memo *dictutils.Memo, meta *memory.Buffer, body ReadAtSeeker, mem memory.Allocator) arrow.Record {
var (
msg = flatbuf.GetRootAsMessage(meta.Bytes(), 0)
md flatbuf.RecordBatch
@@ -336,12 +335,21 @@ func newRecord(schema *arrow.Schema, meta *memory.Buffer, body ReadAtSeeker, mem
codec: codec,
mem: mem,
},
- max: kMaxNestingDepth,
+ memo: memo,
+ max: kMaxNestingDepth,
}
+ pos := dictutils.NewFieldPos()
cols := make([]arrow.Array, len(schema.Fields()))
for i, field := range schema.Fields() {
- cols[i] = ctx.loadArray(field.Type)
+ data := ctx.loadArray(field.Type)
+ defer data.Release()
+
+ if err := dictutils.ResolveFieldDict(memo, data, pos.Child(int32(i)), mem); err != nil {
+ panic(err)
+ }
+
+ cols[i] = array.MakeFromData(data)
defer cols[i].Release()
}
@@ -411,6 +419,7 @@ type arrayLoaderContext struct {
ifield int
ibuffer int
max int
+ memo *dictutils.Memo
}
func (ctx *arrayLoaderContext) field() *flatbuf.FieldNode {
@@ -425,11 +434,16 @@ func (ctx *arrayLoaderContext) buffer() *memory.Buffer {
return buf
}
-func (ctx *arrayLoaderContext) loadArray(dt arrow.DataType) arrow.Array {
+func (ctx *arrayLoaderContext) loadArray(dt arrow.DataType) arrow.ArrayData {
switch dt := dt.(type) {
case *arrow.NullType:
return ctx.loadNull()
+ case *arrow.DictionaryType:
+ indices := ctx.loadPrimitive(dt.IndexType)
+ defer indices.Release()
+ return array.NewData(dt, indices.Len(), indices.Buffers(), indices.Children(), indices.NullN(), indices.Offset())
+
case *arrow.BooleanType,
*arrow.Int8Type, *arrow.Int16Type, *arrow.Int32Type, *arrow.Int64Type,
*arrow.Uint8Type, *arrow.Uint16Type, *arrow.Uint32Type, *arrow.Uint64Type,
@@ -463,7 +477,7 @@ func (ctx *arrayLoaderContext) loadArray(dt arrow.DataType) arrow.Array {
case arrow.ExtensionType:
storage := ctx.loadArray(dt.StorageType())
defer storage.Release()
- return array.NewExtensionArrayWithStorage(dt, storage)
+ return array.NewData(dt, storage.Len(), storage.Buffers(), storage.Children(), storage.NullN(), storage.Offset())
default:
panic(fmt.Errorf("array type %T not handled yet", dt))
@@ -486,7 +500,7 @@ func (ctx *arrayLoaderContext) loadCommon(nbufs int) (*flatbuf.FieldNode, []*mem
return field, buffers
}
-func (ctx *arrayLoaderContext) loadChild(dt arrow.DataType) arrow.Array {
+func (ctx *arrayLoaderContext) loadChild(dt arrow.DataType) arrow.ArrayData {
if ctx.max == 0 {
panic("arrow/ipc: nested type limit reached")
}
@@ -496,15 +510,12 @@ func (ctx *arrayLoaderContext) loadChild(dt arrow.DataType) arrow.Array {
return sub
}
-func (ctx *arrayLoaderContext) loadNull() arrow.Array {
+func (ctx *arrayLoaderContext) loadNull() arrow.ArrayData {
field := ctx.field()
- data := array.NewData(arrow.Null, int(field.Length()), nil, nil, int(field.NullCount()), 0)
- defer data.Release()
-
- return array.MakeFromData(data)
+ return array.NewData(arrow.Null, int(field.Length()), nil, nil, int(field.NullCount()), 0)
}
-func (ctx *arrayLoaderContext) loadPrimitive(dt arrow.DataType) arrow.Array {
+func (ctx *arrayLoaderContext) loadPrimitive(dt arrow.DataType) arrow.ArrayData {
field, buffers := ctx.loadCommon(2)
switch field.Length() {
@@ -517,35 +528,26 @@ func (ctx *arrayLoaderContext) loadPrimitive(dt arrow.DataType) arrow.Array {
defer releaseBuffers(buffers)
- data := array.NewData(dt, int(field.Length()), buffers, nil, int(field.NullCount()), 0)
- defer data.Release()
-
- return array.MakeFromData(data)
+ return array.NewData(dt, int(field.Length()), buffers, nil, int(field.NullCount()), 0)
}
-func (ctx *arrayLoaderContext) loadBinary(dt arrow.DataType) arrow.Array {
+func (ctx *arrayLoaderContext) loadBinary(dt arrow.DataType) arrow.ArrayData {
field, buffers := ctx.loadCommon(3)
buffers = append(buffers, ctx.buffer(), ctx.buffer())
defer releaseBuffers(buffers)
- data := array.NewData(dt, int(field.Length()), buffers, nil, int(field.NullCount()), 0)
- defer data.Release()
-
- return array.MakeFromData(data)
+ return array.NewData(dt, int(field.Length()), buffers, nil, int(field.NullCount()), 0)
}
-func (ctx *arrayLoaderContext) loadFixedSizeBinary(dt *arrow.FixedSizeBinaryType) arrow.Array {
+func (ctx *arrayLoaderContext) loadFixedSizeBinary(dt *arrow.FixedSizeBinaryType) arrow.ArrayData {
field, buffers := ctx.loadCommon(2)
buffers = append(buffers, ctx.buffer())
defer releaseBuffers(buffers)
- data := array.NewData(dt, int(field.Length()), buffers, nil, int(field.NullCount()), 0)
- defer data.Release()
-
- return array.MakeFromData(data)
+ return array.NewData(dt, int(field.Length()), buffers, nil, int(field.NullCount()), 0)
}
-func (ctx *arrayLoaderContext) loadMap(dt *arrow.MapType) arrow.Array {
+func (ctx *arrayLoaderContext) loadMap(dt *arrow.MapType) arrow.ArrayData {
field, buffers := ctx.loadCommon(2)
buffers = append(buffers, ctx.buffer())
defer releaseBuffers(buffers)
@@ -553,13 +555,10 @@ func (ctx *arrayLoaderContext) loadMap(dt *arrow.MapType) arrow.Array {
sub := ctx.loadChild(dt.ValueType())
defer sub.Release()
- data := array.NewData(dt, int(field.Length()), buffers, []arrow.ArrayData{sub.Data()}, int(field.NullCount()), 0)
- defer data.Release()
-
- return array.NewMapData(data)
+ return array.NewData(dt, int(field.Length()), buffers, []arrow.ArrayData{sub}, int(field.NullCount()), 0)
}
-func (ctx *arrayLoaderContext) loadList(dt *arrow.ListType) arrow.Array {
+func (ctx *arrayLoaderContext) loadList(dt *arrow.ListType) arrow.ArrayData {
field, buffers := ctx.loadCommon(2)
buffers = append(buffers, ctx.buffer())
defer releaseBuffers(buffers)
@@ -567,88 +566,81 @@ func (ctx *arrayLoaderContext) loadList(dt *arrow.ListType) arrow.Array {
sub := ctx.loadChild(dt.Elem())
defer sub.Release()
- data := array.NewData(dt, int(field.Length()), buffers, []arrow.ArrayData{sub.Data()}, int(field.NullCount()), 0)
- defer data.Release()
-
- return array.NewListData(data)
+ return array.NewData(dt, int(field.Length()), buffers, []arrow.ArrayData{sub}, int(field.NullCount()), 0)
}
-func (ctx *arrayLoaderContext) loadFixedSizeList(dt *arrow.FixedSizeListType) arrow.Array {
+func (ctx *arrayLoaderContext) loadFixedSizeList(dt *arrow.FixedSizeListType) arrow.ArrayData {
field, buffers := ctx.loadCommon(1)
defer releaseBuffers(buffers)
sub := ctx.loadChild(dt.Elem())
defer sub.Release()
- data := array.NewData(dt, int(field.Length()), buffers, []arrow.ArrayData{sub.Data()}, int(field.NullCount()), 0)
- defer data.Release()
-
- return array.NewFixedSizeListData(data)
+ return array.NewData(dt, int(field.Length()), buffers, []arrow.ArrayData{sub}, int(field.NullCount()), 0)
}
-func (ctx *arrayLoaderContext) loadStruct(dt *arrow.StructType) arrow.Array {
+func (ctx *arrayLoaderContext) loadStruct(dt *arrow.StructType) arrow.ArrayData {
field, buffers := ctx.loadCommon(1)
defer releaseBuffers(buffers)
- arrs := make([]arrow.Array, len(dt.Fields()))
subs := make([]arrow.ArrayData, len(dt.Fields()))
for i, f := range dt.Fields() {
- arrs[i] = ctx.loadChild(f.Type)
- subs[i] = arrs[i].Data()
+ subs[i] = ctx.loadChild(f.Type)
}
defer func() {
- for i := range arrs {
- arrs[i].Release()
+ for i := range subs {
+ subs[i].Release()
}
}()
- data := array.NewData(dt, int(field.Length()), buffers, subs, int(field.NullCount()), 0)
- defer data.Release()
-
- return array.NewStructData(data)
-}
-
-func readDictionary(meta *memory.Buffer, types dictTypeMap, r ReadAtSeeker) (int64, arrow.Array, error) {
- // msg := flatbuf.GetRootAsMessage(meta.Bytes(), 0)
- // var dictBatch flatbuf.DictionaryBatch
- // initFB(&dictBatch, msg.Header)
- //
- // id := dictBatch.Id()
- // v, ok := types[id]
- // if !ok {
- // return id, nil, errors.Errorf("arrow/ipc: no type metadata for dictionary with ID=%d", id)
- // }
- //
- // fields := []arrow.Field{v}
- //
- // // we need a schema for the record batch.
- // schema := arrow.NewSchema(fields, nil)
- //
- // // the dictionary is embedded in a record batch with a single column.
- // recBatch := dictBatch.Data(nil)
- //
- // var (
- // batchMeta *memory.Buffer
- // body *memory.Buffer
- // )
- //
- //
- // ctx := &arrayLoaderContext{
- // src: ipcSource{
- // meta: &md,
- // r: bytes.NewReader(body.Bytes()),
- // },
- // max: kMaxNestingDepth,
- // }
- //
- // cols := make([]arrow.Array, len(schema.Fields()))
- // for i, field := range schema.Fields() {
- // cols[i] = ctx.loadArray(field.Type)
- // }
- //
- // batch := array.NewRecord(schema, cols, rows)
-
- panic("not implemented")
+ return array.NewData(dt, int(field.Length()), buffers, subs, int(field.NullCount()), 0)
+}
+
+func readDictionary(memo *dictutils.Memo, meta *memory.Buffer, body ReadAtSeeker, mem memory.Allocator) (dictutils.Kind, error) {
+ var (
+ msg = flatbuf.GetRootAsMessage(meta.Bytes(), 0)
+ md flatbuf.DictionaryBatch
+ data flatbuf.RecordBatch
+ codec decompressor
+ )
+ initFB(&md, msg.Header)
+
+ md.Data(&data)
+ bodyCompress := data.Compression(nil)
+ if bodyCompress != nil {
+ codec = getDecompressor(bodyCompress.Codec())
+ }
+
+ id := md.Id()
+ // look up the dictionary value type, which must have been added to the
+ // memo already before calling this function
+ valueType, ok := memo.Type(id)
+ if !ok {
+ return 0, fmt.Errorf("arrow/ipc: no dictionary type found with id: %d", id)
+ }
+
+ ctx := &arrayLoaderContext{
+ src: ipcSource{
+ meta: &data,
+ codec: codec,
+ r: body,
+ mem: mem,
+ },
+ memo: memo,
+ max: kMaxNestingDepth,
+ }
+
+ dict := ctx.loadArray(valueType)
+ defer dict.Release()
+
+ if md.IsDelta() {
+ memo.AddDelta(id, dict)
+ return dictutils.KindDelta, nil
+ }
+ if memo.AddOrReplace(id, dict) {
+ return dictutils.KindNew, nil
+ }
+ return dictutils.KindReplacement, nil
}
func releaseBuffers(buffers []*memory.Buffer) {
diff --git a/go/arrow/ipc/file_writer.go b/go/arrow/ipc/file_writer.go
index 80d7bd9eb5..f1082e8ffb 100644
--- a/go/arrow/ipc/file_writer.go
+++ b/go/arrow/ipc/file_writer.go
@@ -23,6 +23,7 @@ import (
"github.com/apache/arrow/go/v8/arrow"
"github.com/apache/arrow/go/v8/arrow/bitutil"
+ "github.com/apache/arrow/go/v8/arrow/internal/dictutils"
"github.com/apache/arrow/go/v8/arrow/internal/flatbuf"
"github.com/apache/arrow/go/v8/arrow/memory"
)
@@ -274,8 +275,15 @@ type FileWriter struct {
pw PayloadWriter
schema *arrow.Schema
+ mapper dictutils.Mapper
codec flatbuf.CompressionType
compressNP int
+
+ // map of the last written dictionaries by id
+ // so we can avoid writing the same dictionary over and over
+ // also needed for correctness when writing IPC format which
+ // does not allow replacements or deltas.
+ lastWrittenDicts map[int64]arrow.Array
}
// NewFileWriter opens an Arrow file using the provided writer w.
@@ -339,6 +347,12 @@ func (f *FileWriter) Write(rec arrow.Record) error {
)
defer data.Release()
+ err := writeDictionaryPayloads(f.mem, rec, true, false, &f.mapper, f.lastWrittenDicts, f.pw, enc)
+ if err != nil {
+ return fmt.Errorf("arrow/ipc: failure writing dictionary batches: %w", err)
+ }
+
+ enc.reset()
if err := enc.Encode(&data, rec); err != nil {
return fmt.Errorf("arrow/ipc: could not encode record to payload: %w", err)
}
@@ -360,8 +374,11 @@ func (f *FileWriter) start() error {
return err
}
+ f.mapper.ImportSchema(f.schema)
+ f.lastWrittenDicts = make(map[int64]arrow.Array)
+
// write out schema payloads
- ps := payloadsFromSchema(f.schema, f.mem, nil)
+ ps := payloadFromSchema(f.schema, f.mem, &f.mapper)
defer ps.Release()
for _, data := range ps {
diff --git a/go/arrow/ipc/ipc_test.go b/go/arrow/ipc/ipc_test.go
index 8bcf591590..bf6e3de9b1 100644
--- a/go/arrow/ipc/ipc_test.go
+++ b/go/arrow/ipc/ipc_test.go
@@ -274,3 +274,59 @@ func TestWriteColumnWithOffset(t *testing.T) {
}
})
}
+
+func TestIPCTable(t *testing.T) {
+ pool := memory.NewGoAllocator()
+ schema := arrow.NewSchema([]arrow.Field{{Name: "f1", Type: arrow.PrimitiveTypes.Int32}}, nil)
+ b := array.NewRecordBuilder(pool, schema)
+ defer b.Release()
+ b.Field(0).(*array.Int32Builder).AppendValues([]int32{1, 2, 3, 4}, []bool{true, true, false, true})
+
+ rec1 := b.NewRecord()
+ defer rec1.Release()
+
+ tbl := array.NewTableFromRecords(schema, []arrow.Record{rec1})
+ defer tbl.Release()
+
+ var buf bytes.Buffer
+ ipcWriter := ipc.NewWriter(&buf, ipc.WithAllocator(pool), ipc.WithSchema(schema))
+ defer func(ipcWriter *ipc.Writer) {
+ err := ipcWriter.Close()
+ if err != nil {
+ t.Fatalf("error closing ipc writer: %s", err.Error())
+ }
+ }(ipcWriter)
+
+ t.Log("Reading data before")
+ tr := array.NewTableReader(tbl, 2)
+ defer tr.Release()
+
+ n := 0
+ for tr.Next() {
+ rec := tr.Record()
+ for i, col := range rec.Columns() {
+ t.Logf("rec[%d][%q]: %v nulls:%v\n", n,
+ rec.ColumnName(i), col, col.NullBitmapBytes())
+ }
+ n++
+ err := ipcWriter.Write(rec)
+ if err != nil {
+ panic(err)
+ }
+ }
+
+ t.Log("Reading data after")
+ ipcReader, err := ipc.NewReader(bytes.NewReader(buf.Bytes()), ipc.WithAllocator(pool))
+ if err != nil {
+ panic(err)
+ }
+ n = 0
+ for ipcReader.Next() {
+ rec := ipcReader.Record()
+ for i, col := range rec.Columns() {
+ t.Logf("rec[%d][%q]: %v nulls:%v\n", n,
+ rec.ColumnName(i), col, col.NullBitmapBytes())
+ }
+ n++
+ }
+}
diff --git a/go/arrow/ipc/metadata.go b/go/arrow/ipc/metadata.go
index f05d970ac4..59dbb610c1 100644
--- a/go/arrow/ipc/metadata.go
+++ b/go/arrow/ipc/metadata.go
@@ -23,6 +23,7 @@ import (
"sort"
"github.com/apache/arrow/go/v8/arrow"
+ "github.com/apache/arrow/go/v8/arrow/internal/dictutils"
"github.com/apache/arrow/go/v8/arrow/internal/flatbuf"
"github.com/apache/arrow/go/v8/arrow/memory"
flatbuffers "github.com/google/flatbuffers/go"
@@ -159,7 +160,7 @@ func initFB(t interface {
t.Init(tbl.Bytes, tbl.Pos)
}
-func fieldFromFB(field *flatbuf.Field, memo *dictMemo) (arrow.Field, error) {
+func fieldFromFB(field *flatbuf.Field, pos dictutils.FieldPos, memo *dictutils.Memo) (arrow.Field, error) {
var (
err error
o arrow.Field
@@ -172,42 +173,38 @@ func fieldFromFB(field *flatbuf.Field, memo *dictMemo) (arrow.Field, error) {
return o, err
}
- encoding := field.Dictionary(nil)
- switch encoding {
- case nil:
- n := field.ChildrenLength()
- children := make([]arrow.Field, n)
- for i := range children {
- var childFB flatbuf.Field
- if !field.Children(&childFB, i) {
- return o, fmt.Errorf("arrow/ipc: could not load field child %d", i)
- }
- child, err := fieldFromFB(&childFB, memo)
- if err != nil {
- return o, fmt.Errorf("arrow/ipc: could not convert field child %d: %w", i, err)
- }
- children[i] = child
- }
+ n := field.ChildrenLength()
+ children := make([]arrow.Field, n)
+ for i := range children {
+ var childFB flatbuf.Field
+ if !field.Children(&childFB, i) {
+ return o, fmt.Errorf("arrow/ipc: could not load field child %d", i)
- o.Type, err = typeFromFB(field, children, &o.Metadata)
+ }
+ child, err := fieldFromFB(&childFB, pos.Child(int32(i)), memo)
if err != nil {
- return o, fmt.Errorf("arrow/ipc: could not convert field type: %w", err)
+ return o, fmt.Errorf("arrow/ipc: could not convert field child %d: %w", i, err)
}
- default:
- panic("not implemented") // FIXME(sbinet)
+ children[i] = child
+ }
+
+ o.Type, err = typeFromFB(field, pos, children, &o.Metadata, memo)
+ if err != nil {
+ return o, fmt.Errorf("arrow/ipc: could not convert field type: %w", err)
}
return o, nil
}
-func fieldToFB(b *flatbuffers.Builder, field arrow.Field, memo *dictMemo) flatbuffers.UOffsetT {
- var visitor = fieldVisitor{b: b, memo: memo, meta: make(map[string]string)}
+func fieldToFB(b *flatbuffers.Builder, pos dictutils.FieldPos, field arrow.Field, memo *dictutils.Mapper) flatbuffers.UOffsetT {
+ var visitor = fieldVisitor{b: b, memo: memo, pos: pos, meta: make(map[string]string)}
return visitor.result(field)
}
type fieldVisitor struct {
b *flatbuffers.Builder
- memo *dictMemo
+ memo *dictutils.Mapper
+ pos dictutils.FieldPos
dtype flatbuf.Type
offset flatbuffers.UOffsetT
kids []flatbuffers.UOffsetT
@@ -336,7 +333,7 @@ func (fv *fieldVisitor) visit(field arrow.Field) {
fv.dtype = flatbuf.TypeStruct_
offsets := make([]flatbuffers.UOffsetT, len(dt.Fields()))
for i, field := range dt.Fields() {
- offsets[i] = fieldToFB(fv.b, field, fv.memo)
+ offsets[i] = fieldToFB(fv.b, fv.pos.Child(int32(i)), field, fv.memo)
}
flatbuf.Struct_Start(fv.b)
for i := len(offsets) - 1; i >= 0; i-- {
@@ -347,13 +344,13 @@ func (fv *fieldVisitor) visit(field arrow.Field) {
case *arrow.ListType:
fv.dtype = flatbuf.TypeList
- fv.kids = append(fv.kids, fieldToFB(fv.b, dt.ElemField(), fv.memo))
+ fv.kids = append(fv.kids, fieldToFB(fv.b, fv.pos.Child(0), dt.ElemField(), fv.memo))
flatbuf.ListStart(fv.b)
fv.offset = flatbuf.ListEnd(fv.b)
case *arrow.FixedSizeListType:
fv.dtype = flatbuf.TypeFixedSizeList
- fv.kids = append(fv.kids, fieldToFB(fv.b, dt.ElemField(), fv.memo))
+ fv.kids = append(fv.kids, fieldToFB(fv.b, fv.pos.Child(0), dt.ElemField(), fv.memo))
flatbuf.FixedSizeListStart(fv.b)
flatbuf.FixedSizeListAddListSize(fv.b, dt.Len())
fv.offset = flatbuf.FixedSizeListEnd(fv.b)
@@ -385,7 +382,7 @@ func (fv *fieldVisitor) visit(field arrow.Field) {
case *arrow.MapType:
fv.dtype = flatbuf.TypeMap
- fv.kids = append(fv.kids, fieldToFB(fv.b, dt.ValueField(), fv.memo))
+ fv.kids = append(fv.kids, fieldToFB(fv.b, fv.pos.Child(0), dt.ValueField(), fv.memo))
flatbuf.MapStart(fv.b)
flatbuf.MapAddKeysSorted(fv.b, dt.KeysSorted)
fv.offset = flatbuf.MapEnd(fv.b)
@@ -396,6 +393,10 @@ func (fv *fieldVisitor) visit(field arrow.Field) {
fv.meta[ExtensionTypeKeyName] = dt.ExtensionName()
fv.meta[ExtensionMetadataKeyName] = string(dt.Serialize())
+ case *arrow.DictionaryType:
+ field.Type = dt.ValueType
+ fv.visit(field)
+
default:
err := fmt.Errorf("arrow/ipc: invalid data type %v", dt)
panic(err) // FIXME(sbinet): implement all data-types.
@@ -413,9 +414,32 @@ func (fv *fieldVisitor) result(field arrow.Field) flatbuffers.UOffsetT {
}
kidsFB := fv.b.EndVector(len(fv.kids))
+ storageType := field.Type
+ if storageType.ID() == arrow.EXTENSION {
+ storageType = storageType.(arrow.ExtensionType).StorageType()
+ }
+
var dictFB flatbuffers.UOffsetT
- if field.Type.ID() == arrow.DICTIONARY {
- panic("not implemented") // FIXME(sbinet)
+ if storageType.ID() == arrow.DICTIONARY {
+ idxType := field.Type.(*arrow.DictionaryType).IndexType.(arrow.FixedWidthDataType)
+
+ dictID, err := fv.memo.GetFieldID(fv.pos.Path())
+ if err != nil {
+ panic(err)
+ }
+ var signed bool
+ switch idxType.ID() {
+ case arrow.UINT8, arrow.UINT16, arrow.UINT32, arrow.UINT64:
+ signed = false
+ case arrow.INT8, arrow.INT16, arrow.INT32, arrow.INT64:
+ signed = true
+ }
+ indexTypeOffset := intToFB(fv.b, int32(idxType.BitWidth()), signed)
+ flatbuf.DictionaryEncodingStart(fv.b)
+ flatbuf.DictionaryEncodingAddId(fv.b, dictID)
+ flatbuf.DictionaryEncodingAddIndexType(fv.b, indexTypeOffset)
+ flatbuf.DictionaryEncodingAddIsOrdered(fv.b, field.Type.(*arrow.DictionaryType).Ordered)
+ dictFB = flatbuf.DictionaryEncodingEnd(fv.b)
}
var (
@@ -469,44 +493,7 @@ func (fv *fieldVisitor) result(field arrow.Field) flatbuffers.UOffsetT {
return offset
}
-func fieldFromFBDict(field *flatbuf.Field) (arrow.Field, error) {
- var (
- o = arrow.Field{
- Name: string(field.Name()),
- Nullable: field.Nullable(),
- }
- err error
- memo = newMemo()
- )
-
- // any DictionaryEncoding set is ignored here.
-
- kids := make([]arrow.Field, field.ChildrenLength())
- for i := range kids {
- var kid flatbuf.Field
- if !field.Children(&kid, i) {
- return o, fmt.Errorf("arrow/ipc: could not load field child %d", i)
- }
- kids[i], err = fieldFromFB(&kid, &memo)
- if err != nil {
- return o, fmt.Errorf("arrow/ipc: field from dict: %w", err)
- }
- }
-
- meta, err := metadataFromFB(field)
- if err != nil {
- return o, fmt.Errorf("arrow/ipc: metadata for field from dict: %w", err)
- }
-
- o.Type, err = typeFromFB(field, kids, &meta)
- if err != nil {
- return o, fmt.Errorf("arrow/ipc: type for field from dict: %w", err)
- }
-
- return o, nil
-}
-
-func typeFromFB(field *flatbuf.Field, children []arrow.Field, md *arrow.Metadata) (arrow.DataType, error) {
+func typeFromFB(field *flatbuf.Field, pos dictutils.FieldPos, children []arrow.Field, md *arrow.Metadata, memo *dictutils.Memo) (arrow.DataType, error) {
var data flatbuffers.Table
if !field.Type(&data) {
return nil, fmt.Errorf("arrow/ipc: could not load field type data")
@@ -517,6 +504,32 @@ func typeFromFB(field *flatbuf.Field, children []arrow.Field, md *arrow.Metadata
return dt, err
}
+ var (
+ dictID = int64(-1)
+ dictValueType arrow.DataType
+ encoding = field.Dictionary(nil)
+ )
+ if encoding != nil {
+ var idt flatbuf.Int
+ encoding.IndexType(&idt)
+ idxType, err := intFromFB(idt)
+ if err != nil {
+ return nil, err
+ }
+
+ dictValueType = dt
+ dt = &arrow.DictionaryType{IndexType: idxType, ValueType: dictValueType, Ordered: encoding.IsOrdered()}
+ dictID = encoding.Id()
+
+ if err = memo.Mapper.AddField(dictID, pos.Path()); err != nil {
+ return dt, err
+ }
+ if err = memo.AddType(dictID, dictValueType); err != nil {
+ return dt, err
+ }
+
+ }
+
// look for extension metadata in custom metadata field.
if md.Len() > 0 {
i := md.FindKey(ExtensionTypeKeyName)
@@ -874,10 +887,11 @@ func metadataToFB(b *flatbuffers.Builder, meta arrow.Metadata, start startVecFun
return b.EndVector(n)
}
-func schemaFromFB(schema *flatbuf.Schema, memo *dictMemo) (*arrow.Schema, error) {
+func schemaFromFB(schema *flatbuf.Schema, memo *dictutils.Memo) (*arrow.Schema, error) {
var (
err error
fields = make([]arrow.Field, schema.FieldsLength())
+ pos = dictutils.NewFieldPos()
)
for i := range fields {
@@ -886,7 +900,7 @@ func schemaFromFB(schema *flatbuf.Schema, memo *dictMemo) (*arrow.Schema, error)
return nil, fmt.Errorf("arrow/ipc: could not read field %d from schema", i)
}
- fields[i], err = fieldFromFB(&field, memo)
+ fields[i], err = fieldFromFB(&field, pos.Child(int32(i)), memo)
if err != nil {
return nil, fmt.Errorf("arrow/ipc: could not convert field %d from flatbuf: %w", i, err)
}
@@ -900,10 +914,11 @@ func schemaFromFB(schema *flatbuf.Schema, memo *dictMemo) (*arrow.Schema, error)
return arrow.NewSchema(fields, &md), nil
}
-func schemaToFB(b *flatbuffers.Builder, schema *arrow.Schema, memo *dictMemo) flatbuffers.UOffsetT {
+func schemaToFB(b *flatbuffers.Builder, schema *arrow.Schema, memo *dictutils.Mapper) flatbuffers.UOffsetT {
fields := make([]flatbuffers.UOffsetT, len(schema.Fields()))
+ pos := dictutils.NewFieldPos()
for i, field := range schema.Fields() {
- fields[i] = fieldToFB(b, field, memo)
+ fields[i] = fieldToFB(b, pos.Child(int32(i)), field, memo)
}
flatbuf.SchemaStartFieldsVector(b, len(fields))
@@ -923,73 +938,12 @@ func schemaToFB(b *flatbuffers.Builder, schema *arrow.Schema, memo *dictMemo) fl
return offset
}
-func dictTypesFromFB(schema *flatbuf.Schema) (dictTypeMap, error) {
- var (
- err error
- fields = make(dictTypeMap, schema.FieldsLength())
- )
- for i := 0; i < schema.FieldsLength(); i++ {
- var field flatbuf.Field
- if !schema.Fields(&field, i) {
- return nil, fmt.Errorf("arrow/ipc: could not load field %d from schema", i)
- }
- fields, err = visitField(&field, fields)
- if err != nil {
- return nil, fmt.Errorf("arrow/ipc: could not visit field %d from schema: %w", i, err)
- }
- }
- return fields, err
-}
-
-func visitField(field *flatbuf.Field, dict dictTypeMap) (dictTypeMap, error) {
- var err error
- meta := field.Dictionary(nil)
- switch meta {
- case nil:
- // field is not dictionary encoded.
- // => visit children.
- for i := 0; i < field.ChildrenLength(); i++ {
- var child flatbuf.Field
- if !field.Children(&child, i) {
- return nil, fmt.Errorf("arrow/ipc: could not visit child %d from field", i)
- }
- dict, err = visitField(&child, dict)
- if err != nil {
- return nil, err
- }
- }
- default:
- // field is dictionary encoded.
- // construct the data type for the dictionary: no descendants can be dict-encoded.
- dfield, err := fieldFromFBDict(field)
- if err != nil {
- return nil, fmt.Errorf("arrow/ipc: could not create data type for dictionary: %w", err)
- }
- dict[meta.Id()] = dfield
- }
- return dict, err
-}
-
-// payloadsFromSchema returns a slice of payloads corresponding to the given schema.
-// Callers of payloadsFromSchema will need to call Release after use.
-func payloadsFromSchema(schema *arrow.Schema, mem memory.Allocator, memo *dictMemo) payloads {
- dict := newMemo()
-
- ps := make(payloads, 1, dict.Len()+1)
+// payloadFromSchema returns a slice of payloads corresponding to the given schema.
+// Callers of payloadFromSchema will need to call Release after use.
+func payloadFromSchema(schema *arrow.Schema, mem memory.Allocator, memo *dictutils.Mapper) payloads {
+ ps := make(payloads, 1)
ps[0].msg = MessageSchema
- ps[0].meta = writeSchemaMessage(schema, mem, &dict)
-
- // append dictionaries.
- if dict.Len() > 0 {
- panic("payloads-from-schema: not-implemented")
- // for id, arr := range dict.id2dict {
- // // GetSchemaPayloads: writer.cc:535
- // }
- }
-
- if memo != nil {
- *memo = dict
- }
+ ps[0].meta = writeSchemaMessage(schema, mem, memo)
return ps
}
@@ -1015,7 +969,7 @@ func writeMessageFB(b *flatbuffers.Builder, mem memory.Allocator, hdrType flatbu
return writeFBBuilder(b, mem)
}
-func writeSchemaMessage(schema *arrow.Schema, mem memory.Allocator, dict *dictMemo) *memory.Buffer {
+func writeSchemaMessage(schema *arrow.Schema, mem memory.Allocator, dict *dictutils.Mapper) *memory.Buffer {
b := flatbuffers.NewBuilder(1024)
schemaFB := schemaToFB(b, schema, dict)
return writeMessageFB(b, mem, flatbuf.MessageHeaderSchema, schemaFB, 0)
@@ -1024,8 +978,9 @@ func writeSchemaMessage(schema *arrow.Schema, mem memory.Allocator, dict *dictMe
func writeFileFooter(schema *arrow.Schema, dicts, recs []fileBlock, w io.Writer) error {
var (
b = flatbuffers.NewBuilder(1024)
- memo = newMemo()
+ memo dictutils.Mapper
)
+ memo.ImportSchema(schema)
schemaFB := schemaToFB(b, schema, &memo)
dictsFB := fileBlocksToFB(b, dicts, flatbuf.FooterStartDictionariesVector)
@@ -1050,6 +1005,18 @@ func writeRecordMessage(mem memory.Allocator, size, bodyLength int64, fields []f
return writeMessageFB(b, mem, flatbuf.MessageHeaderRecordBatch, recFB, bodyLength)
}
+func writeDictionaryMessage(mem memory.Allocator, id int64, isDelta bool, size, bodyLength int64, fields []fieldMetadata, meta []bufferMetadata, codec flatbuf.CompressionType) *memory.Buffer {
+ b := flatbuffers.NewBuilder(0)
+ recFB := recordToFB(b, size, bodyLength, fields, meta, codec)
+
+ flatbuf.DictionaryBatchStart(b)
+ flatbuf.DictionaryBatchAddId(b, id)
+ flatbuf.DictionaryBatchAddData(b, recFB)
+ flatbuf.DictionaryBatchAddIsDelta(b, isDelta)
+ dictFB := flatbuf.DictionaryBatchEnd(b)
+ return writeMessageFB(b, mem, flatbuf.MessageHeaderDictionaryBatch, dictFB, bodyLength)
+}
+
func recordToFB(b *flatbuffers.Builder, size, bodyLength int64, fields []fieldMetadata, meta []bufferMetadata, codec flatbuf.CompressionType) flatbuffers.UOffsetT {
fieldsFB := writeFieldNodes(b, fields, flatbuf.RecordBatchStartNodesVector)
metaFB := writeBuffers(b, meta, flatbuf.RecordBatchStartBuffersVector)
diff --git a/go/arrow/ipc/metadata_test.go b/go/arrow/ipc/metadata_test.go
index 4de17766a6..cc8e613b4c 100644
--- a/go/arrow/ipc/metadata_test.go
+++ b/go/arrow/ipc/metadata_test.go
@@ -23,6 +23,7 @@ import (
"github.com/apache/arrow/go/v8/arrow"
"github.com/apache/arrow/go/v8/arrow/array"
+ "github.com/apache/arrow/go/v8/arrow/internal/dictutils"
"github.com/apache/arrow/go/v8/arrow/internal/flatbuf"
"github.com/apache/arrow/go/v8/arrow/internal/testing/types"
"github.com/apache/arrow/go/v8/arrow/memory"
@@ -34,7 +35,7 @@ func TestRWSchema(t *testing.T) {
meta := arrow.NewMetadata([]string{"k1", "k2", "k3"}, []string{"v1", "v2", "v3"})
for _, tc := range []struct {
schema *arrow.Schema
- memo dictMemo
+ memo dictutils.Memo
}{
{
schema: arrow.NewSchema([]arrow.Field{
@@ -42,13 +43,14 @@ func TestRWSchema(t *testing.T) {
{Name: "f2", Type: arrow.PrimitiveTypes.Uint16},
{Name: "f3", Type: arrow.PrimitiveTypes.Float64},
}, &meta),
- memo: newMemo(),
+ memo: dictutils.Memo{},
},
} {
t.Run("", func(t *testing.T) {
b := flatbuffers.NewBuilder(0)
- offset := schemaToFB(b, tc.schema, &tc.memo)
+ tc.memo.Mapper.ImportSchema(tc.schema)
+ offset := schemaToFB(b, tc.schema, &tc.memo.Mapper)
b.Finish(offset)
buf := b.FinishedBytes()
diff --git a/go/arrow/ipc/reader.go b/go/arrow/ipc/reader.go
index e3275359b0..dc3a6d6db8 100644
--- a/go/arrow/ipc/reader.go
+++ b/go/arrow/ipc/reader.go
@@ -26,6 +26,7 @@ import (
"github.com/apache/arrow/go/v8/arrow"
"github.com/apache/arrow/go/v8/arrow/array"
"github.com/apache/arrow/go/v8/arrow/internal/debug"
+ "github.com/apache/arrow/go/v8/arrow/internal/dictutils"
"github.com/apache/arrow/go/v8/arrow/internal/flatbuf"
"github.com/apache/arrow/go/v8/arrow/memory"
)
@@ -41,8 +42,9 @@ type Reader struct {
rec arrow.Record
err error
- types dictTypeMap
- memo dictMemo
+ // types dictTypeMap
+ memo dictutils.Memo
+ readInitialDicts bool
mem memory.Allocator
@@ -61,9 +63,9 @@ func NewReaderFromMessageReader(r MessageReader, opts ...Option) (*Reader, error
rr := &Reader{
r: r,
refCount: 1,
- types: make(dictTypeMap),
- memo: newMemo(),
- mem: cfg.alloc,
+ // types: make(dictTypeMap),
+ memo: dictutils.NewMemo(),
+ mem: cfg.alloc,
}
err := rr.readSchema(cfg.schema)
@@ -99,17 +101,6 @@ func (r *Reader) readSchema(schema *arrow.Schema) error {
var schemaFB flatbuf.Schema
initFB(&schemaFB, msg.msg.Header)
- r.types, err = dictTypesFromFB(&schemaFB)
- if err != nil {
- return fmt.Errorf("arrow/ipc: could read dictionary types from message schema: %w", err)
- }
-
- // TODO(sbinet): in the future, we may want to reconcile IDs in the stream with
- // those found in the schema.
- for range r.types {
- panic("not implemented") // FIXME(sbinet): ReadNextDictionary
- }
-
r.schema, err = schemaFromFB(&schemaFB, &r.memo)
if err != nil {
return fmt.Errorf("arrow/ipc: could not decode schema from message schema: %w", err)
@@ -161,9 +152,54 @@ func (r *Reader) Next() bool {
return r.next()
}
+func (r *Reader) getInitialDicts() bool {
+ var msg *Message
+ // we have to get all dictionaries before reconstructing the first
+ // record. subsequent deltas and replacements modify the memo
+ numDicts := r.memo.Mapper.NumDicts()
+ // there should be numDicts dictionary messages
+ for i := 0; i < numDicts; i++ {
+ msg, r.err = r.r.Message()
+ if r.err != nil {
+ r.done = true
+ if r.err == io.EOF {
+ if i == 0 {
+ r.err = nil
+ } else {
+ r.err = fmt.Errorf("arrow/ipc: IPC stream ended without reading the expected (%d) dictionaries", numDicts)
+ }
+ }
+ return false
+ }
+
+ if msg.Type() != MessageDictionaryBatch {
+ r.err = fmt.Errorf("arrow/ipc: IPC stream did not have the expected (%d) dictionaries at the start of the stream", numDicts)
+ }
+ if _, err := readDictionary(&r.memo, msg.meta, bytes.NewReader(msg.body.Bytes()), r.mem); err != nil {
+ r.done = true
+ r.err = err
+ return false
+ }
+ }
+ r.readInitialDicts = true
+ return true
+}
+
func (r *Reader) next() bool {
+ if !r.readInitialDicts && !r.getInitialDicts() {
+ return false
+ }
+
var msg *Message
msg, r.err = r.r.Message()
+
+ for msg != nil && msg.Type() == MessageDictionaryBatch {
+ if _, r.err = readDictionary(&r.memo, msg.meta, bytes.NewReader(msg.body.Bytes()), r.mem); r.err != nil {
+ r.done = true
+ return false
+ }
+ msg, r.err = r.r.Message()
+ }
if r.err != nil {
r.done = true
if errors.Is(r.err, io.EOF) {
@@ -177,7 +213,7 @@ func (r *Reader) next() bool {
return false
}
- r.rec = newRecord(r.schema, msg.meta, bytes.NewReader(msg.body.Bytes()), r.mem)
+ r.rec = newRecord(r.schema, &r.memo, msg.meta, bytes.NewReader(msg.body.Bytes()), r.mem)
return true
}
diff --git a/go/arrow/ipc/writer.go b/go/arrow/ipc/writer.go
index bf96952789..f2496bf1d7 100644
--- a/go/arrow/ipc/writer.go
+++ b/go/arrow/ipc/writer.go
@@ -20,6 +20,7 @@ import (
"bytes"
"context"
"encoding/binary"
+ "errors"
"fmt"
"io"
"math"
@@ -28,6 +29,7 @@ import (
"github.com/apache/arrow/go/v8/arrow"
"github.com/apache/arrow/go/v8/arrow/array"
"github.com/apache/arrow/go/v8/arrow/bitutil"
+ "github.com/apache/arrow/go/v8/arrow/internal/dictutils"
"github.com/apache/arrow/go/v8/arrow/internal/flatbuf"
"github.com/apache/arrow/go/v8/arrow/memory"
)
@@ -57,6 +59,18 @@ func (w *swriter) Write(p []byte) (int, error) {
return n, err
}
+func hasNestedDict(data arrow.ArrayData) bool {
+ if data.DataType().ID() == arrow.DICTIONARY {
+ return true
+ }
+ for _, c := range data.Children() {
+ if hasNestedDict(c) {
+ return true
+ }
+ }
+ return false
+}
+
// Writer is an Arrow stream writer.
type Writer struct {
w io.Writer
@@ -66,8 +80,14 @@ type Writer struct {
started bool
schema *arrow.Schema
+ mapper dictutils.Mapper
codec flatbuf.CompressionType
compressNP int
+
+ // map of the last written dictionaries by id
+ // so we can avoid writing the same dictionary over and over
+ lastWrittenDicts map[int64]arrow.Array
+ emitDictDeltas bool
}
// NewWriterWithPayloadWriter constructs a writer with the provided payload writer
@@ -114,6 +134,10 @@ func (w *Writer) Close() error {
}
w.pw = nil
+ for _, d := range w.lastWrittenDicts {
+ d.Release()
+ }
+
return nil
}
@@ -137,6 +161,12 @@ func (w *Writer) Write(rec arrow.Record) error {
)
defer data.Release()
+ err := writeDictionaryPayloads(w.mem, rec, false, w.emitDictDeltas, &w.mapper, w.lastWrittenDicts, w.pw, enc)
+ if err != nil {
+ return fmt.Errorf("arrow/ipc: failure writing dictionary batches: %w", err)
+ }
+
+ enc.reset()
if err := enc.Encode(&data, rec); err != nil {
return fmt.Errorf("arrow/ipc: could not encode record to payload: %w", err)
}
@@ -144,11 +174,80 @@ func (w *Writer) Write(rec arrow.Record) error {
return w.pw.WritePayload(data)
}
+func writeDictionaryPayloads(mem memory.Allocator, batch arrow.Record, isFileFormat bool, emitDictDeltas bool, mapper *dictutils.Mapper, lastWrittenDicts map[int64]arrow.Array, pw PayloadWriter, encoder *recordEncoder) error {
+ dictionaries, err := dictutils.CollectDictionaries(batch, mapper)
+ if err != nil {
+ return err
+ }
+ defer func() {
+ for _, d := range dictionaries {
+ d.Dict.Release()
+ }
+ }()
+
+ eqopt := array.WithNaNsEqual(true)
+ for _, pair := range dictionaries {
+ encoder.reset()
+ var (
+ deltaStart int64
+ enc = dictEncoder{encoder}
+ )
+ lastDict, exists := lastWrittenDicts[pair.ID]
+ if exists {
+ if lastDict.Data() == pair.Dict.Data() {
+ continue
+ }
+ newLen, lastLen := pair.Dict.Len(), lastDict.Len()
+ if lastLen == newLen && array.ArrayApproxEqual(lastDict, pair.Dict, eqopt) {
+ // same dictionary by value
+ // might cost CPU, but required for IPC file format
+ continue
+ }
+ if isFileFormat {
+ return errors.New("arrow/ipc: Dictionary replacement detected when writing IPC file format. Arrow IPC File only supports single dictionary per field")
+ }
+
+ if newLen > lastLen &&
+ emitDictDeltas &&
+ !hasNestedDict(pair.Dict.Data()) &&
+ (array.ArraySliceApproxEqual(lastDict, 0, int64(lastLen), pair.Dict, 0, int64(lastLen), eqopt)) {
+ deltaStart = int64(lastLen)
+ }
+ }
+
+ var data = Payload{msg: MessageDictionaryBatch}
+ defer data.Release()
+
+ dict := pair.Dict
+ if deltaStart > 0 {
+ dict = array.NewSlice(dict, deltaStart, int64(dict.Len()))
+ defer dict.Release()
+ }
+ if err := enc.Encode(&data, pair.ID, deltaStart > 0, dict); err != nil {
+ return err
+ }
+
+ if err := pw.WritePayload(data); err != nil {
+ return err
+ }
+
+ lastWrittenDicts[pair.ID] = pair.Dict
+ if lastDict != nil {
+ lastDict.Release()
+ }
+ pair.Dict.Retain()
+ }
+ return nil
+}
+
func (w *Writer) start() error {
w.started = true
+ w.mapper.ImportSchema(w.schema)
+ w.lastWrittenDicts = make(map[int64]arrow.Array)
+
// write out schema payloads
- ps := payloadsFromSchema(w.schema, w.mem, nil)
+ ps := payloadFromSchema(w.schema, w.mem, &w.mapper)
defer ps.Release()
for _, data := range ps {
@@ -161,6 +260,31 @@ func (w *Writer) start() error {
return nil
}
+type dictEncoder struct {
+ *recordEncoder
+}
+
+func (d *dictEncoder) encodeMetadata(p *Payload, isDelta bool, id, nrows int64) error {
+ p.meta = writeDictionaryMessage(d.mem, id, isDelta, nrows, p.size, d.fields, d.meta, d.codec)
+ return nil
+}
+
+func (d *dictEncoder) Encode(p *Payload, id int64, isDelta bool, dict arrow.Array) error {
+ d.start = 0
+ defer func() {
+ d.start = 0
+ }()
+
+ schema := arrow.NewSchema([]arrow.Field{{Name: "dictionary", Type: dict.DataType(), Nullable: true}}, nil)
+ batch := array.NewRecord(schema, []arrow.Array{dict}, int64(dict.Len()))
+ defer batch.Release()
+ if err := d.encode(p, batch); err != nil {
+ return err
+ }
+
+ return d.encodeMetadata(p, isDelta, id, batch.NumRows())
+}
+
type recordEncoder struct {
mem memory.Allocator
@@ -185,6 +309,11 @@ func newRecordEncoder(mem memory.Allocator, startOffset, maxDepth int64, allow64
}
}
+func (w *recordEncoder) reset() {
+ w.start = 0
+ w.fields = make([]fieldMetadata, 0)
+}
+
func (w *recordEncoder) compressBodyBuffers(p *Payload) error {
compress := func(idx int, codec compressor) error {
if p.body[idx] == nil || p.body[idx].Len() == 0 {
@@ -261,7 +390,7 @@ func (w *recordEncoder) compressBodyBuffers(p *Payload) error {
return <-errch
}
-func (w *recordEncoder) Encode(p *Payload, rec arrow.Record) error {
+func (w *recordEncoder) encode(p *Payload, rec arrow.Record) error {
// perform depth-first traversal of the row-batch
for i, col := range rec.Columns() {
@@ -305,7 +434,7 @@ func (w *recordEncoder) Encode(p *Payload, rec arrow.Record) error {
panic("not aligned")
}
- return w.encodeMetadata(p, rec.NumRows())
+ return nil
}
func (w *recordEncoder) visit(p *Payload, arr arrow.Array) error {
@@ -326,6 +455,11 @@ func (w *recordEncoder) visit(p *Payload, arr arrow.Array) error {
return nil
}
+ if arr.DataType().ID() == arrow.DICTIONARY {
+ arr := arr.(*array.Dictionary)
+ return w.visit(p, arr.Indices())
+ }
+
// add all common elements
w.fields = append(w.fields, fieldMetadata{
Len: int64(arr.Len()),
@@ -597,6 +731,13 @@ func (w *recordEncoder) getZeroBasedValueOffsets(arr arrow.Array) (*memory.Buffe
return voffsets, nil
}
+func (w *recordEncoder) Encode(p *Payload, rec arrow.Record) error {
+ if err := w.encode(p, rec); err != nil {
+ return err
+ }
+ return w.encodeMetadata(p, rec.NumRows())
+}
+
func (w *recordEncoder) encodeMetadata(p *Payload, nrows int64) error {
p.meta = writeRecordMessage(w.mem, nrows, p.size, w.fields, w.meta, w.codec)
return nil
diff --git a/go/go.mod b/go/go.mod
index df95ab28e0..54daf232b1 100644
--- a/go/go.mod
+++ b/go/go.mod
@@ -23,7 +23,7 @@ require (
github.com/andybalholm/brotli v1.0.4
github.com/apache/thrift v0.15.0
github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815
- github.com/goccy/go-json v0.7.10
+ github.com/goccy/go-json v0.9.6
github.com/golang/snappy v0.0.4
github.com/google/flatbuffers v2.0.5+incompatible
github.com/google/go-cmp v0.5.7 // indirect
diff --git a/go/go.sum b/go/go.sum
index 894b9acac1..dfb88a1215 100644
--- a/go/go.sum
+++ b/go/go.sum
@@ -103,6 +103,8 @@ github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/goccy/go-json v0.7.10 h1:ulhbuNe1JqE68nMRXXTJRrUu0uhouf0VevLINxQq4Ec=
github.com/goccy/go-json v0.7.10/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
+github.com/goccy/go-json v0.9.6 h1:5/4CtRQdtsX0sal8fdVhTaiMN01Ri8BExZZ8iRmHQ6E=
+github.com/goccy/go-json v0.9.6/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
diff --git a/go/parquet/internal/hashing/hashing_test.go b/go/internal/hashing/hashing_test.go
similarity index 100%
rename from go/parquet/internal/hashing/hashing_test.go
rename to go/internal/hashing/hashing_test.go
diff --git a/go/internal/hashing/types.tmpldata b/go/internal/hashing/types.tmpldata
new file mode 100644
index 0000000000..0ba6f765d2
--- /dev/null
+++ b/go/internal/hashing/types.tmpldata
@@ -0,0 +1,42 @@
+[
+ {
+ "Name": "Int8",
+ "name": "int8"
+ },
+ {
+ "Name": "Uint8",
+ "name": "uint8"
+ },
+ {
+ "Name": "Int16",
+ "name": "int16"
+ },
+ {
+ "Name": "Uint16",
+ "name": "uint16"
+ },
+ {
+ "Name": "Int32",
+ "name": "int32"
+ },
+ {
+ "Name": "Int64",
+ "name": "int64"
+ },
+ {
+ "Name": "Uint32",
+ "name": "uint32"
+ },
+ {
+ "Name": "Uint64",
+ "name": "uint64"
+ },
+ {
+ "Name": "Float32",
+ "name": "float32"
+ },
+ {
+ "Name": "Float64",
+ "name": "float64"
+ }
+]
diff --git a/go/internal/hashing/xxh3_memo_table.gen.go b/go/internal/hashing/xxh3_memo_table.gen.go
new file mode 100644
index 0000000000..b20cbf0d27
--- /dev/null
+++ b/go/internal/hashing/xxh3_memo_table.gen.go
@@ -0,0 +1,2783 @@
+// Code generated by xxh3_memo_table.gen.go.tmpl. DO NOT EDIT.
+
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hashing
+
+import (
+ "math"
+
+ "github.com/apache/arrow/go/v8/arrow"
+ "github.com/apache/arrow/go/v8/arrow/bitutil"
+ "github.com/apache/arrow/go/v8/internal/utils"
+)
+
+type payloadInt8 struct {
+ val int8
+ memoIdx int32
+}
+
+type entryInt8 struct {
+ h uint64
+ payload payloadInt8
+}
+
+func (e entryInt8) Valid() bool { return e.h != sentinel }
+
+// Int8HashTable is a hashtable specifically for int8 that
+// is utilized with the MemoTable to generalize interactions for easier
+// implementation of dictionaries without losing performance.
+type Int8HashTable struct {
+ cap uint64
+ capMask uint64
+ size uint64
+
+ entries []entryInt8
+}
+
+// NewInt8HashTable returns a new hash table for int8 values
+// initialized with the passed in capacity or 32 whichever is larger.
+func NewInt8HashTable(cap uint64) *Int8HashTable {
+ initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ ret := &Int8HashTable{cap: initCap, capMask: initCap - 1, size: 0}
+ ret.entries = make([]entryInt8, initCap)
+ return ret
+}
+
+// Reset drops all of the values in this hash table and re-initializes it
+// with the specified initial capacity as if by calling New, but without having
+// to reallocate the object.
+func (h *Int8HashTable) Reset(cap uint64) {
+ h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ h.capMask = h.cap - 1
+ h.size = 0
+ h.entries = make([]entryInt8, h.cap)
+}
+
+// CopyValues is used for copying the values out of the hash table into the
+// passed in slice, in the order that they were first inserted
+func (h *Int8HashTable) CopyValues(out []int8) {
+ h.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset copies a subset of the values in the hashtable out, starting
+// with the value at start, in the order that they were inserted.
+func (h *Int8HashTable) CopyValuesSubset(start int, out []int8) {
+ h.VisitEntries(func(e *entryInt8) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ out[idx] = e.payload.val
+ }
+ })
+}
+
+func (h *Int8HashTable) WriteOut(out []byte) {
+ h.WriteOutSubset(0, out)
+}
+
+func (h *Int8HashTable) WriteOutSubset(start int, out []byte) {
+ data := arrow.Int8Traits.CastFromBytes(out)
+ h.VisitEntries(func(e *entryInt8) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ data[idx] = e.payload.val
+ }
+ })
+}
+
+func (h *Int8HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap }
+
+func (Int8HashTable) fixHash(v uint64) uint64 {
+ if v == sentinel {
+ return 42
+ }
+ return v
+}
+
+// Lookup retrieves the entry for a given hash value assuming it's payload value returns
+// true when passed to the cmp func. Returns a pointer to the entry for the given hash value,
+// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false.
+func (h *Int8HashTable) Lookup(v uint64, cmp func(int8) bool) (*entryInt8, bool) {
+ idx, ok := h.lookup(v, h.capMask, cmp)
+ return &h.entries[idx], ok
+}
+
+func (h *Int8HashTable) lookup(v uint64, szMask uint64, cmp func(int8) bool) (uint64, bool) {
+ const perturbShift uint8 = 5
+
+ var (
+ idx uint64
+ perturb uint64
+ e *entryInt8
+ )
+
+ v = h.fixHash(v)
+ idx = v & szMask
+ perturb = (v >> uint64(perturbShift)) + 1
+
+ for {
+ e = &h.entries[idx]
+ if e.h == v && cmp(e.payload.val) {
+ return idx, true
+ }
+
+ if e.h == sentinel {
+ return idx, false
+ }
+
+ // perturbation logic inspired from CPython's set/dict object
+ // the goal is that all 64 bits of unmasked hash value eventually
+ // participate int he probing sequence, to minimize clustering
+ idx = (idx + perturb) & szMask
+ perturb = (perturb >> uint64(perturbShift)) + 1
+ }
+}
+
+func (h *Int8HashTable) upsize(newcap uint64) error {
+ newMask := newcap - 1
+
+ oldEntries := h.entries
+ h.entries = make([]entryInt8, newcap)
+ for _, e := range oldEntries {
+ if e.Valid() {
+ idx, _ := h.lookup(e.h, newMask, func(int8) bool { return false })
+ h.entries[idx] = e
+ }
+ }
+ h.cap = newcap
+ h.capMask = newMask
+ return nil
+}
+
+// Insert updates the given entry with the provided hash value, payload value and memo index.
+// The entry pointer must have been retrieved via lookup in order to actually insert properly.
+func (h *Int8HashTable) Insert(e *entryInt8, v uint64, val int8, memoIdx int32) error {
+ e.h = h.fixHash(v)
+ e.payload.val = val
+ e.payload.memoIdx = memoIdx
+ h.size++
+
+ if h.needUpsize() {
+ h.upsize(h.cap * uint64(loadFactor) * 2)
+ }
+ return nil
+}
+
+// VisitEntries will call the passed in function on each *valid* entry in the hash table,
+// a valid entry being one which has had a value inserted into it.
+func (h *Int8HashTable) VisitEntries(visit func(*entryInt8)) {
+ for _, e := range h.entries {
+ if e.Valid() {
+ visit(&e)
+ }
+ }
+}
+
+// Int8MemoTable is a wrapper over the appropriate hashtable to provide an interface
+// conforming to the MemoTable interface defined in the encoding package for general interactions
+// regarding dictionaries.
+type Int8MemoTable struct {
+ tbl *Int8HashTable
+ nullIdx int32
+}
+
+// NewInt8MemoTable returns a new memotable with num entries pre-allocated to reduce further
+// allocations when inserting.
+func NewInt8MemoTable(num int64) *Int8MemoTable {
+ return &Int8MemoTable{tbl: NewInt8HashTable(uint64(num)), nullIdx: KeyNotFound}
+}
+
+func (Int8MemoTable) TypeTraits() TypeTraits {
+ return arrow.Int8Traits
+}
+
+// Reset allows this table to be re-used by dumping all the data currently in the table.
+func (s *Int8MemoTable) Reset() {
+ s.tbl.Reset(32)
+ s.nullIdx = KeyNotFound
+}
+
+// Size returns the current number of inserted elements into the table including if a null
+// has been inserted.
+func (s *Int8MemoTable) Size() int {
+ sz := int(s.tbl.size)
+ if _, ok := s.GetNull(); ok {
+ sz++
+ }
+ return sz
+}
+
+// GetNull returns the index of an inserted null or KeyNotFound along with a bool
+// that will be true if found and false if not.
+func (s *Int8MemoTable) GetNull() (int, bool) {
+ return int(s.nullIdx), s.nullIdx != KeyNotFound
+}
+
+// GetOrInsertNull will return the index of the null entry or insert a null entry
+// if one currently doesn't exist. The found value will be true if there was already
+// a null in the table, and false if it inserted one.
+func (s *Int8MemoTable) GetOrInsertNull() (idx int, found bool) {
+ idx, found = s.GetNull()
+ if !found {
+ idx = s.Size()
+ s.nullIdx = int32(idx)
+ }
+ return
+}
+
+// CopyValues will copy the values from the memo table out into the passed in slice
+// which must be of the appropriate type.
+func (s *Int8MemoTable) CopyValues(out interface{}) {
+ s.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset is like CopyValues but only copies a subset of values starting
+// at the provided start index
+func (s *Int8MemoTable) CopyValuesSubset(start int, out interface{}) {
+ s.tbl.CopyValuesSubset(start, out.([]int8))
+}
+
+func (s *Int8MemoTable) WriteOut(out []byte) {
+ s.tbl.CopyValues(arrow.Int8Traits.CastFromBytes(out))
+}
+
+func (s *Int8MemoTable) WriteOutSubset(start int, out []byte) {
+ s.tbl.CopyValuesSubset(start, arrow.Int8Traits.CastFromBytes(out))
+}
+
+func (s *Int8MemoTable) WriteOutLE(out []byte) {
+ s.tbl.WriteOut(out)
+}
+
+func (s *Int8MemoTable) WriteOutSubsetLE(start int, out []byte) {
+ s.tbl.WriteOutSubset(start, out)
+}
+
+// Get returns the index of the requested value in the hash table or KeyNotFound
+// along with a boolean indicating if it was found or not.
+func (s *Int8MemoTable) Get(val interface{}) (int, bool) {
+
+ h := hashInt(uint64(val.(int8)), 0)
+ if e, ok := s.tbl.Lookup(h, func(v int8) bool { return val.(int8) == v }); ok {
+ return int(e.payload.memoIdx), ok
+ }
+ return KeyNotFound, false
+}
+
+// GetOrInsert will return the index of the specified value in the table, or insert the
+// value into the table and return the new index. found indicates whether or not it already
+// existed in the table (true) or was inserted by this call (false).
+func (s *Int8MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) {
+
+ h := hashInt(uint64(val.(int8)), 0)
+ e, ok := s.tbl.Lookup(h, func(v int8) bool {
+ return val.(int8) == v
+ })
+
+ if ok {
+ idx = int(e.payload.memoIdx)
+ found = true
+ } else {
+ idx = s.Size()
+ s.tbl.Insert(e, h, val.(int8), int32(idx))
+ }
+ return
+}
+
+type payloadUint8 struct {
+ val uint8
+ memoIdx int32
+}
+
+type entryUint8 struct {
+ h uint64
+ payload payloadUint8
+}
+
+func (e entryUint8) Valid() bool { return e.h != sentinel }
+
+// Uint8HashTable is a hashtable specifically for uint8 that
+// is utilized with the MemoTable to generalize interactions for easier
+// implementation of dictionaries without losing performance.
+type Uint8HashTable struct {
+ cap uint64
+ capMask uint64
+ size uint64
+
+ entries []entryUint8
+}
+
+// NewUint8HashTable returns a new hash table for uint8 values
+// initialized with the passed in capacity or 32 whichever is larger.
+func NewUint8HashTable(cap uint64) *Uint8HashTable {
+ initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ ret := &Uint8HashTable{cap: initCap, capMask: initCap - 1, size: 0}
+ ret.entries = make([]entryUint8, initCap)
+ return ret
+}
+
+// Reset drops all of the values in this hash table and re-initializes it
+// with the specified initial capacity as if by calling New, but without having
+// to reallocate the object.
+func (h *Uint8HashTable) Reset(cap uint64) {
+ h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ h.capMask = h.cap - 1
+ h.size = 0
+ h.entries = make([]entryUint8, h.cap)
+}
+
+// CopyValues is used for copying the values out of the hash table into the
+// passed in slice, in the order that they were first inserted
+func (h *Uint8HashTable) CopyValues(out []uint8) {
+ h.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset copies a subset of the values in the hashtable out, starting
+// with the value at start, in the order that they were inserted.
+func (h *Uint8HashTable) CopyValuesSubset(start int, out []uint8) {
+ h.VisitEntries(func(e *entryUint8) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ out[idx] = e.payload.val
+ }
+ })
+}
+
+func (h *Uint8HashTable) WriteOut(out []byte) {
+ h.WriteOutSubset(0, out)
+}
+
+func (h *Uint8HashTable) WriteOutSubset(start int, out []byte) {
+ data := arrow.Uint8Traits.CastFromBytes(out)
+ h.VisitEntries(func(e *entryUint8) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ data[idx] = e.payload.val
+ }
+ })
+}
+
+func (h *Uint8HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap }
+
+func (Uint8HashTable) fixHash(v uint64) uint64 {
+ if v == sentinel {
+ return 42
+ }
+ return v
+}
+
+// Lookup retrieves the entry for a given hash value assuming it's payload value returns
+// true when passed to the cmp func. Returns a pointer to the entry for the given hash value,
+// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false.
+func (h *Uint8HashTable) Lookup(v uint64, cmp func(uint8) bool) (*entryUint8, bool) {
+ idx, ok := h.lookup(v, h.capMask, cmp)
+ return &h.entries[idx], ok
+}
+
+func (h *Uint8HashTable) lookup(v uint64, szMask uint64, cmp func(uint8) bool) (uint64, bool) {
+ const perturbShift uint8 = 5
+
+ var (
+ idx uint64
+ perturb uint64
+ e *entryUint8
+ )
+
+ v = h.fixHash(v)
+ idx = v & szMask
+ perturb = (v >> uint64(perturbShift)) + 1
+
+ for {
+ e = &h.entries[idx]
+ if e.h == v && cmp(e.payload.val) {
+ return idx, true
+ }
+
+ if e.h == sentinel {
+ return idx, false
+ }
+
+ // perturbation logic inspired from CPython's set/dict object
+ // the goal is that all 64 bits of unmasked hash value eventually
+ // participate int he probing sequence, to minimize clustering
+ idx = (idx + perturb) & szMask
+ perturb = (perturb >> uint64(perturbShift)) + 1
+ }
+}
+
+func (h *Uint8HashTable) upsize(newcap uint64) error {
+ newMask := newcap - 1
+
+ oldEntries := h.entries
+ h.entries = make([]entryUint8, newcap)
+ for _, e := range oldEntries {
+ if e.Valid() {
+ idx, _ := h.lookup(e.h, newMask, func(uint8) bool { return false })
+ h.entries[idx] = e
+ }
+ }
+ h.cap = newcap
+ h.capMask = newMask
+ return nil
+}
+
+// Insert updates the given entry with the provided hash value, payload value and memo index.
+// The entry pointer must have been retrieved via lookup in order to actually insert properly.
+func (h *Uint8HashTable) Insert(e *entryUint8, v uint64, val uint8, memoIdx int32) error {
+ e.h = h.fixHash(v)
+ e.payload.val = val
+ e.payload.memoIdx = memoIdx
+ h.size++
+
+ if h.needUpsize() {
+ h.upsize(h.cap * uint64(loadFactor) * 2)
+ }
+ return nil
+}
+
+// VisitEntries will call the passed in function on each *valid* entry in the hash table,
+// a valid entry being one which has had a value inserted into it.
+func (h *Uint8HashTable) VisitEntries(visit func(*entryUint8)) {
+ for _, e := range h.entries {
+ if e.Valid() {
+ visit(&e)
+ }
+ }
+}
+
+// Uint8MemoTable is a wrapper over the appropriate hashtable to provide an interface
+// conforming to the MemoTable interface defined in the encoding package for general interactions
+// regarding dictionaries.
+type Uint8MemoTable struct {
+ tbl *Uint8HashTable
+ nullIdx int32
+}
+
+// NewUint8MemoTable returns a new memotable with num entries pre-allocated to reduce further
+// allocations when inserting.
+func NewUint8MemoTable(num int64) *Uint8MemoTable {
+ return &Uint8MemoTable{tbl: NewUint8HashTable(uint64(num)), nullIdx: KeyNotFound}
+}
+
+func (Uint8MemoTable) TypeTraits() TypeTraits {
+ return arrow.Uint8Traits
+}
+
+// Reset allows this table to be re-used by dumping all the data currently in the table.
+func (s *Uint8MemoTable) Reset() {
+ s.tbl.Reset(32)
+ s.nullIdx = KeyNotFound
+}
+
+// Size returns the current number of inserted elements into the table including if a null
+// has been inserted.
+func (s *Uint8MemoTable) Size() int {
+ sz := int(s.tbl.size)
+ if _, ok := s.GetNull(); ok {
+ sz++
+ }
+ return sz
+}
+
+// GetNull returns the index of an inserted null or KeyNotFound along with a bool
+// that will be true if found and false if not.
+func (s *Uint8MemoTable) GetNull() (int, bool) {
+ return int(s.nullIdx), s.nullIdx != KeyNotFound
+}
+
+// GetOrInsertNull will return the index of the null entry or insert a null entry
+// if one currently doesn't exist. The found value will be true if there was already
+// a null in the table, and false if it inserted one.
+func (s *Uint8MemoTable) GetOrInsertNull() (idx int, found bool) {
+ idx, found = s.GetNull()
+ if !found {
+ idx = s.Size()
+ s.nullIdx = int32(idx)
+ }
+ return
+}
+
+// CopyValues will copy the values from the memo table out into the passed in slice
+// which must be of the appropriate type.
+func (s *Uint8MemoTable) CopyValues(out interface{}) {
+ s.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset is like CopyValues but only copies a subset of values starting
+// at the provided start index
+func (s *Uint8MemoTable) CopyValuesSubset(start int, out interface{}) {
+ s.tbl.CopyValuesSubset(start, out.([]uint8))
+}
+
+func (s *Uint8MemoTable) WriteOut(out []byte) {
+ s.tbl.CopyValues(arrow.Uint8Traits.CastFromBytes(out))
+}
+
+func (s *Uint8MemoTable) WriteOutSubset(start int, out []byte) {
+ s.tbl.CopyValuesSubset(start, arrow.Uint8Traits.CastFromBytes(out))
+}
+
+func (s *Uint8MemoTable) WriteOutLE(out []byte) {
+ s.tbl.WriteOut(out)
+}
+
+func (s *Uint8MemoTable) WriteOutSubsetLE(start int, out []byte) {
+ s.tbl.WriteOutSubset(start, out)
+}
+
+// Get returns the index of the requested value in the hash table or KeyNotFound
+// along with a boolean indicating if it was found or not.
+func (s *Uint8MemoTable) Get(val interface{}) (int, bool) {
+
+ h := hashInt(uint64(val.(uint8)), 0)
+ if e, ok := s.tbl.Lookup(h, func(v uint8) bool { return val.(uint8) == v }); ok {
+ return int(e.payload.memoIdx), ok
+ }
+ return KeyNotFound, false
+}
+
+// GetOrInsert will return the index of the specified value in the table, or insert the
+// value into the table and return the new index. found indicates whether or not it already
+// existed in the table (true) or was inserted by this call (false).
+func (s *Uint8MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) {
+
+ h := hashInt(uint64(val.(uint8)), 0)
+ e, ok := s.tbl.Lookup(h, func(v uint8) bool {
+ return val.(uint8) == v
+ })
+
+ if ok {
+ idx = int(e.payload.memoIdx)
+ found = true
+ } else {
+ idx = s.Size()
+ s.tbl.Insert(e, h, val.(uint8), int32(idx))
+ }
+ return
+}
+
+type payloadInt16 struct {
+ val int16
+ memoIdx int32
+}
+
+type entryInt16 struct {
+ h uint64
+ payload payloadInt16
+}
+
+func (e entryInt16) Valid() bool { return e.h != sentinel }
+
+// Int16HashTable is a hashtable specifically for int16 that
+// is utilized with the MemoTable to generalize interactions for easier
+// implementation of dictionaries without losing performance.
+type Int16HashTable struct {
+ cap uint64
+ capMask uint64
+ size uint64
+
+ entries []entryInt16
+}
+
+// NewInt16HashTable returns a new hash table for int16 values
+// initialized with the passed in capacity or 32 whichever is larger.
+func NewInt16HashTable(cap uint64) *Int16HashTable {
+ initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ ret := &Int16HashTable{cap: initCap, capMask: initCap - 1, size: 0}
+ ret.entries = make([]entryInt16, initCap)
+ return ret
+}
+
+// Reset drops all of the values in this hash table and re-initializes it
+// with the specified initial capacity as if by calling New, but without having
+// to reallocate the object.
+func (h *Int16HashTable) Reset(cap uint64) {
+ h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ h.capMask = h.cap - 1
+ h.size = 0
+ h.entries = make([]entryInt16, h.cap)
+}
+
+// CopyValues is used for copying the values out of the hash table into the
+// passed in slice, in the order that they were first inserted
+func (h *Int16HashTable) CopyValues(out []int16) {
+ h.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset copies a subset of the values in the hashtable out, starting
+// with the value at start, in the order that they were inserted.
+func (h *Int16HashTable) CopyValuesSubset(start int, out []int16) {
+ h.VisitEntries(func(e *entryInt16) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ out[idx] = e.payload.val
+ }
+ })
+}
+
+func (h *Int16HashTable) WriteOut(out []byte) {
+ h.WriteOutSubset(0, out)
+}
+
+func (h *Int16HashTable) WriteOutSubset(start int, out []byte) {
+ data := arrow.Int16Traits.CastFromBytes(out)
+ h.VisitEntries(func(e *entryInt16) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ data[idx] = utils.ToLEInt16(e.payload.val)
+ }
+ })
+}
+
+func (h *Int16HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap }
+
+func (Int16HashTable) fixHash(v uint64) uint64 {
+ if v == sentinel {
+ return 42
+ }
+ return v
+}
+
+// Lookup retrieves the entry for a given hash value assuming it's payload value returns
+// true when passed to the cmp func. Returns a pointer to the entry for the given hash value,
+// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false.
+func (h *Int16HashTable) Lookup(v uint64, cmp func(int16) bool) (*entryInt16, bool) {
+ idx, ok := h.lookup(v, h.capMask, cmp)
+ return &h.entries[idx], ok
+}
+
+func (h *Int16HashTable) lookup(v uint64, szMask uint64, cmp func(int16) bool) (uint64, bool) {
+ const perturbShift uint8 = 5
+
+ var (
+ idx uint64
+ perturb uint64
+ e *entryInt16
+ )
+
+ v = h.fixHash(v)
+ idx = v & szMask
+ perturb = (v >> uint64(perturbShift)) + 1
+
+ for {
+ e = &h.entries[idx]
+ if e.h == v && cmp(e.payload.val) {
+ return idx, true
+ }
+
+ if e.h == sentinel {
+ return idx, false
+ }
+
+ // perturbation logic inspired from CPython's set/dict object
+ // the goal is that all 64 bits of unmasked hash value eventually
+ // participate int he probing sequence, to minimize clustering
+ idx = (idx + perturb) & szMask
+ perturb = (perturb >> uint64(perturbShift)) + 1
+ }
+}
+
+func (h *Int16HashTable) upsize(newcap uint64) error {
+ newMask := newcap - 1
+
+ oldEntries := h.entries
+ h.entries = make([]entryInt16, newcap)
+ for _, e := range oldEntries {
+ if e.Valid() {
+ idx, _ := h.lookup(e.h, newMask, func(int16) bool { return false })
+ h.entries[idx] = e
+ }
+ }
+ h.cap = newcap
+ h.capMask = newMask
+ return nil
+}
+
+// Insert updates the given entry with the provided hash value, payload value and memo index.
+// The entry pointer must have been retrieved via lookup in order to actually insert properly.
+func (h *Int16HashTable) Insert(e *entryInt16, v uint64, val int16, memoIdx int32) error {
+ e.h = h.fixHash(v)
+ e.payload.val = val
+ e.payload.memoIdx = memoIdx
+ h.size++
+
+ if h.needUpsize() {
+ h.upsize(h.cap * uint64(loadFactor) * 2)
+ }
+ return nil
+}
+
+// VisitEntries will call the passed in function on each *valid* entry in the hash table,
+// a valid entry being one which has had a value inserted into it.
+func (h *Int16HashTable) VisitEntries(visit func(*entryInt16)) {
+ for _, e := range h.entries {
+ if e.Valid() {
+ visit(&e)
+ }
+ }
+}
+
+// Int16MemoTable is a wrapper over the appropriate hashtable to provide an interface
+// conforming to the MemoTable interface defined in the encoding package for general interactions
+// regarding dictionaries.
+type Int16MemoTable struct {
+ tbl *Int16HashTable
+ nullIdx int32
+}
+
+// NewInt16MemoTable returns a new memotable with num entries pre-allocated to reduce further
+// allocations when inserting.
+func NewInt16MemoTable(num int64) *Int16MemoTable {
+ return &Int16MemoTable{tbl: NewInt16HashTable(uint64(num)), nullIdx: KeyNotFound}
+}
+
+func (Int16MemoTable) TypeTraits() TypeTraits {
+ return arrow.Int16Traits
+}
+
+// Reset allows this table to be re-used by dumping all the data currently in the table.
+func (s *Int16MemoTable) Reset() {
+ s.tbl.Reset(32)
+ s.nullIdx = KeyNotFound
+}
+
+// Size returns the current number of inserted elements into the table including if a null
+// has been inserted.
+func (s *Int16MemoTable) Size() int {
+ sz := int(s.tbl.size)
+ if _, ok := s.GetNull(); ok {
+ sz++
+ }
+ return sz
+}
+
+// GetNull returns the index of an inserted null or KeyNotFound along with a bool
+// that will be true if found and false if not.
+func (s *Int16MemoTable) GetNull() (int, bool) {
+ return int(s.nullIdx), s.nullIdx != KeyNotFound
+}
+
+// GetOrInsertNull will return the index of the null entry or insert a null entry
+// if one currently doesn't exist. The found value will be true if there was already
+// a null in the table, and false if it inserted one.
+func (s *Int16MemoTable) GetOrInsertNull() (idx int, found bool) {
+ idx, found = s.GetNull()
+ if !found {
+ idx = s.Size()
+ s.nullIdx = int32(idx)
+ }
+ return
+}
+
+// CopyValues will copy the values from the memo table out into the passed in slice
+// which must be of the appropriate type.
+func (s *Int16MemoTable) CopyValues(out interface{}) {
+ s.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset is like CopyValues but only copies a subset of values starting
+// at the provided start index
+func (s *Int16MemoTable) CopyValuesSubset(start int, out interface{}) {
+ s.tbl.CopyValuesSubset(start, out.([]int16))
+}
+
+func (s *Int16MemoTable) WriteOut(out []byte) {
+ s.tbl.CopyValues(arrow.Int16Traits.CastFromBytes(out))
+}
+
+func (s *Int16MemoTable) WriteOutSubset(start int, out []byte) {
+ s.tbl.CopyValuesSubset(start, arrow.Int16Traits.CastFromBytes(out))
+}
+
+func (s *Int16MemoTable) WriteOutLE(out []byte) {
+ s.tbl.WriteOut(out)
+}
+
+func (s *Int16MemoTable) WriteOutSubsetLE(start int, out []byte) {
+ s.tbl.WriteOutSubset(start, out)
+}
+
+// Get returns the index of the requested value in the hash table or KeyNotFound
+// along with a boolean indicating if it was found or not.
+func (s *Int16MemoTable) Get(val interface{}) (int, bool) {
+
+ h := hashInt(uint64(val.(int16)), 0)
+ if e, ok := s.tbl.Lookup(h, func(v int16) bool { return val.(int16) == v }); ok {
+ return int(e.payload.memoIdx), ok
+ }
+ return KeyNotFound, false
+}
+
+// GetOrInsert will return the index of the specified value in the table, or insert the
+// value into the table and return the new index. found indicates whether or not it already
+// existed in the table (true) or was inserted by this call (false).
+func (s *Int16MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) {
+
+ h := hashInt(uint64(val.(int16)), 0)
+ e, ok := s.tbl.Lookup(h, func(v int16) bool {
+ return val.(int16) == v
+ })
+
+ if ok {
+ idx = int(e.payload.memoIdx)
+ found = true
+ } else {
+ idx = s.Size()
+ s.tbl.Insert(e, h, val.(int16), int32(idx))
+ }
+ return
+}
+
+type payloadUint16 struct {
+ val uint16
+ memoIdx int32
+}
+
+type entryUint16 struct {
+ h uint64
+ payload payloadUint16
+}
+
+func (e entryUint16) Valid() bool { return e.h != sentinel }
+
+// Uint16HashTable is a hashtable specifically for uint16 that
+// is utilized with the MemoTable to generalize interactions for easier
+// implementation of dictionaries without losing performance.
+type Uint16HashTable struct {
+ cap uint64
+ capMask uint64
+ size uint64
+
+ entries []entryUint16
+}
+
+// NewUint16HashTable returns a new hash table for uint16 values
+// initialized with the passed in capacity or 32 whichever is larger.
+func NewUint16HashTable(cap uint64) *Uint16HashTable {
+ initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ ret := &Uint16HashTable{cap: initCap, capMask: initCap - 1, size: 0}
+ ret.entries = make([]entryUint16, initCap)
+ return ret
+}
+
+// Reset drops all of the values in this hash table and re-initializes it
+// with the specified initial capacity as if by calling New, but without having
+// to reallocate the object.
+func (h *Uint16HashTable) Reset(cap uint64) {
+ h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ h.capMask = h.cap - 1
+ h.size = 0
+ h.entries = make([]entryUint16, h.cap)
+}
+
+// CopyValues is used for copying the values out of the hash table into the
+// passed in slice, in the order that they were first inserted
+func (h *Uint16HashTable) CopyValues(out []uint16) {
+ h.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset copies a subset of the values in the hashtable out, starting
+// with the value at start, in the order that they were inserted.
+func (h *Uint16HashTable) CopyValuesSubset(start int, out []uint16) {
+ h.VisitEntries(func(e *entryUint16) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ out[idx] = e.payload.val
+ }
+ })
+}
+
+func (h *Uint16HashTable) WriteOut(out []byte) {
+ h.WriteOutSubset(0, out)
+}
+
+func (h *Uint16HashTable) WriteOutSubset(start int, out []byte) {
+ data := arrow.Uint16Traits.CastFromBytes(out)
+ h.VisitEntries(func(e *entryUint16) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ data[idx] = utils.ToLEUint16(e.payload.val)
+ }
+ })
+}
+
+func (h *Uint16HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap }
+
+func (Uint16HashTable) fixHash(v uint64) uint64 {
+ if v == sentinel {
+ return 42
+ }
+ return v
+}
+
+// Lookup retrieves the entry for a given hash value assuming it's payload value returns
+// true when passed to the cmp func. Returns a pointer to the entry for the given hash value,
+// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false.
+func (h *Uint16HashTable) Lookup(v uint64, cmp func(uint16) bool) (*entryUint16, bool) {
+ idx, ok := h.lookup(v, h.capMask, cmp)
+ return &h.entries[idx], ok
+}
+
+func (h *Uint16HashTable) lookup(v uint64, szMask uint64, cmp func(uint16) bool) (uint64, bool) {
+ const perturbShift uint8 = 5
+
+ var (
+ idx uint64
+ perturb uint64
+ e *entryUint16
+ )
+
+ v = h.fixHash(v)
+ idx = v & szMask
+ perturb = (v >> uint64(perturbShift)) + 1
+
+ for {
+ e = &h.entries[idx]
+ if e.h == v && cmp(e.payload.val) {
+ return idx, true
+ }
+
+ if e.h == sentinel {
+ return idx, false
+ }
+
+ // perturbation logic inspired from CPython's set/dict object
+ // the goal is that all 64 bits of unmasked hash value eventually
+ // participate int he probing sequence, to minimize clustering
+ idx = (idx + perturb) & szMask
+ perturb = (perturb >> uint64(perturbShift)) + 1
+ }
+}
+
+func (h *Uint16HashTable) upsize(newcap uint64) error {
+ newMask := newcap - 1
+
+ oldEntries := h.entries
+ h.entries = make([]entryUint16, newcap)
+ for _, e := range oldEntries {
+ if e.Valid() {
+ idx, _ := h.lookup(e.h, newMask, func(uint16) bool { return false })
+ h.entries[idx] = e
+ }
+ }
+ h.cap = newcap
+ h.capMask = newMask
+ return nil
+}
+
+// Insert updates the given entry with the provided hash value, payload value and memo index.
+// The entry pointer must have been retrieved via lookup in order to actually insert properly.
+func (h *Uint16HashTable) Insert(e *entryUint16, v uint64, val uint16, memoIdx int32) error {
+ e.h = h.fixHash(v)
+ e.payload.val = val
+ e.payload.memoIdx = memoIdx
+ h.size++
+
+ if h.needUpsize() {
+ h.upsize(h.cap * uint64(loadFactor) * 2)
+ }
+ return nil
+}
+
+// VisitEntries will call the passed in function on each *valid* entry in the hash table,
+// a valid entry being one which has had a value inserted into it.
+func (h *Uint16HashTable) VisitEntries(visit func(*entryUint16)) {
+ for _, e := range h.entries {
+ if e.Valid() {
+ visit(&e)
+ }
+ }
+}
+
+// Uint16MemoTable is a wrapper over the appropriate hashtable to provide an interface
+// conforming to the MemoTable interface defined in the encoding package for general interactions
+// regarding dictionaries.
+type Uint16MemoTable struct {
+ tbl *Uint16HashTable
+ nullIdx int32
+}
+
+// NewUint16MemoTable returns a new memotable with num entries pre-allocated to reduce further
+// allocations when inserting.
+func NewUint16MemoTable(num int64) *Uint16MemoTable {
+ return &Uint16MemoTable{tbl: NewUint16HashTable(uint64(num)), nullIdx: KeyNotFound}
+}
+
+func (Uint16MemoTable) TypeTraits() TypeTraits {
+ return arrow.Uint16Traits
+}
+
+// Reset allows this table to be re-used by dumping all the data currently in the table.
+func (s *Uint16MemoTable) Reset() {
+ s.tbl.Reset(32)
+ s.nullIdx = KeyNotFound
+}
+
+// Size returns the current number of inserted elements into the table including if a null
+// has been inserted.
+func (s *Uint16MemoTable) Size() int {
+ sz := int(s.tbl.size)
+ if _, ok := s.GetNull(); ok {
+ sz++
+ }
+ return sz
+}
+
+// GetNull returns the index of an inserted null or KeyNotFound along with a bool
+// that will be true if found and false if not.
+func (s *Uint16MemoTable) GetNull() (int, bool) {
+ return int(s.nullIdx), s.nullIdx != KeyNotFound
+}
+
+// GetOrInsertNull will return the index of the null entry or insert a null entry
+// if one currently doesn't exist. The found value will be true if there was already
+// a null in the table, and false if it inserted one.
+func (s *Uint16MemoTable) GetOrInsertNull() (idx int, found bool) {
+ idx, found = s.GetNull()
+ if !found {
+ idx = s.Size()
+ s.nullIdx = int32(idx)
+ }
+ return
+}
+
+// CopyValues will copy the values from the memo table out into the passed in slice
+// which must be of the appropriate type.
+func (s *Uint16MemoTable) CopyValues(out interface{}) {
+ s.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset is like CopyValues but only copies a subset of values starting
+// at the provided start index
+func (s *Uint16MemoTable) CopyValuesSubset(start int, out interface{}) {
+ s.tbl.CopyValuesSubset(start, out.([]uint16))
+}
+
+func (s *Uint16MemoTable) WriteOut(out []byte) {
+ s.tbl.CopyValues(arrow.Uint16Traits.CastFromBytes(out))
+}
+
+func (s *Uint16MemoTable) WriteOutSubset(start int, out []byte) {
+ s.tbl.CopyValuesSubset(start, arrow.Uint16Traits.CastFromBytes(out))
+}
+
+func (s *Uint16MemoTable) WriteOutLE(out []byte) {
+ s.tbl.WriteOut(out)
+}
+
+func (s *Uint16MemoTable) WriteOutSubsetLE(start int, out []byte) {
+ s.tbl.WriteOutSubset(start, out)
+}
+
+// Get returns the index of the requested value in the hash table or KeyNotFound
+// along with a boolean indicating if it was found or not.
+func (s *Uint16MemoTable) Get(val interface{}) (int, bool) {
+
+ h := hashInt(uint64(val.(uint16)), 0)
+ if e, ok := s.tbl.Lookup(h, func(v uint16) bool { return val.(uint16) == v }); ok {
+ return int(e.payload.memoIdx), ok
+ }
+ return KeyNotFound, false
+}
+
+// GetOrInsert will return the index of the specified value in the table, or insert the
+// value into the table and return the new index. found indicates whether or not it already
+// existed in the table (true) or was inserted by this call (false).
+func (s *Uint16MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) {
+
+ h := hashInt(uint64(val.(uint16)), 0)
+ e, ok := s.tbl.Lookup(h, func(v uint16) bool {
+ return val.(uint16) == v
+ })
+
+ if ok {
+ idx = int(e.payload.memoIdx)
+ found = true
+ } else {
+ idx = s.Size()
+ s.tbl.Insert(e, h, val.(uint16), int32(idx))
+ }
+ return
+}
+
+type payloadInt32 struct {
+ val int32
+ memoIdx int32
+}
+
+type entryInt32 struct {
+ h uint64
+ payload payloadInt32
+}
+
+func (e entryInt32) Valid() bool { return e.h != sentinel }
+
+// Int32HashTable is a hashtable specifically for int32 that
+// is utilized with the MemoTable to generalize interactions for easier
+// implementation of dictionaries without losing performance.
+type Int32HashTable struct {
+ cap uint64
+ capMask uint64
+ size uint64
+
+ entries []entryInt32
+}
+
+// NewInt32HashTable returns a new hash table for int32 values
+// initialized with the passed in capacity or 32 whichever is larger.
+func NewInt32HashTable(cap uint64) *Int32HashTable {
+ initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ ret := &Int32HashTable{cap: initCap, capMask: initCap - 1, size: 0}
+ ret.entries = make([]entryInt32, initCap)
+ return ret
+}
+
+// Reset drops all of the values in this hash table and re-initializes it
+// with the specified initial capacity as if by calling New, but without having
+// to reallocate the object.
+func (h *Int32HashTable) Reset(cap uint64) {
+ h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ h.capMask = h.cap - 1
+ h.size = 0
+ h.entries = make([]entryInt32, h.cap)
+}
+
+// CopyValues is used for copying the values out of the hash table into the
+// passed in slice, in the order that they were first inserted
+func (h *Int32HashTable) CopyValues(out []int32) {
+ h.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset copies a subset of the values in the hashtable out, starting
+// with the value at start, in the order that they were inserted.
+func (h *Int32HashTable) CopyValuesSubset(start int, out []int32) {
+ h.VisitEntries(func(e *entryInt32) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ out[idx] = e.payload.val
+ }
+ })
+}
+
+func (h *Int32HashTable) WriteOut(out []byte) {
+ h.WriteOutSubset(0, out)
+}
+
+func (h *Int32HashTable) WriteOutSubset(start int, out []byte) {
+ data := arrow.Int32Traits.CastFromBytes(out)
+ h.VisitEntries(func(e *entryInt32) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ data[idx] = utils.ToLEInt32(e.payload.val)
+ }
+ })
+}
+
+func (h *Int32HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap }
+
+func (Int32HashTable) fixHash(v uint64) uint64 {
+ if v == sentinel {
+ return 42
+ }
+ return v
+}
+
+// Lookup retrieves the entry for a given hash value assuming it's payload value returns
+// true when passed to the cmp func. Returns a pointer to the entry for the given hash value,
+// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false.
+func (h *Int32HashTable) Lookup(v uint64, cmp func(int32) bool) (*entryInt32, bool) {
+ idx, ok := h.lookup(v, h.capMask, cmp)
+ return &h.entries[idx], ok
+}
+
+func (h *Int32HashTable) lookup(v uint64, szMask uint64, cmp func(int32) bool) (uint64, bool) {
+ const perturbShift uint8 = 5
+
+ var (
+ idx uint64
+ perturb uint64
+ e *entryInt32
+ )
+
+ v = h.fixHash(v)
+ idx = v & szMask
+ perturb = (v >> uint64(perturbShift)) + 1
+
+ for {
+ e = &h.entries[idx]
+ if e.h == v && cmp(e.payload.val) {
+ return idx, true
+ }
+
+ if e.h == sentinel {
+ return idx, false
+ }
+
+ // perturbation logic inspired from CPython's set/dict object
+ // the goal is that all 64 bits of unmasked hash value eventually
+ // participate int he probing sequence, to minimize clustering
+ idx = (idx + perturb) & szMask
+ perturb = (perturb >> uint64(perturbShift)) + 1
+ }
+}
+
+func (h *Int32HashTable) upsize(newcap uint64) error {
+ newMask := newcap - 1
+
+ oldEntries := h.entries
+ h.entries = make([]entryInt32, newcap)
+ for _, e := range oldEntries {
+ if e.Valid() {
+ idx, _ := h.lookup(e.h, newMask, func(int32) bool { return false })
+ h.entries[idx] = e
+ }
+ }
+ h.cap = newcap
+ h.capMask = newMask
+ return nil
+}
+
+// Insert updates the given entry with the provided hash value, payload value and memo index.
+// The entry pointer must have been retrieved via lookup in order to actually insert properly.
+func (h *Int32HashTable) Insert(e *entryInt32, v uint64, val int32, memoIdx int32) error {
+ e.h = h.fixHash(v)
+ e.payload.val = val
+ e.payload.memoIdx = memoIdx
+ h.size++
+
+ if h.needUpsize() {
+ h.upsize(h.cap * uint64(loadFactor) * 2)
+ }
+ return nil
+}
+
+// VisitEntries will call the passed in function on each *valid* entry in the hash table,
+// a valid entry being one which has had a value inserted into it.
+func (h *Int32HashTable) VisitEntries(visit func(*entryInt32)) {
+ for _, e := range h.entries {
+ if e.Valid() {
+ visit(&e)
+ }
+ }
+}
+
+// Int32MemoTable is a wrapper over the appropriate hashtable to provide an interface
+// conforming to the MemoTable interface defined in the encoding package for general interactions
+// regarding dictionaries.
+type Int32MemoTable struct {
+ tbl *Int32HashTable
+ nullIdx int32
+}
+
+// NewInt32MemoTable returns a new memotable with num entries pre-allocated to reduce further
+// allocations when inserting.
+func NewInt32MemoTable(num int64) *Int32MemoTable {
+ return &Int32MemoTable{tbl: NewInt32HashTable(uint64(num)), nullIdx: KeyNotFound}
+}
+
+func (Int32MemoTable) TypeTraits() TypeTraits {
+ return arrow.Int32Traits
+}
+
+// Reset allows this table to be re-used by dumping all the data currently in the table.
+func (s *Int32MemoTable) Reset() {
+ s.tbl.Reset(32)
+ s.nullIdx = KeyNotFound
+}
+
+// Size returns the current number of inserted elements into the table including if a null
+// has been inserted.
+func (s *Int32MemoTable) Size() int {
+ sz := int(s.tbl.size)
+ if _, ok := s.GetNull(); ok {
+ sz++
+ }
+ return sz
+}
+
+// GetNull returns the index of an inserted null or KeyNotFound along with a bool
+// that will be true if found and false if not.
+func (s *Int32MemoTable) GetNull() (int, bool) {
+ return int(s.nullIdx), s.nullIdx != KeyNotFound
+}
+
+// GetOrInsertNull will return the index of the null entry or insert a null entry
+// if one currently doesn't exist. The found value will be true if there was already
+// a null in the table, and false if it inserted one.
+func (s *Int32MemoTable) GetOrInsertNull() (idx int, found bool) {
+ idx, found = s.GetNull()
+ if !found {
+ idx = s.Size()
+ s.nullIdx = int32(idx)
+ }
+ return
+}
+
+// CopyValues will copy the values from the memo table out into the passed in slice
+// which must be of the appropriate type.
+func (s *Int32MemoTable) CopyValues(out interface{}) {
+ s.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset is like CopyValues but only copies a subset of values starting
+// at the provided start index
+func (s *Int32MemoTable) CopyValuesSubset(start int, out interface{}) {
+ s.tbl.CopyValuesSubset(start, out.([]int32))
+}
+
+func (s *Int32MemoTable) WriteOut(out []byte) {
+ s.tbl.CopyValues(arrow.Int32Traits.CastFromBytes(out))
+}
+
+func (s *Int32MemoTable) WriteOutSubset(start int, out []byte) {
+ s.tbl.CopyValuesSubset(start, arrow.Int32Traits.CastFromBytes(out))
+}
+
+func (s *Int32MemoTable) WriteOutLE(out []byte) {
+ s.tbl.WriteOut(out)
+}
+
+func (s *Int32MemoTable) WriteOutSubsetLE(start int, out []byte) {
+ s.tbl.WriteOutSubset(start, out)
+}
+
+// Get returns the index of the requested value in the hash table or KeyNotFound
+// along with a boolean indicating if it was found or not.
+func (s *Int32MemoTable) Get(val interface{}) (int, bool) {
+
+ h := hashInt(uint64(val.(int32)), 0)
+ if e, ok := s.tbl.Lookup(h, func(v int32) bool { return val.(int32) == v }); ok {
+ return int(e.payload.memoIdx), ok
+ }
+ return KeyNotFound, false
+}
+
+// GetOrInsert will return the index of the specified value in the table, or insert the
+// value into the table and return the new index. found indicates whether or not it already
+// existed in the table (true) or was inserted by this call (false).
+func (s *Int32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) {
+
+ h := hashInt(uint64(val.(int32)), 0)
+ e, ok := s.tbl.Lookup(h, func(v int32) bool {
+ return val.(int32) == v
+ })
+
+ if ok {
+ idx = int(e.payload.memoIdx)
+ found = true
+ } else {
+ idx = s.Size()
+ s.tbl.Insert(e, h, val.(int32), int32(idx))
+ }
+ return
+}
+
+type payloadInt64 struct {
+ val int64
+ memoIdx int32
+}
+
+type entryInt64 struct {
+ h uint64
+ payload payloadInt64
+}
+
+func (e entryInt64) Valid() bool { return e.h != sentinel }
+
+// Int64HashTable is a hashtable specifically for int64 that
+// is utilized with the MemoTable to generalize interactions for easier
+// implementation of dictionaries without losing performance.
+type Int64HashTable struct {
+ cap uint64
+ capMask uint64
+ size uint64
+
+ entries []entryInt64
+}
+
+// NewInt64HashTable returns a new hash table for int64 values
+// initialized with the passed in capacity or 32 whichever is larger.
+func NewInt64HashTable(cap uint64) *Int64HashTable {
+ initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ ret := &Int64HashTable{cap: initCap, capMask: initCap - 1, size: 0}
+ ret.entries = make([]entryInt64, initCap)
+ return ret
+}
+
+// Reset drops all of the values in this hash table and re-initializes it
+// with the specified initial capacity as if by calling New, but without having
+// to reallocate the object.
+func (h *Int64HashTable) Reset(cap uint64) {
+ h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ h.capMask = h.cap - 1
+ h.size = 0
+ h.entries = make([]entryInt64, h.cap)
+}
+
+// CopyValues is used for copying the values out of the hash table into the
+// passed in slice, in the order that they were first inserted
+func (h *Int64HashTable) CopyValues(out []int64) {
+ h.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset copies a subset of the values in the hashtable out, starting
+// with the value at start, in the order that they were inserted.
+func (h *Int64HashTable) CopyValuesSubset(start int, out []int64) {
+ h.VisitEntries(func(e *entryInt64) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ out[idx] = e.payload.val
+ }
+ })
+}
+
+func (h *Int64HashTable) WriteOut(out []byte) {
+ h.WriteOutSubset(0, out)
+}
+
+func (h *Int64HashTable) WriteOutSubset(start int, out []byte) {
+ data := arrow.Int64Traits.CastFromBytes(out)
+ h.VisitEntries(func(e *entryInt64) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ data[idx] = utils.ToLEInt64(e.payload.val)
+ }
+ })
+}
+
+func (h *Int64HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap }
+
+func (Int64HashTable) fixHash(v uint64) uint64 {
+ if v == sentinel {
+ return 42
+ }
+ return v
+}
+
+// Lookup retrieves the entry for a given hash value assuming it's payload value returns
+// true when passed to the cmp func. Returns a pointer to the entry for the given hash value,
+// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false.
+func (h *Int64HashTable) Lookup(v uint64, cmp func(int64) bool) (*entryInt64, bool) {
+ idx, ok := h.lookup(v, h.capMask, cmp)
+ return &h.entries[idx], ok
+}
+
+func (h *Int64HashTable) lookup(v uint64, szMask uint64, cmp func(int64) bool) (uint64, bool) {
+ const perturbShift uint8 = 5
+
+ var (
+ idx uint64
+ perturb uint64
+ e *entryInt64
+ )
+
+ v = h.fixHash(v)
+ idx = v & szMask
+ perturb = (v >> uint64(perturbShift)) + 1
+
+ for {
+ e = &h.entries[idx]
+ if e.h == v && cmp(e.payload.val) {
+ return idx, true
+ }
+
+ if e.h == sentinel {
+ return idx, false
+ }
+
+ // perturbation logic inspired from CPython's set/dict object
+ // the goal is that all 64 bits of unmasked hash value eventually
+ // participate int he probing sequence, to minimize clustering
+ idx = (idx + perturb) & szMask
+ perturb = (perturb >> uint64(perturbShift)) + 1
+ }
+}
+
+func (h *Int64HashTable) upsize(newcap uint64) error {
+ newMask := newcap - 1
+
+ oldEntries := h.entries
+ h.entries = make([]entryInt64, newcap)
+ for _, e := range oldEntries {
+ if e.Valid() {
+ idx, _ := h.lookup(e.h, newMask, func(int64) bool { return false })
+ h.entries[idx] = e
+ }
+ }
+ h.cap = newcap
+ h.capMask = newMask
+ return nil
+}
+
+// Insert updates the given entry with the provided hash value, payload value and memo index.
+// The entry pointer must have been retrieved via lookup in order to actually insert properly.
+func (h *Int64HashTable) Insert(e *entryInt64, v uint64, val int64, memoIdx int32) error {
+ e.h = h.fixHash(v)
+ e.payload.val = val
+ e.payload.memoIdx = memoIdx
+ h.size++
+
+ if h.needUpsize() {
+ h.upsize(h.cap * uint64(loadFactor) * 2)
+ }
+ return nil
+}
+
+// VisitEntries will call the passed in function on each *valid* entry in the hash table,
+// a valid entry being one which has had a value inserted into it.
+func (h *Int64HashTable) VisitEntries(visit func(*entryInt64)) {
+ for _, e := range h.entries {
+ if e.Valid() {
+ visit(&e)
+ }
+ }
+}
+
+// Int64MemoTable is a wrapper over the appropriate hashtable to provide an interface
+// conforming to the MemoTable interface defined in the encoding package for general interactions
+// regarding dictionaries.
+type Int64MemoTable struct {
+ tbl *Int64HashTable
+ nullIdx int32
+}
+
+// NewInt64MemoTable returns a new memotable with num entries pre-allocated to reduce further
+// allocations when inserting.
+func NewInt64MemoTable(num int64) *Int64MemoTable {
+ return &Int64MemoTable{tbl: NewInt64HashTable(uint64(num)), nullIdx: KeyNotFound}
+}
+
+func (Int64MemoTable) TypeTraits() TypeTraits {
+ return arrow.Int64Traits
+}
+
+// Reset allows this table to be re-used by dumping all the data currently in the table.
+func (s *Int64MemoTable) Reset() {
+ s.tbl.Reset(32)
+ s.nullIdx = KeyNotFound
+}
+
+// Size returns the current number of inserted elements into the table including if a null
+// has been inserted.
+func (s *Int64MemoTable) Size() int {
+ sz := int(s.tbl.size)
+ if _, ok := s.GetNull(); ok {
+ sz++
+ }
+ return sz
+}
+
+// GetNull returns the index of an inserted null or KeyNotFound along with a bool
+// that will be true if found and false if not.
+func (s *Int64MemoTable) GetNull() (int, bool) {
+ return int(s.nullIdx), s.nullIdx != KeyNotFound
+}
+
+// GetOrInsertNull will return the index of the null entry or insert a null entry
+// if one currently doesn't exist. The found value will be true if there was already
+// a null in the table, and false if it inserted one.
+func (s *Int64MemoTable) GetOrInsertNull() (idx int, found bool) {
+ idx, found = s.GetNull()
+ if !found {
+ idx = s.Size()
+ s.nullIdx = int32(idx)
+ }
+ return
+}
+
+// CopyValues will copy the values from the memo table out into the passed in slice
+// which must be of the appropriate type.
+func (s *Int64MemoTable) CopyValues(out interface{}) {
+ s.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset is like CopyValues but only copies a subset of values starting
+// at the provided start index
+func (s *Int64MemoTable) CopyValuesSubset(start int, out interface{}) {
+ s.tbl.CopyValuesSubset(start, out.([]int64))
+}
+
+func (s *Int64MemoTable) WriteOut(out []byte) {
+ s.tbl.CopyValues(arrow.Int64Traits.CastFromBytes(out))
+}
+
+func (s *Int64MemoTable) WriteOutSubset(start int, out []byte) {
+ s.tbl.CopyValuesSubset(start, arrow.Int64Traits.CastFromBytes(out))
+}
+
+func (s *Int64MemoTable) WriteOutLE(out []byte) {
+ s.tbl.WriteOut(out)
+}
+
+func (s *Int64MemoTable) WriteOutSubsetLE(start int, out []byte) {
+ s.tbl.WriteOutSubset(start, out)
+}
+
+// Get returns the index of the requested value in the hash table or KeyNotFound
+// along with a boolean indicating if it was found or not.
+func (s *Int64MemoTable) Get(val interface{}) (int, bool) {
+
+ h := hashInt(uint64(val.(int64)), 0)
+ if e, ok := s.tbl.Lookup(h, func(v int64) bool { return val.(int64) == v }); ok {
+ return int(e.payload.memoIdx), ok
+ }
+ return KeyNotFound, false
+}
+
+// GetOrInsert will return the index of the specified value in the table, or insert the
+// value into the table and return the new index. found indicates whether or not it already
+// existed in the table (true) or was inserted by this call (false).
+func (s *Int64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) {
+
+ h := hashInt(uint64(val.(int64)), 0)
+ e, ok := s.tbl.Lookup(h, func(v int64) bool {
+ return val.(int64) == v
+ })
+
+ if ok {
+ idx = int(e.payload.memoIdx)
+ found = true
+ } else {
+ idx = s.Size()
+ s.tbl.Insert(e, h, val.(int64), int32(idx))
+ }
+ return
+}
+
+type payloadUint32 struct {
+ val uint32
+ memoIdx int32
+}
+
+type entryUint32 struct {
+ h uint64
+ payload payloadUint32
+}
+
+func (e entryUint32) Valid() bool { return e.h != sentinel }
+
+// Uint32HashTable is a hashtable specifically for uint32 that
+// is utilized with the MemoTable to generalize interactions for easier
+// implementation of dictionaries without losing performance.
+type Uint32HashTable struct {
+ cap uint64
+ capMask uint64
+ size uint64
+
+ entries []entryUint32
+}
+
+// NewUint32HashTable returns a new hash table for uint32 values
+// initialized with the passed in capacity or 32 whichever is larger.
+func NewUint32HashTable(cap uint64) *Uint32HashTable {
+ initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ ret := &Uint32HashTable{cap: initCap, capMask: initCap - 1, size: 0}
+ ret.entries = make([]entryUint32, initCap)
+ return ret
+}
+
+// Reset drops all of the values in this hash table and re-initializes it
+// with the specified initial capacity as if by calling New, but without having
+// to reallocate the object.
+func (h *Uint32HashTable) Reset(cap uint64) {
+ h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ h.capMask = h.cap - 1
+ h.size = 0
+ h.entries = make([]entryUint32, h.cap)
+}
+
+// CopyValues is used for copying the values out of the hash table into the
+// passed in slice, in the order that they were first inserted
+func (h *Uint32HashTable) CopyValues(out []uint32) {
+ h.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset copies a subset of the values in the hashtable out, starting
+// with the value at start, in the order that they were inserted.
+func (h *Uint32HashTable) CopyValuesSubset(start int, out []uint32) {
+ h.VisitEntries(func(e *entryUint32) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ out[idx] = e.payload.val
+ }
+ })
+}
+
+func (h *Uint32HashTable) WriteOut(out []byte) {
+ h.WriteOutSubset(0, out)
+}
+
+func (h *Uint32HashTable) WriteOutSubset(start int, out []byte) {
+ data := arrow.Uint32Traits.CastFromBytes(out)
+ h.VisitEntries(func(e *entryUint32) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ data[idx] = utils.ToLEUint32(e.payload.val)
+ }
+ })
+}
+
+func (h *Uint32HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap }
+
+func (Uint32HashTable) fixHash(v uint64) uint64 {
+ if v == sentinel {
+ return 42
+ }
+ return v
+}
+
+// Lookup retrieves the entry for a given hash value assuming it's payload value returns
+// true when passed to the cmp func. Returns a pointer to the entry for the given hash value,
+// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false.
+func (h *Uint32HashTable) Lookup(v uint64, cmp func(uint32) bool) (*entryUint32, bool) {
+ idx, ok := h.lookup(v, h.capMask, cmp)
+ return &h.entries[idx], ok
+}
+
+func (h *Uint32HashTable) lookup(v uint64, szMask uint64, cmp func(uint32) bool) (uint64, bool) {
+ const perturbShift uint8 = 5
+
+ var (
+ idx uint64
+ perturb uint64
+ e *entryUint32
+ )
+
+ v = h.fixHash(v)
+ idx = v & szMask
+ perturb = (v >> uint64(perturbShift)) + 1
+
+ for {
+ e = &h.entries[idx]
+ if e.h == v && cmp(e.payload.val) {
+ return idx, true
+ }
+
+ if e.h == sentinel {
+ return idx, false
+ }
+
+ // perturbation logic inspired from CPython's set/dict object
+ // the goal is that all 64 bits of unmasked hash value eventually
+ // participate int he probing sequence, to minimize clustering
+ idx = (idx + perturb) & szMask
+ perturb = (perturb >> uint64(perturbShift)) + 1
+ }
+}
+
+func (h *Uint32HashTable) upsize(newcap uint64) error {
+ newMask := newcap - 1
+
+ oldEntries := h.entries
+ h.entries = make([]entryUint32, newcap)
+ for _, e := range oldEntries {
+ if e.Valid() {
+ idx, _ := h.lookup(e.h, newMask, func(uint32) bool { return false })
+ h.entries[idx] = e
+ }
+ }
+ h.cap = newcap
+ h.capMask = newMask
+ return nil
+}
+
+// Insert updates the given entry with the provided hash value, payload value and memo index.
+// The entry pointer must have been retrieved via lookup in order to actually insert properly.
+func (h *Uint32HashTable) Insert(e *entryUint32, v uint64, val uint32, memoIdx int32) error {
+ e.h = h.fixHash(v)
+ e.payload.val = val
+ e.payload.memoIdx = memoIdx
+ h.size++
+
+ if h.needUpsize() {
+ h.upsize(h.cap * uint64(loadFactor) * 2)
+ }
+ return nil
+}
+
+// VisitEntries will call the passed in function on each *valid* entry in the hash table,
+// a valid entry being one which has had a value inserted into it.
+func (h *Uint32HashTable) VisitEntries(visit func(*entryUint32)) {
+ for _, e := range h.entries {
+ if e.Valid() {
+ visit(&e)
+ }
+ }
+}
+
+// Uint32MemoTable is a wrapper over the appropriate hashtable to provide an interface
+// conforming to the MemoTable interface defined in the encoding package for general interactions
+// regarding dictionaries.
+type Uint32MemoTable struct {
+ tbl *Uint32HashTable
+ nullIdx int32
+}
+
+// NewUint32MemoTable returns a new memotable with num entries pre-allocated to reduce further
+// allocations when inserting.
+func NewUint32MemoTable(num int64) *Uint32MemoTable {
+ return &Uint32MemoTable{tbl: NewUint32HashTable(uint64(num)), nullIdx: KeyNotFound}
+}
+
+func (Uint32MemoTable) TypeTraits() TypeTraits {
+ return arrow.Uint32Traits
+}
+
+// Reset allows this table to be re-used by dumping all the data currently in the table.
+func (s *Uint32MemoTable) Reset() {
+ s.tbl.Reset(32)
+ s.nullIdx = KeyNotFound
+}
+
+// Size returns the current number of inserted elements into the table including if a null
+// has been inserted.
+func (s *Uint32MemoTable) Size() int {
+ sz := int(s.tbl.size)
+ if _, ok := s.GetNull(); ok {
+ sz++
+ }
+ return sz
+}
+
+// GetNull returns the index of an inserted null or KeyNotFound along with a bool
+// that will be true if found and false if not.
+func (s *Uint32MemoTable) GetNull() (int, bool) {
+ return int(s.nullIdx), s.nullIdx != KeyNotFound
+}
+
+// GetOrInsertNull will return the index of the null entry or insert a null entry
+// if one currently doesn't exist. The found value will be true if there was already
+// a null in the table, and false if it inserted one.
+func (s *Uint32MemoTable) GetOrInsertNull() (idx int, found bool) {
+ idx, found = s.GetNull()
+ if !found {
+ idx = s.Size()
+ s.nullIdx = int32(idx)
+ }
+ return
+}
+
+// CopyValues will copy the values from the memo table out into the passed in slice
+// which must be of the appropriate type.
+func (s *Uint32MemoTable) CopyValues(out interface{}) {
+ s.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset is like CopyValues but only copies a subset of values starting
+// at the provided start index
+func (s *Uint32MemoTable) CopyValuesSubset(start int, out interface{}) {
+ s.tbl.CopyValuesSubset(start, out.([]uint32))
+}
+
+func (s *Uint32MemoTable) WriteOut(out []byte) {
+ s.tbl.CopyValues(arrow.Uint32Traits.CastFromBytes(out))
+}
+
+func (s *Uint32MemoTable) WriteOutSubset(start int, out []byte) {
+ s.tbl.CopyValuesSubset(start, arrow.Uint32Traits.CastFromBytes(out))
+}
+
+func (s *Uint32MemoTable) WriteOutLE(out []byte) {
+ s.tbl.WriteOut(out)
+}
+
+func (s *Uint32MemoTable) WriteOutSubsetLE(start int, out []byte) {
+ s.tbl.WriteOutSubset(start, out)
+}
+
+// Get returns the index of the requested value in the hash table or KeyNotFound
+// along with a boolean indicating if it was found or not.
+func (s *Uint32MemoTable) Get(val interface{}) (int, bool) {
+
+ h := hashInt(uint64(val.(uint32)), 0)
+ if e, ok := s.tbl.Lookup(h, func(v uint32) bool { return val.(uint32) == v }); ok {
+ return int(e.payload.memoIdx), ok
+ }
+ return KeyNotFound, false
+}
+
+// GetOrInsert will return the index of the specified value in the table, or insert the
+// value into the table and return the new index. found indicates whether or not it already
+// existed in the table (true) or was inserted by this call (false).
+func (s *Uint32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) {
+
+ h := hashInt(uint64(val.(uint32)), 0)
+ e, ok := s.tbl.Lookup(h, func(v uint32) bool {
+ return val.(uint32) == v
+ })
+
+ if ok {
+ idx = int(e.payload.memoIdx)
+ found = true
+ } else {
+ idx = s.Size()
+ s.tbl.Insert(e, h, val.(uint32), int32(idx))
+ }
+ return
+}
+
+type payloadUint64 struct {
+ val uint64
+ memoIdx int32
+}
+
+type entryUint64 struct {
+ h uint64
+ payload payloadUint64
+}
+
+func (e entryUint64) Valid() bool { return e.h != sentinel }
+
+// Uint64HashTable is a hashtable specifically for uint64 that
+// is utilized with the MemoTable to generalize interactions for easier
+// implementation of dictionaries without losing performance.
+type Uint64HashTable struct {
+ cap uint64
+ capMask uint64
+ size uint64
+
+ entries []entryUint64
+}
+
+// NewUint64HashTable returns a new hash table for uint64 values
+// initialized with the passed in capacity or 32 whichever is larger.
+func NewUint64HashTable(cap uint64) *Uint64HashTable {
+ initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ ret := &Uint64HashTable{cap: initCap, capMask: initCap - 1, size: 0}
+ ret.entries = make([]entryUint64, initCap)
+ return ret
+}
+
+// Reset drops all of the values in this hash table and re-initializes it
+// with the specified initial capacity as if by calling New, but without having
+// to reallocate the object.
+func (h *Uint64HashTable) Reset(cap uint64) {
+ h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ h.capMask = h.cap - 1
+ h.size = 0
+ h.entries = make([]entryUint64, h.cap)
+}
+
+// CopyValues is used for copying the values out of the hash table into the
+// passed in slice, in the order that they were first inserted
+func (h *Uint64HashTable) CopyValues(out []uint64) {
+ h.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset copies a subset of the values in the hashtable out, starting
+// with the value at start, in the order that they were inserted.
+func (h *Uint64HashTable) CopyValuesSubset(start int, out []uint64) {
+ h.VisitEntries(func(e *entryUint64) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ out[idx] = e.payload.val
+ }
+ })
+}
+
+func (h *Uint64HashTable) WriteOut(out []byte) {
+ h.WriteOutSubset(0, out)
+}
+
+func (h *Uint64HashTable) WriteOutSubset(start int, out []byte) {
+ data := arrow.Uint64Traits.CastFromBytes(out)
+ h.VisitEntries(func(e *entryUint64) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ data[idx] = utils.ToLEUint64(e.payload.val)
+ }
+ })
+}
+
+func (h *Uint64HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap }
+
+func (Uint64HashTable) fixHash(v uint64) uint64 {
+ if v == sentinel {
+ return 42
+ }
+ return v
+}
+
+// Lookup retrieves the entry for a given hash value assuming it's payload value returns
+// true when passed to the cmp func. Returns a pointer to the entry for the given hash value,
+// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false.
+func (h *Uint64HashTable) Lookup(v uint64, cmp func(uint64) bool) (*entryUint64, bool) {
+ idx, ok := h.lookup(v, h.capMask, cmp)
+ return &h.entries[idx], ok
+}
+
+func (h *Uint64HashTable) lookup(v uint64, szMask uint64, cmp func(uint64) bool) (uint64, bool) {
+ const perturbShift uint8 = 5
+
+ var (
+ idx uint64
+ perturb uint64
+ e *entryUint64
+ )
+
+ v = h.fixHash(v)
+ idx = v & szMask
+ perturb = (v >> uint64(perturbShift)) + 1
+
+ for {
+ e = &h.entries[idx]
+ if e.h == v && cmp(e.payload.val) {
+ return idx, true
+ }
+
+ if e.h == sentinel {
+ return idx, false
+ }
+
+ // perturbation logic inspired from CPython's set/dict object
+ // the goal is that all 64 bits of unmasked hash value eventually
+ // participate int he probing sequence, to minimize clustering
+ idx = (idx + perturb) & szMask
+ perturb = (perturb >> uint64(perturbShift)) + 1
+ }
+}
+
+func (h *Uint64HashTable) upsize(newcap uint64) error {
+ newMask := newcap - 1
+
+ oldEntries := h.entries
+ h.entries = make([]entryUint64, newcap)
+ for _, e := range oldEntries {
+ if e.Valid() {
+ idx, _ := h.lookup(e.h, newMask, func(uint64) bool { return false })
+ h.entries[idx] = e
+ }
+ }
+ h.cap = newcap
+ h.capMask = newMask
+ return nil
+}
+
+// Insert updates the given entry with the provided hash value, payload value and memo index.
+// The entry pointer must have been retrieved via lookup in order to actually insert properly.
+func (h *Uint64HashTable) Insert(e *entryUint64, v uint64, val uint64, memoIdx int32) error {
+ e.h = h.fixHash(v)
+ e.payload.val = val
+ e.payload.memoIdx = memoIdx
+ h.size++
+
+ if h.needUpsize() {
+ h.upsize(h.cap * uint64(loadFactor) * 2)
+ }
+ return nil
+}
+
+// VisitEntries will call the passed in function on each *valid* entry in the hash table,
+// a valid entry being one which has had a value inserted into it.
+func (h *Uint64HashTable) VisitEntries(visit func(*entryUint64)) {
+ for _, e := range h.entries {
+ if e.Valid() {
+ visit(&e)
+ }
+ }
+}
+
+// Uint64MemoTable is a wrapper over the appropriate hashtable to provide an interface
+// conforming to the MemoTable interface defined in the encoding package for general interactions
+// regarding dictionaries.
+type Uint64MemoTable struct {
+ tbl *Uint64HashTable
+ nullIdx int32
+}
+
+// NewUint64MemoTable returns a new memotable with num entries pre-allocated to reduce further
+// allocations when inserting.
+func NewUint64MemoTable(num int64) *Uint64MemoTable {
+ return &Uint64MemoTable{tbl: NewUint64HashTable(uint64(num)), nullIdx: KeyNotFound}
+}
+
+func (Uint64MemoTable) TypeTraits() TypeTraits {
+ return arrow.Uint64Traits
+}
+
+// Reset allows this table to be re-used by dumping all the data currently in the table.
+func (s *Uint64MemoTable) Reset() {
+ s.tbl.Reset(32)
+ s.nullIdx = KeyNotFound
+}
+
+// Size returns the current number of inserted elements into the table including if a null
+// has been inserted.
+func (s *Uint64MemoTable) Size() int {
+ sz := int(s.tbl.size)
+ if _, ok := s.GetNull(); ok {
+ sz++
+ }
+ return sz
+}
+
+// GetNull returns the index of an inserted null or KeyNotFound along with a bool
+// that will be true if found and false if not.
+func (s *Uint64MemoTable) GetNull() (int, bool) {
+ return int(s.nullIdx), s.nullIdx != KeyNotFound
+}
+
+// GetOrInsertNull will return the index of the null entry or insert a null entry
+// if one currently doesn't exist. The found value will be true if there was already
+// a null in the table, and false if it inserted one.
+func (s *Uint64MemoTable) GetOrInsertNull() (idx int, found bool) {
+ idx, found = s.GetNull()
+ if !found {
+ idx = s.Size()
+ s.nullIdx = int32(idx)
+ }
+ return
+}
+
+// CopyValues will copy the values from the memo table out into the passed in slice
+// which must be of the appropriate type.
+func (s *Uint64MemoTable) CopyValues(out interface{}) {
+ s.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset is like CopyValues but only copies a subset of values starting
+// at the provided start index
+func (s *Uint64MemoTable) CopyValuesSubset(start int, out interface{}) {
+ s.tbl.CopyValuesSubset(start, out.([]uint64))
+}
+
+func (s *Uint64MemoTable) WriteOut(out []byte) {
+ s.tbl.CopyValues(arrow.Uint64Traits.CastFromBytes(out))
+}
+
+func (s *Uint64MemoTable) WriteOutSubset(start int, out []byte) {
+ s.tbl.CopyValuesSubset(start, arrow.Uint64Traits.CastFromBytes(out))
+}
+
+func (s *Uint64MemoTable) WriteOutLE(out []byte) {
+ s.tbl.WriteOut(out)
+}
+
+func (s *Uint64MemoTable) WriteOutSubsetLE(start int, out []byte) {
+ s.tbl.WriteOutSubset(start, out)
+}
+
+// Get returns the index of the requested value in the hash table or KeyNotFound
+// along with a boolean indicating if it was found or not.
+func (s *Uint64MemoTable) Get(val interface{}) (int, bool) {
+
+ h := hashInt(uint64(val.(uint64)), 0)
+ if e, ok := s.tbl.Lookup(h, func(v uint64) bool { return val.(uint64) == v }); ok {
+ return int(e.payload.memoIdx), ok
+ }
+ return KeyNotFound, false
+}
+
+// GetOrInsert will return the index of the specified value in the table, or insert the
+// value into the table and return the new index. found indicates whether or not it already
+// existed in the table (true) or was inserted by this call (false).
+func (s *Uint64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) {
+
+ h := hashInt(uint64(val.(uint64)), 0)
+ e, ok := s.tbl.Lookup(h, func(v uint64) bool {
+ return val.(uint64) == v
+ })
+
+ if ok {
+ idx = int(e.payload.memoIdx)
+ found = true
+ } else {
+ idx = s.Size()
+ s.tbl.Insert(e, h, val.(uint64), int32(idx))
+ }
+ return
+}
+
+type payloadFloat32 struct {
+ val float32
+ memoIdx int32
+}
+
+type entryFloat32 struct {
+ h uint64
+ payload payloadFloat32
+}
+
+func (e entryFloat32) Valid() bool { return e.h != sentinel }
+
+// Float32HashTable is a hashtable specifically for float32 that
+// is utilized with the MemoTable to generalize interactions for easier
+// implementation of dictionaries without losing performance.
+type Float32HashTable struct {
+ cap uint64
+ capMask uint64
+ size uint64
+
+ entries []entryFloat32
+}
+
+// NewFloat32HashTable returns a new hash table for float32 values
+// initialized with the passed in capacity or 32 whichever is larger.
+func NewFloat32HashTable(cap uint64) *Float32HashTable {
+ initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ ret := &Float32HashTable{cap: initCap, capMask: initCap - 1, size: 0}
+ ret.entries = make([]entryFloat32, initCap)
+ return ret
+}
+
+// Reset drops all of the values in this hash table and re-initializes it
+// with the specified initial capacity as if by calling New, but without having
+// to reallocate the object.
+func (h *Float32HashTable) Reset(cap uint64) {
+ h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ h.capMask = h.cap - 1
+ h.size = 0
+ h.entries = make([]entryFloat32, h.cap)
+}
+
+// CopyValues is used for copying the values out of the hash table into the
+// passed in slice, in the order that they were first inserted
+func (h *Float32HashTable) CopyValues(out []float32) {
+ h.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset copies a subset of the values in the hashtable out, starting
+// with the value at start, in the order that they were inserted.
+func (h *Float32HashTable) CopyValuesSubset(start int, out []float32) {
+ h.VisitEntries(func(e *entryFloat32) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ out[idx] = e.payload.val
+ }
+ })
+}
+
+func (h *Float32HashTable) WriteOut(out []byte) {
+ h.WriteOutSubset(0, out)
+}
+
+func (h *Float32HashTable) WriteOutSubset(start int, out []byte) {
+ data := arrow.Float32Traits.CastFromBytes(out)
+ h.VisitEntries(func(e *entryFloat32) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ data[idx] = utils.ToLEFloat32(e.payload.val)
+ }
+ })
+}
+
+func (h *Float32HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap }
+
+func (Float32HashTable) fixHash(v uint64) uint64 {
+ if v == sentinel {
+ return 42
+ }
+ return v
+}
+
+// Lookup retrieves the entry for a given hash value assuming it's payload value returns
+// true when passed to the cmp func. Returns a pointer to the entry for the given hash value,
+// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false.
+func (h *Float32HashTable) Lookup(v uint64, cmp func(float32) bool) (*entryFloat32, bool) {
+ idx, ok := h.lookup(v, h.capMask, cmp)
+ return &h.entries[idx], ok
+}
+
+func (h *Float32HashTable) lookup(v uint64, szMask uint64, cmp func(float32) bool) (uint64, bool) {
+ const perturbShift uint8 = 5
+
+ var (
+ idx uint64
+ perturb uint64
+ e *entryFloat32
+ )
+
+ v = h.fixHash(v)
+ idx = v & szMask
+ perturb = (v >> uint64(perturbShift)) + 1
+
+ for {
+ e = &h.entries[idx]
+ if e.h == v && cmp(e.payload.val) {
+ return idx, true
+ }
+
+ if e.h == sentinel {
+ return idx, false
+ }
+
+ // perturbation logic inspired from CPython's set/dict object
+ // the goal is that all 64 bits of unmasked hash value eventually
+ // participate int he probing sequence, to minimize clustering
+ idx = (idx + perturb) & szMask
+ perturb = (perturb >> uint64(perturbShift)) + 1
+ }
+}
+
+func (h *Float32HashTable) upsize(newcap uint64) error {
+ newMask := newcap - 1
+
+ oldEntries := h.entries
+ h.entries = make([]entryFloat32, newcap)
+ for _, e := range oldEntries {
+ if e.Valid() {
+ idx, _ := h.lookup(e.h, newMask, func(float32) bool { return false })
+ h.entries[idx] = e
+ }
+ }
+ h.cap = newcap
+ h.capMask = newMask
+ return nil
+}
+
+// Insert updates the given entry with the provided hash value, payload value and memo index.
+// The entry pointer must have been retrieved via lookup in order to actually insert properly.
+func (h *Float32HashTable) Insert(e *entryFloat32, v uint64, val float32, memoIdx int32) error {
+ e.h = h.fixHash(v)
+ e.payload.val = val
+ e.payload.memoIdx = memoIdx
+ h.size++
+
+ if h.needUpsize() {
+ h.upsize(h.cap * uint64(loadFactor) * 2)
+ }
+ return nil
+}
+
+// VisitEntries will call the passed in function on each *valid* entry in the hash table,
+// a valid entry being one which has had a value inserted into it.
+func (h *Float32HashTable) VisitEntries(visit func(*entryFloat32)) {
+ for _, e := range h.entries {
+ if e.Valid() {
+ visit(&e)
+ }
+ }
+}
+
+// Float32MemoTable is a wrapper over the appropriate hashtable to provide an interface
+// conforming to the MemoTable interface defined in the encoding package for general interactions
+// regarding dictionaries.
+type Float32MemoTable struct {
+ tbl *Float32HashTable
+ nullIdx int32
+}
+
+// NewFloat32MemoTable returns a new memotable with num entries pre-allocated to reduce further
+// allocations when inserting.
+func NewFloat32MemoTable(num int64) *Float32MemoTable {
+ return &Float32MemoTable{tbl: NewFloat32HashTable(uint64(num)), nullIdx: KeyNotFound}
+}
+
+func (Float32MemoTable) TypeTraits() TypeTraits {
+ return arrow.Float32Traits
+}
+
+// Reset allows this table to be re-used by dumping all the data currently in the table.
+func (s *Float32MemoTable) Reset() {
+ s.tbl.Reset(32)
+ s.nullIdx = KeyNotFound
+}
+
+// Size returns the current number of inserted elements into the table including if a null
+// has been inserted.
+func (s *Float32MemoTable) Size() int {
+ sz := int(s.tbl.size)
+ if _, ok := s.GetNull(); ok {
+ sz++
+ }
+ return sz
+}
+
+// GetNull returns the index of an inserted null or KeyNotFound along with a bool
+// that will be true if found and false if not.
+func (s *Float32MemoTable) GetNull() (int, bool) {
+ return int(s.nullIdx), s.nullIdx != KeyNotFound
+}
+
+// GetOrInsertNull will return the index of the null entry or insert a null entry
+// if one currently doesn't exist. The found value will be true if there was already
+// a null in the table, and false if it inserted one.
+func (s *Float32MemoTable) GetOrInsertNull() (idx int, found bool) {
+ idx, found = s.GetNull()
+ if !found {
+ idx = s.Size()
+ s.nullIdx = int32(idx)
+ }
+ return
+}
+
+// CopyValues will copy the values from the memo table out into the passed in slice
+// which must be of the appropriate type.
+func (s *Float32MemoTable) CopyValues(out interface{}) {
+ s.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset is like CopyValues but only copies a subset of values starting
+// at the provided start index
+func (s *Float32MemoTable) CopyValuesSubset(start int, out interface{}) {
+ s.tbl.CopyValuesSubset(start, out.([]float32))
+}
+
+func (s *Float32MemoTable) WriteOut(out []byte) {
+ s.tbl.CopyValues(arrow.Float32Traits.CastFromBytes(out))
+}
+
+func (s *Float32MemoTable) WriteOutSubset(start int, out []byte) {
+ s.tbl.CopyValuesSubset(start, arrow.Float32Traits.CastFromBytes(out))
+}
+
+func (s *Float32MemoTable) WriteOutLE(out []byte) {
+ s.tbl.WriteOut(out)
+}
+
+func (s *Float32MemoTable) WriteOutSubsetLE(start int, out []byte) {
+ s.tbl.WriteOutSubset(start, out)
+}
+
+// Get returns the index of the requested value in the hash table or KeyNotFound
+// along with a boolean indicating if it was found or not.
+func (s *Float32MemoTable) Get(val interface{}) (int, bool) {
+ var cmp func(float32) bool
+
+ if math.IsNaN(float64(val.(float32))) {
+ cmp = isNan32Cmp
+ // use consistent internal bit pattern for NaN regardless of the pattern
+ // that is passed to us. NaN is NaN is NaN
+ val = float32(math.NaN())
+ } else {
+ cmp = func(v float32) bool { return val.(float32) == v }
+ }
+
+ h := hashFloat32(val.(float32), 0)
+ if e, ok := s.tbl.Lookup(h, cmp); ok {
+ return int(e.payload.memoIdx), ok
+ }
+ return KeyNotFound, false
+}
+
+// GetOrInsert will return the index of the specified value in the table, or insert the
+// value into the table and return the new index. found indicates whether or not it already
+// existed in the table (true) or was inserted by this call (false).
+func (s *Float32MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) {
+
+ var cmp func(float32) bool
+
+ if math.IsNaN(float64(val.(float32))) {
+ cmp = isNan32Cmp
+ // use consistent internal bit pattern for NaN regardless of the pattern
+ // that is passed to us. NaN is NaN is NaN
+ val = float32(math.NaN())
+ } else {
+ cmp = func(v float32) bool { return val.(float32) == v }
+ }
+
+ h := hashFloat32(val.(float32), 0)
+ e, ok := s.tbl.Lookup(h, cmp)
+
+ if ok {
+ idx = int(e.payload.memoIdx)
+ found = true
+ } else {
+ idx = s.Size()
+ s.tbl.Insert(e, h, val.(float32), int32(idx))
+ }
+ return
+}
+
+type payloadFloat64 struct {
+ val float64
+ memoIdx int32
+}
+
+type entryFloat64 struct {
+ h uint64
+ payload payloadFloat64
+}
+
+func (e entryFloat64) Valid() bool { return e.h != sentinel }
+
+// Float64HashTable is a hashtable specifically for float64 that
+// is utilized with the MemoTable to generalize interactions for easier
+// implementation of dictionaries without losing performance.
+type Float64HashTable struct {
+ cap uint64
+ capMask uint64
+ size uint64
+
+ entries []entryFloat64
+}
+
+// NewFloat64HashTable returns a new hash table for float64 values
+// initialized with the passed in capacity or 32 whichever is larger.
+func NewFloat64HashTable(cap uint64) *Float64HashTable {
+ initCap := uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ ret := &Float64HashTable{cap: initCap, capMask: initCap - 1, size: 0}
+ ret.entries = make([]entryFloat64, initCap)
+ return ret
+}
+
+// Reset drops all of the values in this hash table and re-initializes it
+// with the specified initial capacity as if by calling New, but without having
+// to reallocate the object.
+func (h *Float64HashTable) Reset(cap uint64) {
+ h.cap = uint64(bitutil.NextPowerOf2(int(max(cap, 32))))
+ h.capMask = h.cap - 1
+ h.size = 0
+ h.entries = make([]entryFloat64, h.cap)
+}
+
+// CopyValues is used for copying the values out of the hash table into the
+// passed in slice, in the order that they were first inserted
+func (h *Float64HashTable) CopyValues(out []float64) {
+ h.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset copies a subset of the values in the hashtable out, starting
+// with the value at start, in the order that they were inserted.
+func (h *Float64HashTable) CopyValuesSubset(start int, out []float64) {
+ h.VisitEntries(func(e *entryFloat64) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ out[idx] = e.payload.val
+ }
+ })
+}
+
+func (h *Float64HashTable) WriteOut(out []byte) {
+ h.WriteOutSubset(0, out)
+}
+
+func (h *Float64HashTable) WriteOutSubset(start int, out []byte) {
+ data := arrow.Float64Traits.CastFromBytes(out)
+ h.VisitEntries(func(e *entryFloat64) {
+ idx := e.payload.memoIdx - int32(start)
+ if idx >= 0 {
+ data[idx] = utils.ToLEFloat64(e.payload.val)
+ }
+ })
+}
+
+func (h *Float64HashTable) needUpsize() bool { return h.size*uint64(loadFactor) >= h.cap }
+
+func (Float64HashTable) fixHash(v uint64) uint64 {
+ if v == sentinel {
+ return 42
+ }
+ return v
+}
+
+// Lookup retrieves the entry for a given hash value assuming it's payload value returns
+// true when passed to the cmp func. Returns a pointer to the entry for the given hash value,
+// and a boolean as to whether it was found. It is not safe to use the pointer if the bool is false.
+func (h *Float64HashTable) Lookup(v uint64, cmp func(float64) bool) (*entryFloat64, bool) {
+ idx, ok := h.lookup(v, h.capMask, cmp)
+ return &h.entries[idx], ok
+}
+
+func (h *Float64HashTable) lookup(v uint64, szMask uint64, cmp func(float64) bool) (uint64, bool) {
+ const perturbShift uint8 = 5
+
+ var (
+ idx uint64
+ perturb uint64
+ e *entryFloat64
+ )
+
+ v = h.fixHash(v)
+ idx = v & szMask
+ perturb = (v >> uint64(perturbShift)) + 1
+
+ for {
+ e = &h.entries[idx]
+ if e.h == v && cmp(e.payload.val) {
+ return idx, true
+ }
+
+ if e.h == sentinel {
+ return idx, false
+ }
+
+ // perturbation logic inspired from CPython's set/dict object
+ // the goal is that all 64 bits of unmasked hash value eventually
+ // participate int he probing sequence, to minimize clustering
+ idx = (idx + perturb) & szMask
+ perturb = (perturb >> uint64(perturbShift)) + 1
+ }
+}
+
+func (h *Float64HashTable) upsize(newcap uint64) error {
+ newMask := newcap - 1
+
+ oldEntries := h.entries
+ h.entries = make([]entryFloat64, newcap)
+ for _, e := range oldEntries {
+ if e.Valid() {
+ idx, _ := h.lookup(e.h, newMask, func(float64) bool { return false })
+ h.entries[idx] = e
+ }
+ }
+ h.cap = newcap
+ h.capMask = newMask
+ return nil
+}
+
+// Insert updates the given entry with the provided hash value, payload value and memo index.
+// The entry pointer must have been retrieved via lookup in order to actually insert properly.
+func (h *Float64HashTable) Insert(e *entryFloat64, v uint64, val float64, memoIdx int32) error {
+ e.h = h.fixHash(v)
+ e.payload.val = val
+ e.payload.memoIdx = memoIdx
+ h.size++
+
+ if h.needUpsize() {
+ h.upsize(h.cap * uint64(loadFactor) * 2)
+ }
+ return nil
+}
+
+// VisitEntries will call the passed in function on each *valid* entry in the hash table,
+// a valid entry being one which has had a value inserted into it.
+func (h *Float64HashTable) VisitEntries(visit func(*entryFloat64)) {
+ for _, e := range h.entries {
+ if e.Valid() {
+ visit(&e)
+ }
+ }
+}
+
+// Float64MemoTable is a wrapper over the appropriate hashtable to provide an interface
+// conforming to the MemoTable interface defined in the encoding package for general interactions
+// regarding dictionaries.
+type Float64MemoTable struct {
+ tbl *Float64HashTable
+ nullIdx int32
+}
+
+// NewFloat64MemoTable returns a new memotable with num entries pre-allocated to reduce further
+// allocations when inserting.
+func NewFloat64MemoTable(num int64) *Float64MemoTable {
+ return &Float64MemoTable{tbl: NewFloat64HashTable(uint64(num)), nullIdx: KeyNotFound}
+}
+
+func (Float64MemoTable) TypeTraits() TypeTraits {
+ return arrow.Float64Traits
+}
+
+// Reset allows this table to be re-used by dumping all the data currently in the table.
+func (s *Float64MemoTable) Reset() {
+ s.tbl.Reset(32)
+ s.nullIdx = KeyNotFound
+}
+
+// Size returns the current number of inserted elements into the table including if a null
+// has been inserted.
+func (s *Float64MemoTable) Size() int {
+ sz := int(s.tbl.size)
+ if _, ok := s.GetNull(); ok {
+ sz++
+ }
+ return sz
+}
+
+// GetNull returns the index of an inserted null or KeyNotFound along with a bool
+// that will be true if found and false if not.
+func (s *Float64MemoTable) GetNull() (int, bool) {
+ return int(s.nullIdx), s.nullIdx != KeyNotFound
+}
+
+// GetOrInsertNull will return the index of the null entry or insert a null entry
+// if one currently doesn't exist. The found value will be true if there was already
+// a null in the table, and false if it inserted one.
+func (s *Float64MemoTable) GetOrInsertNull() (idx int, found bool) {
+ idx, found = s.GetNull()
+ if !found {
+ idx = s.Size()
+ s.nullIdx = int32(idx)
+ }
+ return
+}
+
+// CopyValues will copy the values from the memo table out into the passed in slice
+// which must be of the appropriate type.
+func (s *Float64MemoTable) CopyValues(out interface{}) {
+ s.CopyValuesSubset(0, out)
+}
+
+// CopyValuesSubset is like CopyValues but only copies a subset of values starting
+// at the provided start index
+func (s *Float64MemoTable) CopyValuesSubset(start int, out interface{}) {
+ s.tbl.CopyValuesSubset(start, out.([]float64))
+}
+
+func (s *Float64MemoTable) WriteOut(out []byte) {
+ s.tbl.CopyValues(arrow.Float64Traits.CastFromBytes(out))
+}
+
+func (s *Float64MemoTable) WriteOutSubset(start int, out []byte) {
+ s.tbl.CopyValuesSubset(start, arrow.Float64Traits.CastFromBytes(out))
+}
+
+func (s *Float64MemoTable) WriteOutLE(out []byte) {
+ s.tbl.WriteOut(out)
+}
+
+func (s *Float64MemoTable) WriteOutSubsetLE(start int, out []byte) {
+ s.tbl.WriteOutSubset(start, out)
+}
+
+// Get returns the index of the requested value in the hash table or KeyNotFound
+// along with a boolean indicating if it was found or not.
+func (s *Float64MemoTable) Get(val interface{}) (int, bool) {
+ var cmp func(float64) bool
+ if math.IsNaN(val.(float64)) {
+ cmp = math.IsNaN
+ // use consistent internal bit pattern for NaN regardless of the pattern
+ // that is passed to us. NaN is NaN is NaN
+ val = math.NaN()
+ } else {
+ cmp = func(v float64) bool { return val.(float64) == v }
+ }
+
+ h := hashFloat64(val.(float64), 0)
+ if e, ok := s.tbl.Lookup(h, cmp); ok {
+ return int(e.payload.memoIdx), ok
+ }
+ return KeyNotFound, false
+}
+
+// GetOrInsert will return the index of the specified value in the table, or insert the
+// value into the table and return the new index. found indicates whether or not it already
+// existed in the table (true) or was inserted by this call (false).
+func (s *Float64MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) {
+
+ var cmp func(float64) bool
+ if math.IsNaN(val.(float64)) {
+ cmp = math.IsNaN
+ // use consistent internal bit pattern for NaN regardless of the pattern
+ // that is passed to us. NaN is NaN is NaN
+ val = math.NaN()
+ } else {
+ cmp = func(v float64) bool { return val.(float64) == v }
+ }
+
+ h := hashFloat64(val.(float64), 0)
+ e, ok := s.tbl.Lookup(h, cmp)
+
+ if ok {
+ idx = int(e.payload.memoIdx)
+ found = true
+ } else {
+ idx = s.Size()
+ s.tbl.Insert(e, h, val.(float64), int32(idx))
+ }
+ return
+}
diff --git a/go/parquet/internal/hashing/xxh3_memo_table.gen.go.tmpl b/go/internal/hashing/xxh3_memo_table.gen.go.tmpl
similarity index 94%
rename from go/parquet/internal/hashing/xxh3_memo_table.gen.go.tmpl
rename to go/internal/hashing/xxh3_memo_table.gen.go.tmpl
index 3920732028..1355469d2c 100644
--- a/go/parquet/internal/hashing/xxh3_memo_table.gen.go.tmpl
+++ b/go/internal/hashing/xxh3_memo_table.gen.go.tmpl
@@ -18,7 +18,7 @@ package hashing
import (
"github.com/apache/arrow/go/v8/arrow/bitutil"
- "github.com/apache/arrow/go/v8/parquet/internal/utils"
+ "github.com/apache/arrow/go/v8/internal/utils"
)
{{range .In}}
@@ -90,7 +90,11 @@ func (h *{{.Name}}HashTable) WriteOutSubset(start int, out []byte) {
h.VisitEntries(func(e *entry{{.Name}}) {
idx := e.payload.memoIdx - int32(start)
if idx >= 0 {
+{{if and (ne .Name "Int8") (ne .Name "Uint8") -}}
data[idx] = utils.ToLE{{.Name}}(e.payload.val)
+{{else -}}
+ data[idx] = e.payload.val
+{{end -}}
}
})
}
@@ -197,6 +201,10 @@ func New{{.Name}}MemoTable(num int64) *{{.Name}}MemoTable {
return &{{.Name}}MemoTable{tbl: New{{.Name}}HashTable(uint64(num)), nullIdx: KeyNotFound}
}
+func ({{.Name}}MemoTable) TypeTraits() TypeTraits {
+ return arrow.{{.Name}}Traits
+}
+
// Reset allows this table to be re-used by dumping all the data currently in the table.
func (s *{{.Name}}MemoTable) Reset() {
s.tbl.Reset(32)
@@ -244,17 +252,25 @@ func (s *{{.Name}}MemoTable) CopyValuesSubset(start int, out interface{}) {
}
func (s *{{.Name}}MemoTable) WriteOut(out []byte) {
- s.tbl.WriteOut(out)
+ s.tbl.CopyValues(arrow.{{.Name}}Traits.CastFromBytes(out))
}
func (s *{{.Name}}MemoTable) WriteOutSubset(start int, out []byte) {
+ s.tbl.CopyValuesSubset(start, arrow.{{.Name}}Traits.CastFromBytes(out))
+}
+
+func (s *{{.Name}}MemoTable) WriteOutLE(out []byte) {
+ s.tbl.WriteOut(out)
+}
+
+func (s *{{.Name}}MemoTable) WriteOutSubsetLE(start int, out []byte) {
s.tbl.WriteOutSubset(start, out)
}
// Get returns the index of the requested value in the hash table or KeyNotFound
// along with a boolean indicating if it was found or not.
func (s *{{.Name}}MemoTable) Get(val interface{}) (int, bool) {
-{{if or (eq .Name "Int32") (eq .Name "Int64") }}
+{{if and (ne .Name "Float32") (ne .Name "Float64") }}
h := hashInt(uint64(val.({{.name}})), 0)
if e, ok := s.tbl.Lookup(h, func(v {{.name}}) bool { return val.({{.name}}) == v }); ok {
{{ else -}}
@@ -288,7 +304,7 @@ func (s *{{.Name}}MemoTable) Get(val interface{}) (int, bool) {
// value into the table and return the new index. found indicates whether or not it already
// existed in the table (true) or was inserted by this call (false).
func (s *{{.Name}}MemoTable) GetOrInsert(val interface{}) (idx int, found bool, err error) {
- {{if or (eq .Name "Int32") (eq .Name "Int64") }}
+ {{if and (ne .Name "Float32") (ne .Name "Float64") }}
h := hashInt(uint64(val.({{.name}})), 0)
e, ok := s.tbl.Lookup(h, func(v {{.name}}) bool {
return val.({{.name}}) == v
diff --git a/go/parquet/internal/hashing/xxh3_memo_table.go b/go/internal/hashing/xxh3_memo_table.go
similarity index 82%
rename from go/parquet/internal/hashing/xxh3_memo_table.go
rename to go/internal/hashing/xxh3_memo_table.go
index fbda0757bc..3fb52170c1 100644
--- a/go/parquet/internal/hashing/xxh3_memo_table.go
+++ b/go/internal/hashing/xxh3_memo_table.go
@@ -26,15 +26,55 @@ import (
"reflect"
"unsafe"
- "github.com/apache/arrow/go/v8/arrow"
- "github.com/apache/arrow/go/v8/arrow/array"
- "github.com/apache/arrow/go/v8/arrow/memory"
"github.com/apache/arrow/go/v8/parquet"
"github.com/zeebo/xxh3"
)
-//go:generate go run ../../../arrow/_tools/tmpl/main.go -i -data=types.tmpldata xxh3_memo_table.gen.go.tmpl
+//go:generate go run ../../arrow/_tools/tmpl/main.go -i -data=types.tmpldata xxh3_memo_table.gen.go.tmpl
+
+type TypeTraits interface {
+ BytesRequired(n int) int
+}
+
+// MemoTable interface for hash tables and dictionary encoding.
+//
+// Values will remember the order they are inserted to generate a valid
+// dictionary.
+type MemoTable interface {
+ TypeTraits() TypeTraits
+ // Reset drops everything in the table allowing it to be reused
+ Reset()
+ // Size returns the current number of unique values stored in
+ // the table, including whether or not a null value has been
+ // inserted via GetOrInsertNull.
+ Size() int
+ // GetOrInsert returns the index of the table the specified value is,
+ // and a boolean indicating whether or not the value was found in
+ // the table (if false, the value was inserted). An error is returned
+ // if val is not the appropriate type for the table.
+ GetOrInsert(val interface{}) (idx int, existed bool, err error)
+ // GetOrInsertNull returns the index of the null value in the table,
+ // inserting one if it hasn't already been inserted. It returns a boolean
+ // indicating if the null value already existed or not in the table.
+ GetOrInsertNull() (idx int, existed bool)
+ // GetNull returns the index of the null value in the table, but does not
+ // insert one if it doesn't already exist. Will return -1 if it doesn't exist
+ // indicated by a false value for the boolean.
+ GetNull() (idx int, exists bool)
+ // WriteOut copys the unique values of the memotable out to the byte slice
+ // provided. Must have allocated enough bytes for all the values.
+ WriteOut(out []byte)
+ // WriteOutSubset is like WriteOut, but only writes a subset of values
+ // starting with the index offset.
+ WriteOutSubset(offset int, out []byte)
+}
+
+type NumericMemoTable interface {
+ MemoTable
+ WriteOutLE(out []byte)
+ WriteOutSubsetLE(offset int, out []byte)
+}
func hashInt(val uint64, alg uint64) uint64 {
// Two of xxhash's prime multipliers (which are chosen for their
@@ -125,13 +165,27 @@ var isNan32Cmp = func(v float32) bool { return math.IsNaN(float64(v)) }
// KeyNotFound is the constant returned by memo table functions when a key isn't found in the table
const KeyNotFound = -1
+type BinaryBuilderIFace interface {
+ Reserve(int)
+ ReserveData(int)
+ Retain()
+ Resize(int)
+ Release()
+ DataLen() int
+ Value(int) []byte
+ Len() int
+ AppendNull()
+ AppendString(string)
+ Append([]byte)
+}
+
// BinaryMemoTable is our hashtable for binary data using the BinaryBuilder
// to construct the actual data in an easy to pass around way with minimal copies
// while using a hash table to keep track of the indexes into the dictionary that
// is created as we go.
type BinaryMemoTable struct {
tbl *Int32HashTable
- builder *array.BinaryBuilder
+ builder BinaryBuilderIFace
nullIdx int
}
@@ -140,11 +194,7 @@ type BinaryMemoTable struct {
// initial and valuesize can be used to pre-allocate the table to reduce allocations. With
// initial being the initial number of entries to allocate for and valuesize being the starting
// amount of space allocated for writing the actual binary data.
-func NewBinaryMemoTable(mem memory.Allocator, initial, valuesize int) *BinaryMemoTable {
- if mem == nil {
- mem = memory.DefaultAllocator
- }
- bldr := array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary)
+func NewBinaryMemoTable(initial, valuesize int, bldr BinaryBuilderIFace) *BinaryMemoTable {
bldr.Reserve(int(initial))
datasize := valuesize
if datasize <= 0 {
@@ -154,10 +204,18 @@ func NewBinaryMemoTable(mem memory.Allocator, initial, valuesize int) *BinaryMem
return &BinaryMemoTable{tbl: NewInt32HashTable(uint64(initial)), builder: bldr, nullIdx: KeyNotFound}
}
+type unimplementedtraits struct{}
+
+func (unimplementedtraits) BytesRequired(int) int { panic("unimplemented") }
+
+func (BinaryMemoTable) TypeTraits() TypeTraits {
+ return unimplementedtraits{}
+}
+
// Reset dumps all of the data in the table allowing it to be reutilized.
func (s *BinaryMemoTable) Reset() {
s.tbl.Reset(32)
- s.builder.NewArray().Release()
+ s.builder.Resize(0)
s.builder.Reserve(int(32))
s.builder.ReserveData(int(32) * 4)
s.nullIdx = KeyNotFound
@@ -299,13 +357,13 @@ func (b *BinaryMemoTable) findOffset(idx int) uintptr {
// CopyOffsets copies the list of offsets into the passed in slice, the offsets
// being the start and end values of the underlying allocated bytes in the builder
// for the individual values of the table. out should be at least sized to Size()+1
-func (b *BinaryMemoTable) CopyOffsets(out []int8) {
+func (b *BinaryMemoTable) CopyOffsets(out []int32) {
b.CopyOffsetsSubset(0, out)
}
// CopyOffsetsSubset is like CopyOffsets but instead of copying all of the offsets,
// it gets a subset of the offsets in the table starting at the index provided by "start".
-func (b *BinaryMemoTable) CopyOffsetsSubset(start int, out []int8) {
+func (b *BinaryMemoTable) CopyOffsetsSubset(start int, out []int32) {
if b.builder.Len() <= start {
return
}
@@ -313,11 +371,11 @@ func (b *BinaryMemoTable) CopyOffsetsSubset(start int, out []int8) {
first := b.findOffset(0)
delta := b.findOffset(start)
for i := start; i < b.Size(); i++ {
- offset := int8(b.findOffset(i) - delta)
+ offset := int32(b.findOffset(i) - delta)
out[i-start] = offset
}
- out[b.Size()-start] = int8(b.builder.DataLen() - int(delta) - int(first))
+ out[b.Size()-start] = int32(b.builder.DataLen() - (int(delta) - int(first)))
}
// CopyValues copies the raw binary data bytes out, out should be a []byte
@@ -329,6 +387,10 @@ func (b *BinaryMemoTable) CopyValues(out interface{}) {
// CopyValuesSubset copies the raw binary data bytes out starting with the value
// at the index start, out should be a []byte with at least ValuesSize bytes allocated
func (b *BinaryMemoTable) CopyValuesSubset(start int, out interface{}) {
+ if b.builder.Len() <= start {
+ return
+ }
+
var (
first = b.findOffset(0)
offset = b.findOffset(int(start))
diff --git a/go/parquet/internal/utils/Makefile b/go/internal/utils/Makefile
similarity index 70%
copy from go/parquet/internal/utils/Makefile
copy to go/internal/utils/Makefile
index 2d6153fe31..f56e4b65c3 100644
--- a/go/parquet/internal/utils/Makefile
+++ b/go/internal/utils/Makefile
@@ -35,8 +35,7 @@ ALL_SOURCES := $(shell find . -path ./_lib -prune -o -name '*.go' -name '*.s' -n
.PHONEY: assembly
INTEL_SOURCES := \
- bit_packing_avx2.s min_max_avx2.s min_max_sse4.s \
- unpack_bool_avx2.s unpack_bool_sse4.s
+ min_max_avx2_amd64.s min_max_sse4_amd64.s
#
# ARROW-15336: DO NOT add the assembly target for Arm64 (ARM_SOURCES) until c2goasm added the Arm64 support.
@@ -46,40 +45,19 @@ INTEL_SOURCES := \
assembly: $(INTEL_SOURCES)
-_lib/bit_packing_avx2.s: _lib/bit_packing_avx2.c
- $(CC) -S $(C_FLAGS) $(ASM_FLAGS_AVX2) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@; perl -i -pe 's/mem(cpy|set)/clib·_mem\1(SB)/' $@
-
-_lib/min_max_avx2.s: _lib/min_max.c
+_lib/min_max_avx2_amd64.s: _lib/min_max.c
$(CC) -S $(C_FLAGS) $(ASM_FLAGS_AVX2) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@
-_lib/min_max_sse4.s: _lib/min_max.c
+_lib/min_max_sse4_amd64.s: _lib/min_max.c
$(CC) -S $(C_FLAGS) $(ASM_FLAGS_SSE4) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@
_lib/min_max_neon.s: _lib/min_max.c
$(CC) -S $(C_FLAGS_NEON) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@
-_lib/unpack_bool_avx2.s: _lib/unpack_bool.c
- $(CC) -S $(C_FLAGS) $(ASM_FLAGS_AVX2) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@
-
-_lib/unpack_bool_sse4.s: _lib/unpack_bool.c
- $(CC) -S $(C_FLAGS) $(ASM_FLAGS_SSE4) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@
-
-_lib/unpack_bool_neon.s: _lib/unpack_bool.c
- $(CC) -S $(C_FLAGS_NEON) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@
-
-bit_packing_avx2.s: _lib/bit_packing_avx2.s
- $(C2GOASM) -a -f $^ $@
-
-min_max_avx2.s: _lib/min_max_avx2.s
- $(C2GOASM) -a -f $^ $@
-
-min_max_sse4.s: _lib/min_max_sse4.s
- $(C2GOASM) -a -f $^ $@
-
-unpack_bool_avx2.s: _lib/unpack_bool_avx2.s
+min_max_avx2_amd64.s: _lib/min_max_avx2_amd64.s
$(C2GOASM) -a -f $^ $@
-unpack_bool_sse4.s: _lib/unpack_bool_sse4.s
+min_max_sse4_amd64.s: _lib/min_max_sse4_amd64.s
$(C2GOASM) -a -f $^ $@
clean:
diff --git a/go/parquet/internal/utils/min_max_noasm.go b/go/internal/utils/_lib/arch.h
similarity index 71%
copy from go/parquet/internal/utils/min_max_noasm.go
copy to go/internal/utils/_lib/arch.h
index 1ef1adc6fd..7c75cd2f60 100644
--- a/go/parquet/internal/utils/min_max_noasm.go
+++ b/go/internal/utils/_lib/arch.h
@@ -14,14 +14,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build noasm
+#undef FULL_NAME
-package utils
-
-// if building with the 'noasm' tag, then point to the pure go implementations
-func init() {
- minmaxFuncs.i32 = int32MinMax
- minmaxFuncs.ui32 = uint32MinMax
- minmaxFuncs.i64 = int64MinMax
- minmaxFuncs.ui64 = uint64MinMax
-}
+#if defined(__AVX2__)
+ #define FULL_NAME(x) x##_avx2
+#elif __SSE4_2__ == 1
+ #define FULL_NAME(x) x##_sse4
+#elif __SSE3__ == 1
+ #define FULL_NAME(x) x##_sse3
+#elif defined(__ARM_NEON) || defined(__ARM_NEON__)
+ #define FULL_NAME(x) x##_neon
+#else
+ #define FULL_NAME(x) x##_x86
+#endif
diff --git a/go/parquet/internal/utils/_lib/min_max.c b/go/internal/utils/_lib/min_max.c
similarity index 63%
rename from go/parquet/internal/utils/_lib/min_max.c
rename to go/internal/utils/_lib/min_max.c
index 83c189fc24..d876f31a11 100644
--- a/go/parquet/internal/utils/_lib/min_max.c
+++ b/go/internal/utils/_lib/min_max.c
@@ -20,6 +20,58 @@
#include <math.h>
#include <float.h>
+void FULL_NAME(int8_max_min)(int8_t values[], int len, int8_t* minout, int8_t* maxout) {
+ int8_t max = INT8_MIN;
+ int8_t min = INT8_MAX;
+
+ for (int i = 0; i < len; ++i) {
+ min = min < values[i] ? min : values[i];
+ max = max > values[i] ? max : values[i];
+ }
+
+ *maxout = max;
+ *minout = min;
+}
+
+void FULL_NAME(uint8_max_min)(uint8_t values[], int len, uint8_t* minout, uint8_t* maxout) {
+ uint8_t max = 0;
+ uint8_t min = UINT8_MAX;
+
+ for (int i = 0; i < len; ++i) {
+ min = min < values[i] ? min : values[i];
+ max = max > values[i] ? max : values[i];
+ }
+
+ *maxout = max;
+ *minout = min;
+}
+
+void FULL_NAME(int16_max_min)(int16_t values[], int len, int16_t* minout, int16_t* maxout) {
+ int16_t max = INT16_MIN;
+ int16_t min = INT16_MAX;
+
+ for (int i = 0; i < len; ++i) {
+ min = min < values[i] ? min : values[i];
+ max = max > values[i] ? max : values[i];
+ }
+
+ *maxout = max;
+ *minout = min;
+}
+
+void FULL_NAME(uint16_max_min)(uint16_t values[], int len, uint16_t* minout, uint16_t* maxout) {
+ uint16_t max = 0;
+ uint16_t min = UINT16_MAX;
+
+ for (int i = 0; i < len; ++i) {
+ min = min < values[i] ? min : values[i];
+ max = max > values[i] ? max : values[i];
+ }
+
+ *maxout = max;
+ *minout = min;
+}
+
void FULL_NAME(int32_max_min)(int32_t values[], int len, int32_t* minout, int32_t* maxout) {
int32_t max = INT32_MIN;
int32_t min = INT32_MAX;
diff --git a/go/internal/utils/_lib/min_max_avx2_amd64.s b/go/internal/utils/_lib/min_max_avx2_amd64.s
new file mode 100644
index 0000000000..e4e73fd414
--- /dev/null
+++ b/go/internal/utils/_lib/min_max_avx2_amd64.s
@@ -0,0 +1,1009 @@
+ .text
+ .intel_syntax noprefix
+ .file "min_max.c"
+ .section .rodata.cst32,"aM",@progbits,32
+ .p2align 5 # -- Begin function int8_max_min_avx2
+.LCPI0_0:
+ .zero 32,128
+.LCPI0_1:
+ .zero 32,127
+ .section .rodata.cst16,"aM",@progbits,16
+ .p2align 4
+.LCPI0_2:
+ .zero 16,127
+.LCPI0_3:
+ .zero 16,128
+ .text
+ .globl int8_max_min_avx2
+ .p2align 4, 0x90
+ .type int8_max_min_avx2,@function
+int8_max_min_avx2: # @int8_max_min_avx2
+# %bb.0:
+ push rbp
+ mov rbp, rsp
+ and rsp, -8
+ test esi, esi
+ jle .LBB0_1
+# %bb.2:
+ mov r9d, esi
+ cmp esi, 63
+ ja .LBB0_4
+# %bb.3:
+ mov r8b, -128
+ mov sil, 127
+ xor r10d, r10d
+ jmp .LBB0_11
+.LBB0_1:
+ mov sil, 127
+ mov r8b, -128
+ jmp .LBB0_12
+.LBB0_4:
+ mov r10d, r9d
+ and r10d, -64
+ lea rax, [r10 - 64]
+ mov r8, rax
+ shr r8, 6
+ add r8, 1
+ test rax, rax
+ je .LBB0_5
+# %bb.6:
+ mov rsi, r8
+ and rsi, -2
+ neg rsi
+ vmovdqa ymm1, ymmword ptr [rip + .LCPI0_0] # ymm1 = [128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128]
+ vmovdqa ymm0, ymmword ptr [rip + .LCPI0_1] # ymm0 = [127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127]
+ xor eax, eax
+ vmovdqa ymm2, ymm0
+ vmovdqa ymm3, ymm1
+ .p2align 4, 0x90
+.LBB0_7: # =>This Inner Loop Header: Depth=1
+ vmovdqu ymm4, ymmword ptr [rdi + rax]
+ vmovdqu ymm5, ymmword ptr [rdi + rax + 32]
+ vmovdqu ymm6, ymmword ptr [rdi + rax + 64]
+ vmovdqu ymm7, ymmword ptr [rdi + rax + 96]
+ vpminsb ymm0, ymm0, ymm4
+ vpminsb ymm2, ymm2, ymm5
+ vpmaxsb ymm1, ymm1, ymm4
+ vpmaxsb ymm3, ymm3, ymm5
+ vpminsb ymm0, ymm0, ymm6
+ vpminsb ymm2, ymm2, ymm7
+ vpmaxsb ymm1, ymm1, ymm6
+ vpmaxsb ymm3, ymm3, ymm7
+ sub rax, -128
+ add rsi, 2
+ jne .LBB0_7
+# %bb.8:
+ test r8b, 1
+ je .LBB0_10
+.LBB0_9:
+ vmovdqu ymm4, ymmword ptr [rdi + rax]
+ vmovdqu ymm5, ymmword ptr [rdi + rax + 32]
+ vpmaxsb ymm3, ymm3, ymm5
+ vpmaxsb ymm1, ymm1, ymm4
+ vpminsb ymm2, ymm2, ymm5
+ vpminsb ymm0, ymm0, ymm4
+.LBB0_10:
+ vpmaxsb ymm1, ymm1, ymm3
+ vextracti128 xmm3, ymm1, 1
+ vpmaxsb xmm1, xmm1, xmm3
+ vpxor xmm1, xmm1, xmmword ptr [rip + .LCPI0_2]
+ vpminsb ymm0, ymm0, ymm2
+ vpsrlw xmm2, xmm1, 8
+ vpminub xmm1, xmm1, xmm2
+ vphminposuw xmm1, xmm1
+ vmovd r8d, xmm1
+ xor r8b, 127
+ vextracti128 xmm1, ymm0, 1
+ vpminsb xmm0, xmm0, xmm1
+ vpxor xmm0, xmm0, xmmword ptr [rip + .LCPI0_3]
+ vpsrlw xmm1, xmm0, 8
+ vpminub xmm0, xmm0, xmm1
+ vphminposuw xmm0, xmm0
+ vmovd esi, xmm0
+ xor sil, -128
+ cmp r10, r9
+ je .LBB0_12
+ .p2align 4, 0x90
+.LBB0_11: # =>This Inner Loop Header: Depth=1
+ movzx eax, byte ptr [rdi + r10]
+ cmp sil, al
+ movzx esi, sil
+ cmovg esi, eax
+ cmp r8b, al
+ movzx r8d, r8b
+ cmovl r8d, eax
+ add r10, 1
+ cmp r9, r10
+ jne .LBB0_11
+.LBB0_12:
+ mov byte ptr [rcx], r8b
+ mov byte ptr [rdx], sil
+ mov rsp, rbp
+ pop rbp
+ vzeroupper
+ ret
+.LBB0_5:
+ vmovdqa ymm1, ymmword ptr [rip + .LCPI0_0] # ymm1 = [128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128,128]
+ vmovdqa ymm0, ymmword ptr [rip + .LCPI0_1] # ymm0 = [127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127,127]
+ xor eax, eax
+ vmovdqa ymm2, ymm0
+ vmovdqa ymm3, ymm1
+ test r8b, 1
+ jne .LBB0_9
+ jmp .LBB0_10
+.Lfunc_end0:
+ .size int8_max_min_avx2, .Lfunc_end0-int8_max_min_avx2
+ # -- End function
+ .globl uint8_max_min_avx2 # -- Begin function uint8_max_min_avx2
+ .p2align 4, 0x90
+ .type uint8_max_min_avx2,@function
+uint8_max_min_avx2: # @uint8_max_min_avx2
+# %bb.0:
+ push rbp
+ mov rbp, rsp
+ and rsp, -8
+ test esi, esi
+ jle .LBB1_1
+# %bb.2:
+ mov r9d, esi
+ cmp esi, 63
+ ja .LBB1_4
+# %bb.3:
+ mov sil, -1
+ xor r10d, r10d
+ xor eax, eax
+ jmp .LBB1_11
+.LBB1_1:
+ mov sil, -1
+ xor eax, eax
+ jmp .LBB1_12
+.LBB1_4:
+ mov r10d, r9d
+ and r10d, -64
+ lea rax, [r10 - 64]
+ mov r8, rax
+ shr r8, 6
+ add r8, 1
+ test rax, rax
+ je .LBB1_5
+# %bb.6:
+ mov rsi, r8
+ and rsi, -2
+ neg rsi
+ vpxor xmm0, xmm0, xmm0
+ vpcmpeqd ymm1, ymm1, ymm1
+ xor eax, eax
+ vpcmpeqd ymm2, ymm2, ymm2
+ vpxor xmm3, xmm3, xmm3
+ .p2align 4, 0x90
+.LBB1_7: # =>This Inner Loop Header: Depth=1
+ vmovdqu ymm4, ymmword ptr [rdi + rax]
+ vmovdqu ymm5, ymmword ptr [rdi + rax + 32]
+ vmovdqu ymm6, ymmword ptr [rdi + rax + 64]
+ vmovdqu ymm7, ymmword ptr [rdi + rax + 96]
+ vpminub ymm1, ymm1, ymm4
+ vpminub ymm2, ymm2, ymm5
+ vpmaxub ymm0, ymm0, ymm4
+ vpmaxub ymm3, ymm3, ymm5
+ vpminub ymm1, ymm1, ymm6
+ vpminub ymm2, ymm2, ymm7
+ vpmaxub ymm0, ymm0, ymm6
+ vpmaxub ymm3, ymm3, ymm7
+ sub rax, -128
+ add rsi, 2
+ jne .LBB1_7
+# %bb.8:
+ test r8b, 1
+ je .LBB1_10
+.LBB1_9:
+ vmovdqu ymm4, ymmword ptr [rdi + rax]
+ vmovdqu ymm5, ymmword ptr [rdi + rax + 32]
+ vpmaxub ymm3, ymm3, ymm5
+ vpmaxub ymm0, ymm0, ymm4
+ vpminub ymm2, ymm2, ymm5
+ vpminub ymm1, ymm1, ymm4
+.LBB1_10:
+ vpminub ymm1, ymm1, ymm2
+ vpmaxub ymm0, ymm0, ymm3
+ vextracti128 xmm2, ymm0, 1
+ vpmaxub xmm0, xmm0, xmm2
+ vpcmpeqd xmm2, xmm2, xmm2
+ vpxor xmm0, xmm0, xmm2
+ vpsrlw xmm2, xmm0, 8
+ vpminub xmm0, xmm0, xmm2
+ vphminposuw xmm0, xmm0
+ vmovd eax, xmm0
+ not al
+ vextracti128 xmm0, ymm1, 1
+ vpminub xmm0, xmm1, xmm0
+ vpsrlw xmm1, xmm0, 8
+ vpminub xmm0, xmm0, xmm1
+ vphminposuw xmm0, xmm0
+ vmovd esi, xmm0
+ cmp r10, r9
+ je .LBB1_12
+ .p2align 4, 0x90
+.LBB1_11: # =>This Inner Loop Header: Depth=1
+ movzx r8d, byte ptr [rdi + r10]
+ cmp sil, r8b
+ movzx esi, sil
+ cmovae esi, r8d
+ cmp al, r8b
+ movzx eax, al
+ cmovbe eax, r8d
+ add r10, 1
+ cmp r9, r10
+ jne .LBB1_11
+.LBB1_12:
+ mov byte ptr [rcx], al
+ mov byte ptr [rdx], sil
+ mov rsp, rbp
+ pop rbp
+ vzeroupper
+ ret
+.LBB1_5:
+ vpxor xmm0, xmm0, xmm0
+ vpcmpeqd ymm1, ymm1, ymm1
+ xor eax, eax
+ vpcmpeqd ymm2, ymm2, ymm2
+ vpxor xmm3, xmm3, xmm3
+ test r8b, 1
+ jne .LBB1_9
+ jmp .LBB1_10
+.Lfunc_end1:
+ .size uint8_max_min_avx2, .Lfunc_end1-uint8_max_min_avx2
+ # -- End function
+ .section .rodata.cst32,"aM",@progbits,32
+ .p2align 5 # -- Begin function int16_max_min_avx2
+.LCPI2_0:
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+.LCPI2_1:
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .section .rodata.cst16,"aM",@progbits,16
+ .p2align 4
+.LCPI2_2:
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+ .short 32767 # 0x7fff
+.LCPI2_3:
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .short 32768 # 0x8000
+ .text
+ .globl int16_max_min_avx2
+ .p2align 4, 0x90
+ .type int16_max_min_avx2,@function
+int16_max_min_avx2: # @int16_max_min_avx2
+# %bb.0:
+ push rbp
+ mov rbp, rsp
+ and rsp, -8
+ test esi, esi
+ jle .LBB2_1
+# %bb.2:
+ mov r9d, esi
+ cmp esi, 31
+ ja .LBB2_4
+# %bb.3:
+ mov r8w, -32768
+ mov si, 32767
+ xor r10d, r10d
+ jmp .LBB2_11
+.LBB2_1:
+ mov si, 32767
+ mov r8w, -32768
+ jmp .LBB2_12
+.LBB2_4:
+ mov r10d, r9d
+ and r10d, -32
+ lea rax, [r10 - 32]
+ mov r8, rax
+ shr r8, 5
+ add r8, 1
+ test rax, rax
+ je .LBB2_5
+# %bb.6:
+ mov rsi, r8
+ and rsi, -2
+ neg rsi
+ vmovdqa ymm1, ymmword ptr [rip + .LCPI2_0] # ymm1 = [32768,32768,32768,32768,32768,32768,32768,32768,32768,32768,32768,32768,32768,32768,32768,32768]
+ vmovdqa ymm0, ymmword ptr [rip + .LCPI2_1] # ymm0 = [32767,32767,32767,32767,32767,32767,32767,32767,32767,32767,32767,32767,32767,32767,32767,32767]
+ xor eax, eax
+ vmovdqa ymm2, ymm0
+ vmovdqa ymm3, ymm1
+ .p2align 4, 0x90
... 7249 lines suppressed ...