You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ze...@apache.org on 2022/11/18 16:05:43 UTC

[arrow] branch master updated: ARROW-18332: [Go] Cast Dictionary types to value type (#14650)

This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new e9222ae00b ARROW-18332: [Go] Cast Dictionary types to value type (#14650)
e9222ae00b is described below

commit e9222ae00b515a0f57d8c7d54bdb19457d68582f
Author: Matt Topol <zo...@gmail.com>
AuthorDate: Fri Nov 18 11:05:32 2022 -0500

    ARROW-18332: [Go] Cast Dictionary types to value type (#14650)
    
    Authored-by: Matt Topol <zo...@gmail.com>
    Signed-off-by: Matt Topol <zo...@gmail.com>
---
 go/arrow/compute/cast.go                  | 42 ++++++++++++++++++++++++++
 go/arrow/compute/cast_test.go             | 49 ++++++++++++++++++++++++++++---
 go/arrow/compute/datum.go                 |  3 --
 go/arrow/compute/internal/kernels/cast.go |  2 +-
 go/arrow/scalar/scalar.go                 | 18 ++++++++----
 5 files changed, 100 insertions(+), 14 deletions(-)

diff --git a/go/arrow/compute/cast.go b/go/arrow/compute/cast.go
index 1c00eedcb6..9530a06761 100644
--- a/go/arrow/compute/cast.go
+++ b/go/arrow/compute/cast.go
@@ -134,6 +134,38 @@ func (cf *castFunction) DispatchExact(vals ...arrow.DataType) (exec.Kernel, erro
 	return candidates[0], nil
 }
 
+func unpackDictionary(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
+	var (
+		dictArr  = batch.Values[0].Array.MakeArray().(*array.Dictionary)
+		opts     = ctx.State.(kernels.CastState)
+		dictType = dictArr.DataType().(*arrow.DictionaryType)
+		toType   = opts.ToType
+	)
+	defer dictArr.Release()
+
+	if !arrow.TypeEqual(toType, dictType) && !CanCast(dictType, toType) {
+		return fmt.Errorf("%w: cast type %s incompatible with dictionary type %s",
+			arrow.ErrInvalid, toType, dictType)
+	}
+
+	unpacked, err := TakeArray(ctx.Ctx, dictArr.Dictionary(), dictArr.Indices())
+	if err != nil {
+		return err
+	}
+	defer unpacked.Release()
+
+	if !arrow.TypeEqual(dictType, toType) {
+		unpacked, err = CastArray(ctx.Ctx, unpacked, &opts)
+		if err != nil {
+			return err
+		}
+		defer unpacked.Release()
+	}
+
+	out.TakeOwnership(unpacked.Data())
+	return nil
+}
+
 func CastFromExtension(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
 	opts := ctx.State.(kernels.CastState)
 
@@ -402,6 +434,8 @@ func getTemporalCasts() []*castFunction {
 				panic(err)
 			}
 		}
+		fn.AddNewTypeCast(arrow.DICTIONARY, []exec.InputType{exec.NewIDInput(arrow.DICTIONARY)},
+			kernels[0].Signature.OutType, unpackDictionary, exec.NullComputedNoPrealloc, exec.MemNoPrealloc)
 		output = append(output, fn)
 	}
 
@@ -425,6 +459,10 @@ func getNumericCasts() []*castFunction {
 				panic(err)
 			}
 		}
+
+		fn.AddNewTypeCast(arrow.DICTIONARY, []exec.InputType{exec.NewIDInput(arrow.DICTIONARY)},
+			kns[0].Signature.OutType, unpackDictionary, exec.NullComputedNoPrealloc, exec.MemNoPrealloc)
+
 		return fn
 	}
 
@@ -486,6 +524,10 @@ func getBinaryLikeCasts() []*castFunction {
 				panic(err)
 			}
 		}
+
+		fn.AddNewTypeCast(arrow.DICTIONARY, []exec.InputType{exec.NewIDInput(arrow.DICTIONARY)},
+			kns[0].Signature.OutType, unpackDictionary, exec.NullComputedNoPrealloc, exec.MemNoPrealloc)
+
 		out = append(out, fn)
 	}
 
diff --git a/go/arrow/compute/cast_test.go b/go/arrow/compute/cast_test.go
index 6a0b77fce0..9e3b7c1ac1 100644
--- a/go/arrow/compute/cast_test.go
+++ b/go/arrow/compute/cast_test.go
@@ -240,6 +240,7 @@ var (
 		arrow.BinaryTypes.String,
 		arrow.BinaryTypes.LargeString,
 	}
+	dictIndexTypes = integerTypes
 )
 
 type CastSuite struct {
@@ -364,7 +365,7 @@ func (c *CastSuite) TestCanCast() {
 		canCast(from, []arrow.DataType{arrow.FixedWidthTypes.Boolean})
 		canCast(from, numericTypes)
 		canCast(from, []arrow.DataType{arrow.BinaryTypes.String, arrow.BinaryTypes.LargeString})
-		cannotCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: from}, []arrow.DataType{from})
+		canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: from}, []arrow.DataType{from})
 
 		cannotCast(from, []arrow.DataType{arrow.Null})
 	}
@@ -373,11 +374,11 @@ func (c *CastSuite) TestCanCast() {
 		canCast(from, []arrow.DataType{arrow.FixedWidthTypes.Boolean})
 		canCast(from, numericTypes)
 		canCast(from, baseBinaryTypes)
-		cannotCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int64, ValueType: from}, []arrow.DataType{from})
+		canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int64, ValueType: from}, []arrow.DataType{from})
 
 		// any cast which is valid for the dictionary is valid for the dictionary array
-		// canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Uint32, ValueType: from}, baseBinaryTypes)
-		// canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int16, ValueType: from}, baseBinaryTypes)
+		canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Uint32, ValueType: from}, baseBinaryTypes)
+		canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int16, ValueType: from}, baseBinaryTypes)
 
 		cannotCast(from, []arrow.DataType{arrow.Null})
 	}
@@ -2257,6 +2258,9 @@ func (c *CastSuite) TestIdentityCasts() {
 	c.checkCastSelfZeroCopy(arrow.FixedWidthTypes.Date32, `[1, 2, 3, 4]`)
 	c.checkCastSelfZeroCopy(arrow.FixedWidthTypes.Date64, `[86400000, 0]`)
 	c.checkCastSelfZeroCopy(arrow.FixedWidthTypes.Timestamp_s, `[1, 2, 3, 4]`)
+
+	c.checkCastSelfZeroCopy(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int8, ValueType: arrow.PrimitiveTypes.Int8},
+		`[1, 2, 3, 1, null, 3]`)
 }
 
 func (c *CastSuite) TestListToPrimitive() {
@@ -2727,6 +2731,43 @@ func (c *CastSuite) TestNoOutBitmapIfIsAllValid() {
 	c.Nil(result.Data().Buffers()[0])
 }
 
+func (c *CastSuite) TestFromDictionary() {
+	ctx := compute.WithAllocator(context.Background(), c.mem)
+
+	dictionaries := []arrow.Array{}
+
+	for _, ty := range numericTypes {
+		a, _, _ := array.FromJSON(c.mem, ty, strings.NewReader(`[23, 12, 45, 12, null]`))
+		defer a.Release()
+		dictionaries = append(dictionaries, a)
+	}
+
+	for _, ty := range []arrow.DataType{arrow.BinaryTypes.String, arrow.BinaryTypes.LargeString} {
+		a, _, _ := array.FromJSON(c.mem, ty, strings.NewReader(`["foo", "bar", "baz", "foo", null]`))
+		defer a.Release()
+		dictionaries = append(dictionaries, a)
+	}
+
+	for _, d := range dictionaries {
+		for _, ty := range dictIndexTypes {
+			indices, _, _ := array.FromJSON(c.mem, ty, strings.NewReader(`[4, 0, 1, 2, 0, 4, null, 2]`))
+
+			expected, err := compute.Take(ctx, compute.TakeOptions{}, &compute.ArrayDatum{d.Data()}, &compute.ArrayDatum{indices.Data()})
+			c.Require().NoError(err)
+			exp := expected.(*compute.ArrayDatum).MakeArray()
+
+			dictArr := array.NewDictionaryArray(&arrow.DictionaryType{IndexType: ty, ValueType: d.DataType()}, indices, d)
+			checkCast(c.T(), dictArr, exp, *compute.SafeCastOptions(d.DataType()))
+
+			indices.Release()
+			expected.Release()
+			exp.Release()
+			dictArr.Release()
+			return
+		}
+	}
+}
+
 func TestCasts(t *testing.T) {
 	suite.Run(t, new(CastSuite))
 }
diff --git a/go/arrow/compute/datum.go b/go/arrow/compute/datum.go
index 83e6cbe10a..e02d50a98a 100644
--- a/go/arrow/compute/datum.go
+++ b/go/arrow/compute/datum.go
@@ -122,9 +122,6 @@ type releasable interface {
 }
 
 func (d *ScalarDatum) Release() {
-	if !d.Value.IsValid() {
-		return
-	}
 	if v, ok := d.Value.(releasable); ok {
 		v.Release()
 	}
diff --git a/go/arrow/compute/internal/kernels/cast.go b/go/arrow/compute/internal/kernels/cast.go
index 455b0b54f4..a3d1e9b0e0 100644
--- a/go/arrow/compute/internal/kernels/cast.go
+++ b/go/arrow/compute/internal/kernels/cast.go
@@ -84,7 +84,7 @@ func OutputAllNull(_ *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult
 	return nil
 }
 
-func canCastFromDict(id arrow.Type) bool {
+func CanCastFromDict(id arrow.Type) bool {
 	return arrow.IsPrimitive(id) || arrow.IsBaseBinary(id) || arrow.IsFixedSizeBinary(id)
 }
 
diff --git a/go/arrow/scalar/scalar.go b/go/arrow/scalar/scalar.go
index 4fd15f3d46..c03f380699 100644
--- a/go/arrow/scalar/scalar.go
+++ b/go/arrow/scalar/scalar.go
@@ -568,7 +568,7 @@ func init() {
 // GetScalar creates a scalar object from the value at a given index in the
 // passed in array, returns an error if unable to do so.
 func GetScalar(arr arrow.Array, idx int) (Scalar, error) {
-	if arr.IsNull(idx) {
+	if arr.DataType().ID() != arrow.DICTIONARY && arr.IsNull(idx) {
 		return MakeNullScalar(arr.DataType()), nil
 	}
 
@@ -675,13 +675,19 @@ func GetScalar(arr arrow.Array, idx int) (Scalar, error) {
 		return NewTimestampScalar(arr.Value(idx), arr.DataType()), nil
 	case *array.Dictionary:
 		ty := arr.DataType().(*arrow.DictionaryType)
-		index, err := MakeScalarParam(arr.GetValueIndex(idx), ty.IndexType)
-		if err != nil {
-			return nil, err
+		valid := arr.IsValid(idx)
+		scalar := &Dictionary{scalar: scalar{ty, valid}}
+		if valid {
+			index, err := MakeScalarParam(arr.GetValueIndex(idx), ty.IndexType)
+			if err != nil {
+				return nil, err
+			}
+
+			scalar.Value.Index = index
+		} else {
+			scalar.Value.Index = MakeNullScalar(ty.IndexType)
 		}
 
-		scalar := &Dictionary{scalar: scalar{ty, arr.IsValid(idx)}}
-		scalar.Value.Index = index
 		scalar.Value.Dict = arr.Dictionary()
 		scalar.Value.Dict.Retain()
 		return scalar, nil