You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ze...@apache.org on 2022/11/18 16:05:43 UTC
[arrow] branch master updated: ARROW-18332: [Go] Cast Dictionary types to value type (#14650)
This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new e9222ae00b ARROW-18332: [Go] Cast Dictionary types to value type (#14650)
e9222ae00b is described below
commit e9222ae00b515a0f57d8c7d54bdb19457d68582f
Author: Matt Topol <zo...@gmail.com>
AuthorDate: Fri Nov 18 11:05:32 2022 -0500
ARROW-18332: [Go] Cast Dictionary types to value type (#14650)
Authored-by: Matt Topol <zo...@gmail.com>
Signed-off-by: Matt Topol <zo...@gmail.com>
---
go/arrow/compute/cast.go | 42 ++++++++++++++++++++++++++
go/arrow/compute/cast_test.go | 49 ++++++++++++++++++++++++++++---
go/arrow/compute/datum.go | 3 --
go/arrow/compute/internal/kernels/cast.go | 2 +-
go/arrow/scalar/scalar.go | 18 ++++++++----
5 files changed, 100 insertions(+), 14 deletions(-)
diff --git a/go/arrow/compute/cast.go b/go/arrow/compute/cast.go
index 1c00eedcb6..9530a06761 100644
--- a/go/arrow/compute/cast.go
+++ b/go/arrow/compute/cast.go
@@ -134,6 +134,38 @@ func (cf *castFunction) DispatchExact(vals ...arrow.DataType) (exec.Kernel, erro
return candidates[0], nil
}
+func unpackDictionary(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
+ var (
+ dictArr = batch.Values[0].Array.MakeArray().(*array.Dictionary)
+ opts = ctx.State.(kernels.CastState)
+ dictType = dictArr.DataType().(*arrow.DictionaryType)
+ toType = opts.ToType
+ )
+ defer dictArr.Release()
+
+ if !arrow.TypeEqual(toType, dictType) && !CanCast(dictType, toType) {
+ return fmt.Errorf("%w: cast type %s incompatible with dictionary type %s",
+ arrow.ErrInvalid, toType, dictType)
+ }
+
+ unpacked, err := TakeArray(ctx.Ctx, dictArr.Dictionary(), dictArr.Indices())
+ if err != nil {
+ return err
+ }
+ defer unpacked.Release()
+
+ if !arrow.TypeEqual(dictType, toType) {
+ unpacked, err = CastArray(ctx.Ctx, unpacked, &opts)
+ if err != nil {
+ return err
+ }
+ defer unpacked.Release()
+ }
+
+ out.TakeOwnership(unpacked.Data())
+ return nil
+}
+
func CastFromExtension(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
opts := ctx.State.(kernels.CastState)
@@ -402,6 +434,8 @@ func getTemporalCasts() []*castFunction {
panic(err)
}
}
+ fn.AddNewTypeCast(arrow.DICTIONARY, []exec.InputType{exec.NewIDInput(arrow.DICTIONARY)},
+ kernels[0].Signature.OutType, unpackDictionary, exec.NullComputedNoPrealloc, exec.MemNoPrealloc)
output = append(output, fn)
}
@@ -425,6 +459,10 @@ func getNumericCasts() []*castFunction {
panic(err)
}
}
+
+ fn.AddNewTypeCast(arrow.DICTIONARY, []exec.InputType{exec.NewIDInput(arrow.DICTIONARY)},
+ kns[0].Signature.OutType, unpackDictionary, exec.NullComputedNoPrealloc, exec.MemNoPrealloc)
+
return fn
}
@@ -486,6 +524,10 @@ func getBinaryLikeCasts() []*castFunction {
panic(err)
}
}
+
+ fn.AddNewTypeCast(arrow.DICTIONARY, []exec.InputType{exec.NewIDInput(arrow.DICTIONARY)},
+ kns[0].Signature.OutType, unpackDictionary, exec.NullComputedNoPrealloc, exec.MemNoPrealloc)
+
out = append(out, fn)
}
diff --git a/go/arrow/compute/cast_test.go b/go/arrow/compute/cast_test.go
index 6a0b77fce0..9e3b7c1ac1 100644
--- a/go/arrow/compute/cast_test.go
+++ b/go/arrow/compute/cast_test.go
@@ -240,6 +240,7 @@ var (
arrow.BinaryTypes.String,
arrow.BinaryTypes.LargeString,
}
+ dictIndexTypes = integerTypes
)
type CastSuite struct {
@@ -364,7 +365,7 @@ func (c *CastSuite) TestCanCast() {
canCast(from, []arrow.DataType{arrow.FixedWidthTypes.Boolean})
canCast(from, numericTypes)
canCast(from, []arrow.DataType{arrow.BinaryTypes.String, arrow.BinaryTypes.LargeString})
- cannotCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: from}, []arrow.DataType{from})
+ canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: from}, []arrow.DataType{from})
cannotCast(from, []arrow.DataType{arrow.Null})
}
@@ -373,11 +374,11 @@ func (c *CastSuite) TestCanCast() {
canCast(from, []arrow.DataType{arrow.FixedWidthTypes.Boolean})
canCast(from, numericTypes)
canCast(from, baseBinaryTypes)
- cannotCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int64, ValueType: from}, []arrow.DataType{from})
+ canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int64, ValueType: from}, []arrow.DataType{from})
// any cast which is valid for the dictionary is valid for the dictionary array
- // canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Uint32, ValueType: from}, baseBinaryTypes)
- // canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int16, ValueType: from}, baseBinaryTypes)
+ canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Uint32, ValueType: from}, baseBinaryTypes)
+ canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int16, ValueType: from}, baseBinaryTypes)
cannotCast(from, []arrow.DataType{arrow.Null})
}
@@ -2257,6 +2258,9 @@ func (c *CastSuite) TestIdentityCasts() {
c.checkCastSelfZeroCopy(arrow.FixedWidthTypes.Date32, `[1, 2, 3, 4]`)
c.checkCastSelfZeroCopy(arrow.FixedWidthTypes.Date64, `[86400000, 0]`)
c.checkCastSelfZeroCopy(arrow.FixedWidthTypes.Timestamp_s, `[1, 2, 3, 4]`)
+
+ c.checkCastSelfZeroCopy(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int8, ValueType: arrow.PrimitiveTypes.Int8},
+ `[1, 2, 3, 1, null, 3]`)
}
func (c *CastSuite) TestListToPrimitive() {
@@ -2727,6 +2731,43 @@ func (c *CastSuite) TestNoOutBitmapIfIsAllValid() {
c.Nil(result.Data().Buffers()[0])
}
+func (c *CastSuite) TestFromDictionary() {
+ ctx := compute.WithAllocator(context.Background(), c.mem)
+
+ dictionaries := []arrow.Array{}
+
+ for _, ty := range numericTypes {
+ a, _, _ := array.FromJSON(c.mem, ty, strings.NewReader(`[23, 12, 45, 12, null]`))
+ defer a.Release()
+ dictionaries = append(dictionaries, a)
+ }
+
+ for _, ty := range []arrow.DataType{arrow.BinaryTypes.String, arrow.BinaryTypes.LargeString} {
+ a, _, _ := array.FromJSON(c.mem, ty, strings.NewReader(`["foo", "bar", "baz", "foo", null]`))
+ defer a.Release()
+ dictionaries = append(dictionaries, a)
+ }
+
+ for _, d := range dictionaries {
+ for _, ty := range dictIndexTypes {
+ indices, _, _ := array.FromJSON(c.mem, ty, strings.NewReader(`[4, 0, 1, 2, 0, 4, null, 2]`))
+
+ expected, err := compute.Take(ctx, compute.TakeOptions{}, &compute.ArrayDatum{d.Data()}, &compute.ArrayDatum{indices.Data()})
+ c.Require().NoError(err)
+ exp := expected.(*compute.ArrayDatum).MakeArray()
+
+ dictArr := array.NewDictionaryArray(&arrow.DictionaryType{IndexType: ty, ValueType: d.DataType()}, indices, d)
+ checkCast(c.T(), dictArr, exp, *compute.SafeCastOptions(d.DataType()))
+
+ indices.Release()
+ expected.Release()
+ exp.Release()
+ dictArr.Release()
+ return
+ }
+ }
+}
+
func TestCasts(t *testing.T) {
suite.Run(t, new(CastSuite))
}
diff --git a/go/arrow/compute/datum.go b/go/arrow/compute/datum.go
index 83e6cbe10a..e02d50a98a 100644
--- a/go/arrow/compute/datum.go
+++ b/go/arrow/compute/datum.go
@@ -122,9 +122,6 @@ type releasable interface {
}
func (d *ScalarDatum) Release() {
- if !d.Value.IsValid() {
- return
- }
if v, ok := d.Value.(releasable); ok {
v.Release()
}
diff --git a/go/arrow/compute/internal/kernels/cast.go b/go/arrow/compute/internal/kernels/cast.go
index 455b0b54f4..a3d1e9b0e0 100644
--- a/go/arrow/compute/internal/kernels/cast.go
+++ b/go/arrow/compute/internal/kernels/cast.go
@@ -84,7 +84,7 @@ func OutputAllNull(_ *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult
return nil
}
-func canCastFromDict(id arrow.Type) bool {
+func CanCastFromDict(id arrow.Type) bool {
return arrow.IsPrimitive(id) || arrow.IsBaseBinary(id) || arrow.IsFixedSizeBinary(id)
}
diff --git a/go/arrow/scalar/scalar.go b/go/arrow/scalar/scalar.go
index 4fd15f3d46..c03f380699 100644
--- a/go/arrow/scalar/scalar.go
+++ b/go/arrow/scalar/scalar.go
@@ -568,7 +568,7 @@ func init() {
// GetScalar creates a scalar object from the value at a given index in the
// passed in array, returns an error if unable to do so.
func GetScalar(arr arrow.Array, idx int) (Scalar, error) {
- if arr.IsNull(idx) {
+ if arr.DataType().ID() != arrow.DICTIONARY && arr.IsNull(idx) {
return MakeNullScalar(arr.DataType()), nil
}
@@ -675,13 +675,19 @@ func GetScalar(arr arrow.Array, idx int) (Scalar, error) {
return NewTimestampScalar(arr.Value(idx), arr.DataType()), nil
case *array.Dictionary:
ty := arr.DataType().(*arrow.DictionaryType)
- index, err := MakeScalarParam(arr.GetValueIndex(idx), ty.IndexType)
- if err != nil {
- return nil, err
+ valid := arr.IsValid(idx)
+ scalar := &Dictionary{scalar: scalar{ty, valid}}
+ if valid {
+ index, err := MakeScalarParam(arr.GetValueIndex(idx), ty.IndexType)
+ if err != nil {
+ return nil, err
+ }
+
+ scalar.Value.Index = index
+ } else {
+ scalar.Value.Index = MakeNullScalar(ty.IndexType)
}
- scalar := &Dictionary{scalar: scalar{ty, arr.IsValid(idx)}}
- scalar.Value.Index = index
scalar.Value.Dict = arr.Dictionary()
scalar.Value.Dict.Retain()
return scalar, nil