You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "zeroshade (via GitHub)" <gi...@apache.org> on 2023/05/17 20:50:04 UTC

[GitHub] [arrow] zeroshade opened a new pull request, #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

zeroshade opened a new pull request, #35654:
URL: https://github.com/apache/arrow/pull/35654

   <!--
   Thanks for opening a pull request!
   If this is your first pull request you can find detailed information on how 
   to contribute here:
     * [New Contributor's Guide](https://arrow.apache.org/docs/dev/developers/guide/step_by_step/pr_lifecycle.html#reviews-and-merge-of-the-pull-request)
     * [Contributing Overview](https://arrow.apache.org/docs/dev/developers/overview.html)
   
   
   If this is not a [minor PR](https://github.com/apache/arrow/blob/main/CONTRIBUTING.md#Minor-Fixes). Could you open an issue for this pull request on GitHub? https://github.com/apache/arrow/issues/new/choose
   
   Opening GitHub issues ahead of time contributes to the [Openness](http://theapacheway.com/open/#:~:text=Openness%20allows%20new%20users%20the,must%20happen%20in%20the%20open.) of the Apache Arrow project.
   
   Then could you also rename the pull request title in the following format?
   
       GH-${GITHUB_ISSUE_ID}: [${COMPONENT}] ${SUMMARY}
   
   or
   
       MINOR: [${COMPONENT}] ${SUMMARY}
   
   In the case of PARQUET issues on JIRA the title also supports:
   
       PARQUET-${JIRA_ISSUE_ID}: [${COMPONENT}] ${SUMMARY}
   
   -->
   
   ### Rationale for this change
   Providing the ability to execute more complex expressions than single operations by leveraging Substrait's expression objects and deprecating the existing separate Expression interfaces in Go Arrow compute. This provides a quick integration with Substrait Expressions and ExtendedExpressions to start building more integrations.
   
   <!--
    Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed.
    Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes.  
   -->
   
   ### What changes are included in this PR?
   This PR provides:
   
   * an extension registry for Go arrow to provide mappings between Arrow and substrait for functions and for types along with other custom mappings if necessary. 
   * Facilities to convert between Arrow data types and Substrait types
   * Functions to evaluate Substrait expression objects with Arrow data as the input
   * Functions to evaluate Substrait field references against Arrow data and arrow schemas
   
   
   <!--
   There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR.
   -->
   
   ### Are these changes tested?
   Yes, unit tests are included.
   
   <!--
   We typically require tests for all PRs in order to:
   1. Prevent the code from being accidentally broken by subsequent changes
   2. Serve as another way to document the expected behavior of the code
   
   If tests are not included in your PR, please explain why (for example, are they covered by existing tests)?
   -->
   
   ### Are there any user-facing changes?
   Existing `compute.Expression` and its friends are being marked as deprecated.
   <!--
   If there are user-facing changes then we may require documentation to be updated before approving the PR.
   -->
   
   <!--
   If there are any breaking changes to public APIs, please uncomment the line below and explain which changes are breaking.
   -->
   <!-- **This PR includes breaking changes to public APIs.** -->
   
   <!--
   Please uncomment the line below (and provide explanation) if the changes fix either (a) a security vulnerability, (b) a bug that caused incorrect or invalid data to be produced, or (c) a bug that causes a crash (even when the API contract is upheld). We use this to highlight fixes to issues that may affect users without their knowledge. For this reason, fixing bugs that cause errors don't count, since those are usually obvious.
   -->
   <!-- **This PR contains a "Critical Fix".** -->


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1221796056


##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()
+	}
+
+	out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext)
+	if err == compute.ErrEmpty {
+		out = compute.NewDatum(param)
+	} else if err != nil {
+		return nil, err
+	}
+	if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) {
+		return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s",
+			arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType)
+	}
+
+	return out, nil
+}
+
+func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
+	if !exp.IsScalar() {
+		return nil, fmt.Errorf("%w: ExecuteScalarExpression cannot execute non-scalar expressions",
+			arrow.ErrInvalid)
+	}
+
+	switch e := exp.(type) {
+	case expr.Literal:
+		return literalToDatum(compute.GetAllocator(ctx), e, ext)
+	case *expr.FieldReference:
+		return execFieldRef(ctx, e, input, ext)
+	case *expr.Cast:
+		if e.Input == nil {
+			return nil, fmt.Errorf("%w: cast without argument to cast", arrow.ErrInvalid)
+		}
+
+		arg, err := executeScalarBatch(ctx, input, e.Input, ext)
+		if err != nil {
+			return nil, err
+		}
+		defer arg.Release()
+
+		dt, _, err := FromSubstraitType(e.Type, ext)
+		if err != nil {
+			return nil, fmt.Errorf("%w: could not determine type for cast", err)
+		}
+
+		var opts *compute.CastOptions
+		switch e.FailureBehavior {
+		case types.BehaviorThrowException:
+			opts = compute.SafeCastOptions(dt)
+		case types.BehaviorUnspecified:
+			opts = compute.UnsafeCastOptions(dt)
+		case types.BehaviorReturnNil:
+			return nil, fmt.Errorf("%w: cast behavior return nil", arrow.ErrNotImplemented)
+		}
+		return compute.CastDatum(ctx, arg, opts)
+	case *expr.ScalarFunction:
+		var (
+			err       error
+			allScalar = true
+			args      = make([]compute.Datum, e.NArgs())
+			argTypes  = make([]arrow.DataType, e.NArgs())
+		)
+		for i := 0; i < e.NArgs(); i++ {
+			switch v := e.Arg(i).(type) {
+			case types.Enum:
+				args[i] = compute.NewDatum(scalar.NewStringScalar(string(v)))
+			case expr.Expression:
+				args[i], err = executeScalarBatch(ctx, input, v, ext)
+				if err != nil {
+					return nil, err
+				}
+				defer args[i].Release()
+
+				if args[i].Kind() != compute.KindScalar {
+					allScalar = false

Review Comment:
   The only case I can think of it being a valid function would be a generative function, i.e. a function that generates a result set itself like `SELECT RAND()` or whatnot, but that takes some enum arguments.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] lidavidm commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "lidavidm (via GitHub)" <gi...@apache.org>.
lidavidm commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1197926044


##########
go/arrow/internal/flight_integration/cmd/arrow-flight-integration-server/__debug_bin:
##########


Review Comment:
   Did you mean to check this in? (Should we .gitignore it?)



##########
go/arrow/compute/exprs/types.go:
##########
@@ -0,0 +1,745 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"fmt"
+	"hash/maphash"
+	"strconv"
+	"strings"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+const (
+	// URI for official Arrow Substrait Extension Types
+	ArrowExtTypesUri          = "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"

Review Comment:
   should point to main instead of master?



##########
go/arrow/array/util_test.go:
##########
@@ -522,3 +522,29 @@ func TestRecordBuilderUnmarshalJSONExtraFields(t *testing.T) {
 
 	assert.Truef(t, array.RecordEqual(rec1, rec2), "expected: %s\nactual: %s", rec1, rec2)
 }
+
+func TestJSON(t *testing.T) {

Review Comment:
   What is this test testing for? At first glance it doesn't seem related to the other changes here



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1200778534


##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()
+	}
+
+	out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext)
+	if err == compute.ErrEmpty {
+		out = compute.NewDatum(param)
+	} else if err != nil {
+		return nil, err
+	}
+	if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) {
+		return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s",
+			arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType)
+	}
+
+	return out, nil
+}
+
+func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
+	if !exp.IsScalar() {

Review Comment:
   Currently, that would consist of aggregate / window functions I believe (or anything containing one as an argument). Equivalent to Acero's definition of `scalar` vs `vector` functions, i.e. evaluated element-wise vs array-wise.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on PR #35654:
URL: https://github.com/apache/arrow/pull/35654#issuecomment-1577001800

   @westonpace can you look at my responses to your comments and let me know if there's any other questions/concerns?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1200762505


##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil

Review Comment:
   It's a constant defined as "UTC", defined at the top of `types.go`, matching the inline function `TimestampTzTimezoneString` in the c++ (`cpp/src/arrow/engine/substrait/type_internal.h`)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1200787850


##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()
+	}
+
+	out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext)
+	if err == compute.ErrEmpty {
+		out = compute.NewDatum(param)
+	} else if err != nil {
+		return nil, err
+	}
+	if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) {
+		return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s",
+			arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType)
+	}
+
+	return out, nil
+}
+
+func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
+	if !exp.IsScalar() {
+		return nil, fmt.Errorf("%w: ExecuteScalarExpression cannot execute non-scalar expressions",
+			arrow.ErrInvalid)
+	}
+
+	switch e := exp.(type) {
+	case expr.Literal:
+		return literalToDatum(compute.GetAllocator(ctx), e, ext)
+	case *expr.FieldReference:
+		return execFieldRef(ctx, e, input, ext)
+	case *expr.Cast:
+		if e.Input == nil {
+			return nil, fmt.Errorf("%w: cast without argument to cast", arrow.ErrInvalid)
+		}
+
+		arg, err := executeScalarBatch(ctx, input, e.Input, ext)
+		if err != nil {
+			return nil, err
+		}
+		defer arg.Release()
+
+		dt, _, err := FromSubstraitType(e.Type, ext)
+		if err != nil {
+			return nil, fmt.Errorf("%w: could not determine type for cast", err)
+		}
+
+		var opts *compute.CastOptions
+		switch e.FailureBehavior {
+		case types.BehaviorThrowException:
+			opts = compute.SafeCastOptions(dt)
+		case types.BehaviorUnspecified:
+			opts = compute.UnsafeCastOptions(dt)
+		case types.BehaviorReturnNil:
+			return nil, fmt.Errorf("%w: cast behavior return nil", arrow.ErrNotImplemented)
+		}
+		return compute.CastDatum(ctx, arg, opts)
+	case *expr.ScalarFunction:
+		var (
+			err       error
+			allScalar = true
+			args      = make([]compute.Datum, e.NArgs())
+			argTypes  = make([]arrow.DataType, e.NArgs())
+		)
+		for i := 0; i < e.NArgs(); i++ {
+			switch v := e.Arg(i).(type) {
+			case types.Enum:
+				args[i] = compute.NewDatum(scalar.NewStringScalar(string(v)))
+			case expr.Expression:
+				args[i], err = executeScalarBatch(ctx, input, v, ext)
+				if err != nil {
+					return nil, err
+				}
+				defer args[i].Release()
+
+				if args[i].Kind() != compute.KindScalar {
+					allScalar = false

Review Comment:
   In the case where an expression is a scalar function whose arguments consist *solely* of enum values, we don't want the exec batch length for the input to that function to reflect the top level input. It's a special case where the input length for that evaluation is 1 since there's only the one value (the enum).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1200776869


##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()

Review Comment:
   Because the top level is an `ExecBatch` which contains a slice of `Datum` rather than an individual datum itself with a slice of child arrays/scalars, we use the top reference to index into the input exec batch slice, and then call `GetReferencedValue` (in `field_refs.go`) which contains the while loop that will do the recursive handling of the potentially nested struct reference.
   
   This way the code for `GetReferencedValue` can be trivially recursive without having to contain the special case for the `ExecBatch`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1221802358


##########
go/arrow/compute/exprs/builders.go:
##########
@@ -0,0 +1,445 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"fmt"
+	"strconv"
+	"strings"
+	"unicode"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+// NewDefaultExtensionSet constructs an empty extension set using the default
+// Arrow Extension registry and the default collection of substrait extensions
+// from the Substrait-go repo.
+func NewDefaultExtensionSet() ExtensionIDSet {
+	return NewExtensionSetDefault(expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection))
+}
+
+// NewScalarCall constructs a substrait ScalarFunction expression with the provided
+// options and arguments.
+//
+// The function name (fn) is looked up in the internal Arrow DefaultExtensionIDRegistry
+// to ensure it exists and to convert from the Arrow function name to the substrait
+// function name. It is then looked up using the DefaultCollection from the
+// substrait extensions module to find the declaration. If it cannot be found,
+// we try constructing the compound signature name by getting the types of the
+// arguments which were passed and appending them to the function name appropriately.
+//
+// An error is returned if the function cannot be resolved.
+func NewScalarCall(reg ExtensionIDSet, fn string, opts []*types.FunctionOption, args ...types.FuncArg) (*expr.ScalarFunction, error) {
+	conv, ok := reg.GetArrowRegistry().GetArrowToSubstrait(fn)
+	if !ok {
+		return nil, arrow.ErrNotFound
+	}
+
+	id, convOpts, err := conv(fn)
+	if err != nil {
+		return nil, err
+	}
+
+	opts = append(opts, convOpts...)
+	return expr.NewScalarFunc(reg.GetSubstraitRegistry(), id, opts, args...)
+}
+
+// NewFieldRefFromDotPath constructs a substrait reference segment from
+// a dot path and the base schema.
+//
+// dot_path = '.' name
+//
+//	| '[' digit+ ']'
+//	| dot_path+
+//
+// # Examples
+//
+// Assume root schema of {alpha: i32, beta: struct<gamma: list<i32>>, delta: map<string, i32>}
+//
+//	".alpha" => StructFieldRef(0)
+//	"[2]" => StructFieldRef(2)
+//	".beta[0]" => StructFieldRef(1, StructFieldRef(0))
+//	"[1].gamma[3]" => StructFieldRef(1, StructFieldRef(0, ListElementRef(3)))
+//	".delta.foobar" => StructFieldRef(2, MapKeyRef("foobar"))
+//
+// Note: when parsing a name, a '\' preceding any other character
+// will be dropped from the resulting name. Therefore if a name must
+// contain the characters '.', '\', '[', or ']' then they must be escaped
+// with a preceding '\'.
+func NewFieldRefFromDotPath(dotpath string, rootSchema *arrow.Schema) (expr.ReferenceSegment, error) {
+	if len(dotpath) == 0 {
+		return nil, fmt.Errorf("%w dotpath was empty", arrow.ErrInvalid)
+	}
+
+	parseName := func() string {
+		var name string
+		for {
+			idx := strings.IndexAny(dotpath, `\[.`)
+			if idx == -1 {
+				name += dotpath
+				dotpath = ""
+				break
+			}
+
+			if dotpath[idx] != '\\' {
+				// subscript for a new field ref
+				name += dotpath[:idx]
+				dotpath = dotpath[idx:]
+				break
+			}
+
+			if len(dotpath) == idx+1 {
+				// dotpath ends with a backslash; consume it all
+				name += dotpath
+				dotpath = ""
+				break
+			}
+
+			// append all characters before backslash, then the character which follows it
+			name += dotpath[:idx] + string(dotpath[idx+1])
+			dotpath = dotpath[idx+2:]
+		}
+		return name
+	}
+
+	var curType arrow.DataType = arrow.StructOf(rootSchema.Fields()...)
+	children := make([]expr.ReferenceSegment, 0)
+
+	for len(dotpath) > 0 {
+		subscript := dotpath[0]
+		dotpath = dotpath[1:]
+		switch subscript {
+		case '.':
+			// next element is a name
+			n := parseName()
+			switch ct := curType.(type) {
+			case *arrow.StructType:
+				idx, found := ct.FieldIdx(n)
+				if !found {
+					return nil, fmt.Errorf("%w: dot path '%s' referenced invalid field", arrow.ErrInvalid, dotpath)
+				}
+				children = append(children, &expr.StructFieldRef{Field: int32(idx)})
+				curType = ct.Field(idx).Type
+			case *arrow.MapType:
+				curType = ct.KeyType()
+				switch ct.KeyType().ID() {
+				case arrow.BINARY, arrow.LARGE_BINARY:
+					children = append(children, &expr.MapKeyRef{MapKey: expr.NewByteSliceLiteral([]byte(n), false)})
+				case arrow.STRING, arrow.LARGE_STRING:
+					children = append(children, &expr.MapKeyRef{MapKey: expr.NewPrimitiveLiteral(n, false)})
+				default:
+					return nil, fmt.Errorf("%w: MapKeyRef to non-binary/string map not supported", arrow.ErrNotImplemented)
+				}
+			default:
+				return nil, fmt.Errorf("%w: dot path names must refer to struct fields or map keys", arrow.ErrInvalid)
+			}
+		case '[':
+			subend := strings.IndexFunc(dotpath, func(r rune) bool { return !unicode.IsDigit(r) })
+			if subend == -1 || dotpath[subend] != ']' {
+				return nil, fmt.Errorf("%w: dot path '%s' contained an unterminated index", arrow.ErrInvalid, dotpath)
+			}
+			idx, _ := strconv.Atoi(dotpath[:subend])
+			switch ct := curType.(type) {
+			case *arrow.StructType:
+				if idx > len(ct.Fields()) {
+					return nil, fmt.Errorf("%w: field out of bounds in dotpath", arrow.ErrIndex)
+				}
+				curType = ct.Field(idx).Type
+				children = append(children, &expr.StructFieldRef{Field: int32(idx)})
+			case *arrow.MapType:
+				curType = ct.KeyType()
+				var keyLiteral expr.Literal
+				// TODO: implement user defined types and variations
+				switch ct.KeyType().ID() {
+				case arrow.INT8:
+					keyLiteral = expr.NewPrimitiveLiteral(int8(idx), false)
+				case arrow.INT16:
+					keyLiteral = expr.NewPrimitiveLiteral(int16(idx), false)
+				case arrow.INT32:
+					keyLiteral = expr.NewPrimitiveLiteral(int32(idx), false)
+				case arrow.INT64:
+					keyLiteral = expr.NewPrimitiveLiteral(int64(idx), false)
+				case arrow.FLOAT32:
+					keyLiteral = expr.NewPrimitiveLiteral(float32(idx), false)
+				case arrow.FLOAT64:
+					keyLiteral = expr.NewPrimitiveLiteral(float64(idx), false)

Review Comment:
   so, there's nothing in the spec that says a float *can't* be a map key which is why I didn't mark it as unsupported. But you make a good point about the call to `atoi`. I'll change the float cases to unsupported for now and if anyone has an issue I can address it then.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade merged pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade merged PR #35654:
URL: https://github.com/apache/arrow/pull/35654


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1221798676


##########
go/arrow/compute/exprs/builders.go:
##########
@@ -0,0 +1,445 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"fmt"
+	"strconv"
+	"strings"
+	"unicode"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+// NewDefaultExtensionSet constructs an empty extension set using the default
+// Arrow Extension registry and the default collection of substrait extensions
+// from the Substrait-go repo.
+func NewDefaultExtensionSet() ExtensionIDSet {
+	return NewExtensionSetDefault(expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection))
+}
+
+// NewScalarCall constructs a substrait ScalarFunction expression with the provided
+// options and arguments.
+//
+// The function name (fn) is looked up in the internal Arrow DefaultExtensionIDRegistry
+// to ensure it exists and to convert from the Arrow function name to the substrait
+// function name. It is then looked up using the DefaultCollection from the
+// substrait extensions module to find the declaration. If it cannot be found,
+// we try constructing the compound signature name by getting the types of the
+// arguments which were passed and appending them to the function name appropriately.
+//
+// An error is returned if the function cannot be resolved.
+func NewScalarCall(reg ExtensionIDSet, fn string, opts []*types.FunctionOption, args ...types.FuncArg) (*expr.ScalarFunction, error) {
+	conv, ok := reg.GetArrowRegistry().GetArrowToSubstrait(fn)
+	if !ok {
+		return nil, arrow.ErrNotFound
+	}
+
+	id, convOpts, err := conv(fn)
+	if err != nil {
+		return nil, err
+	}
+
+	opts = append(opts, convOpts...)
+	return expr.NewScalarFunc(reg.GetSubstraitRegistry(), id, opts, args...)
+}
+
+// NewFieldRefFromDotPath constructs a substrait reference segment from
+// a dot path and the base schema.
+//
+// dot_path = '.' name
+//
+//	| '[' digit+ ']'
+//	| dot_path+
+//
+// # Examples
+//
+// Assume root schema of {alpha: i32, beta: struct<gamma: list<i32>>, delta: map<string, i32>}
+//
+//	".alpha" => StructFieldRef(0)
+//	"[2]" => StructFieldRef(2)
+//	".beta[0]" => StructFieldRef(1, StructFieldRef(0))
+//	"[1].gamma[3]" => StructFieldRef(1, StructFieldRef(0, ListElementRef(3)))
+//	".delta.foobar" => StructFieldRef(2, MapKeyRef("foobar"))
+//
+// Note: when parsing a name, a '\' preceding any other character
+// will be dropped from the resulting name. Therefore if a name must
+// contain the characters '.', '\', '[', or ']' then they must be escaped
+// with a preceding '\'.
+func NewFieldRefFromDotPath(dotpath string, rootSchema *arrow.Schema) (expr.ReferenceSegment, error) {
+	if len(dotpath) == 0 {
+		return nil, fmt.Errorf("%w dotpath was empty", arrow.ErrInvalid)
+	}
+
+	parseName := func() string {
+		var name string
+		for {
+			idx := strings.IndexAny(dotpath, `\[.`)
+			if idx == -1 {
+				name += dotpath
+				dotpath = ""
+				break
+			}
+
+			if dotpath[idx] != '\\' {
+				// subscript for a new field ref
+				name += dotpath[:idx]
+				dotpath = dotpath[idx:]
+				break
+			}
+
+			if len(dotpath) == idx+1 {
+				// dotpath ends with a backslash; consume it all
+				name += dotpath
+				dotpath = ""
+				break
+			}
+
+			// append all characters before backslash, then the character which follows it
+			name += dotpath[:idx] + string(dotpath[idx+1])
+			dotpath = dotpath[idx+2:]
+		}
+		return name
+	}
+
+	var curType arrow.DataType = arrow.StructOf(rootSchema.Fields()...)
+	children := make([]expr.ReferenceSegment, 0)
+
+	for len(dotpath) > 0 {
+		subscript := dotpath[0]
+		dotpath = dotpath[1:]
+		switch subscript {
+		case '.':
+			// next element is a name
+			n := parseName()
+			switch ct := curType.(type) {
+			case *arrow.StructType:
+				idx, found := ct.FieldIdx(n)

Review Comment:
   It's a map. The `arrow.StructType` in Go currently doesn't support multiple fields with the same name as it maintains a `map[string]int` to map field names to indices. `FieldIdx` just does a lookup in the map.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1197952196


##########
go/arrow/compute/exprs/types.go:
##########
@@ -0,0 +1,745 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"fmt"
+	"hash/maphash"
+	"strconv"
+	"strings"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+const (
+	// URI for official Arrow Substrait Extension Types
+	ArrowExtTypesUri          = "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml"

Review Comment:
   good catch



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] felipecrv commented on pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "felipecrv (via GitHub)" <gi...@apache.org>.
felipecrv commented on PR #35654:
URL: https://github.com/apache/arrow/pull/35654#issuecomment-1555033474

   > CC @benibus @felipecrv if you two feel like taking a look / reviewing on this despite your lack of familiarity with Go
   
   Yes, but this is a big one, so next week.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] westonpace commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "westonpace (via GitHub)" <gi...@apache.org>.
westonpace commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1220389702


##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()
+	}
+
+	out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext)
+	if err == compute.ErrEmpty {
+		out = compute.NewDatum(param)
+	} else if err != nil {
+		return nil, err
+	}
+	if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) {
+		return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s",
+			arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType)
+	}
+
+	return out, nil
+}
+
+func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
+	if !exp.IsScalar() {
+		return nil, fmt.Errorf("%w: ExecuteScalarExpression cannot execute non-scalar expressions",
+			arrow.ErrInvalid)
+	}
+
+	switch e := exp.(type) {
+	case expr.Literal:
+		return literalToDatum(compute.GetAllocator(ctx), e, ext)
+	case *expr.FieldReference:
+		return execFieldRef(ctx, e, input, ext)
+	case *expr.Cast:
+		if e.Input == nil {
+			return nil, fmt.Errorf("%w: cast without argument to cast", arrow.ErrInvalid)
+		}
+
+		arg, err := executeScalarBatch(ctx, input, e.Input, ext)
+		if err != nil {
+			return nil, err
+		}
+		defer arg.Release()
+
+		dt, _, err := FromSubstraitType(e.Type, ext)
+		if err != nil {
+			return nil, fmt.Errorf("%w: could not determine type for cast", err)
+		}
+
+		var opts *compute.CastOptions
+		switch e.FailureBehavior {
+		case types.BehaviorThrowException:

Review Comment:
   Correct.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] westonpace commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "westonpace (via GitHub)" <gi...@apache.org>.
westonpace commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1197968472


##########
go/arrow/compute/arithmetic.go:
##########
@@ -627,6 +627,8 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
 	}{
 		{"sub_unchecked", kernels.OpSub, decPromoteAdd, subUncheckedDoc},
 		{"sub", kernels.OpSubChecked, decPromoteAdd, subDoc},
+		{"subtract_unchecked", kernels.OpSub, decPromoteAdd, subUncheckedDoc},
+		{"subtract", kernels.OpSubChecked, decPromoteAdd, subDoc},

Review Comment:
   Why is this alias needed?



##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented

Review Comment:
   Return an error here?



##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()
+	}
+
+	out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext)
+	if err == compute.ErrEmpty {
+		out = compute.NewDatum(param)
+	} else if err != nil {
+		return nil, err
+	}
+	if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) {
+		return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s",
+			arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType)
+	}
+
+	return out, nil
+}
+
+func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
+	if !exp.IsScalar() {
+		return nil, fmt.Errorf("%w: ExecuteScalarExpression cannot execute non-scalar expressions",
+			arrow.ErrInvalid)
+	}
+
+	switch e := exp.(type) {
+	case expr.Literal:
+		return literalToDatum(compute.GetAllocator(ctx), e, ext)
+	case *expr.FieldReference:
+		return execFieldRef(ctx, e, input, ext)
+	case *expr.Cast:
+		if e.Input == nil {
+			return nil, fmt.Errorf("%w: cast without argument to cast", arrow.ErrInvalid)
+		}
+
+		arg, err := executeScalarBatch(ctx, input, e.Input, ext)
+		if err != nil {
+			return nil, err
+		}
+		defer arg.Release()
+
+		dt, _, err := FromSubstraitType(e.Type, ext)
+		if err != nil {
+			return nil, fmt.Errorf("%w: could not determine type for cast", err)
+		}
+
+		var opts *compute.CastOptions
+		switch e.FailureBehavior {
+		case types.BehaviorThrowException:

Review Comment:
   Substrait desperately needs clarification here (interested in a PR?)  However, you will probably be better off if you interpret Substrait's cast behavior the same as SQL's cast behavior which is "always unsafe with controls to handle situations where a cast is impossible".
   
   In other words, casting 3.7 to int will always give you 3, regardless of the casting behavior.  Casting NaN or Infinity to int will either give you null or throw an exception, depending on casting behavior.



##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))

Review Comment:
   I didn't think the struct literal had names?  Or maybe you're making empty strings for names here?



##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()
+	}
+
+	out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext)
+	if err == compute.ErrEmpty {
+		out = compute.NewDatum(param)
+	} else if err != nil {
+		return nil, err
+	}
+	if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) {
+		return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s",
+			arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType)
+	}
+
+	return out, nil
+}
+
+func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
+	if !exp.IsScalar() {
+		return nil, fmt.Errorf("%w: ExecuteScalarExpression cannot execute non-scalar expressions",
+			arrow.ErrInvalid)
+	}
+
+	switch e := exp.(type) {
+	case expr.Literal:
+		return literalToDatum(compute.GetAllocator(ctx), e, ext)
+	case *expr.FieldReference:
+		return execFieldRef(ctx, e, input, ext)
+	case *expr.Cast:
+		if e.Input == nil {
+			return nil, fmt.Errorf("%w: cast without argument to cast", arrow.ErrInvalid)
+		}
+
+		arg, err := executeScalarBatch(ctx, input, e.Input, ext)
+		if err != nil {
+			return nil, err
+		}
+		defer arg.Release()
+
+		dt, _, err := FromSubstraitType(e.Type, ext)
+		if err != nil {
+			return nil, fmt.Errorf("%w: could not determine type for cast", err)
+		}
+
+		var opts *compute.CastOptions
+		switch e.FailureBehavior {
+		case types.BehaviorThrowException:
+			opts = compute.SafeCastOptions(dt)
+		case types.BehaviorUnspecified:
+			opts = compute.UnsafeCastOptions(dt)
+		case types.BehaviorReturnNil:
+			return nil, fmt.Errorf("%w: cast behavior return nil", arrow.ErrNotImplemented)
+		}
+		return compute.CastDatum(ctx, arg, opts)
+	case *expr.ScalarFunction:
+		var (
+			err       error
+			allScalar = true
+			args      = make([]compute.Datum, e.NArgs())
+			argTypes  = make([]arrow.DataType, e.NArgs())
+		)
+		for i := 0; i < e.NArgs(); i++ {
+			switch v := e.Arg(i).(type) {
+			case types.Enum:
+				args[i] = compute.NewDatum(scalar.NewStringScalar(string(v)))
+			case expr.Expression:
+				args[i], err = executeScalarBatch(ctx, input, v, ext)
+				if err != nil {
+					return nil, err
+				}
+				defer args[i].Release()
+
+				if args[i].Kind() != compute.KindScalar {
+					allScalar = false
+				}
+			default:
+				return nil, arrow.ErrNotImplemented
+			}
+
+			argTypes[i] = args[i].(compute.ArrayLikeDatum).Type()
+		}
+
+		_, conv, ok := ext.DecodeFunction(e.FuncRef())
+		if !ok {
+			return nil, arrow.ErrNotImplemented
+		}
+
+		fname, opts, err := conv(e)
+		if err != nil {
+			return nil, err
+		}
+
+		ectx := compute.GetExecCtx(ctx)
+		fn, ok := ectx.Registry.GetFunction(fname)
+		if !ok {
+			return nil, arrow.ErrInvalid
+		}
+
+		if fn.Kind() != compute.FuncScalar {
+			return nil, arrow.ErrInvalid
+		}
+
+		k, err := fn.DispatchBest(argTypes...)
+		if err != nil {
+			return nil, err
+		}
+
+		var newArgs []compute.Datum
+		// cast arguments if necessary
+		for i, arg := range args {
+			if !arrow.TypeEqual(argTypes[i], arg.(compute.ArrayLikeDatum).Type()) {
+				if newArgs == nil {
+					newArgs = make([]compute.Datum, len(args))
+					copy(newArgs, args)
+				}
+				newArgs[i], err = compute.CastDatum(ctx, arg, compute.SafeCastOptions(argTypes[i]))
+				if err != nil {
+					return nil, err
+				}
+				defer newArgs[i].Release()
+			}
+		}
+		if newArgs != nil {
+			args = newArgs
+		}

Review Comment:
   Technically this kind of implicit casting shouldn't be needed.  Substrait plans should always represent casts with explicit casts.



##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil

Review Comment:
   What is TimestampTzTimezone?



##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()
+	}
+
+	out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext)
+	if err == compute.ErrEmpty {
+		out = compute.NewDatum(param)
+	} else if err != nil {
+		return nil, err
+	}
+	if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) {
+		return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s",
+			arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType)
+	}
+
+	return out, nil
+}
+
+func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
+	if !exp.IsScalar() {
+		return nil, fmt.Errorf("%w: ExecuteScalarExpression cannot execute non-scalar expressions",
+			arrow.ErrInvalid)
+	}
+
+	switch e := exp.(type) {
+	case expr.Literal:
+		return literalToDatum(compute.GetAllocator(ctx), e, ext)
+	case *expr.FieldReference:
+		return execFieldRef(ctx, e, input, ext)
+	case *expr.Cast:
+		if e.Input == nil {
+			return nil, fmt.Errorf("%w: cast without argument to cast", arrow.ErrInvalid)
+		}
+
+		arg, err := executeScalarBatch(ctx, input, e.Input, ext)
+		if err != nil {
+			return nil, err
+		}
+		defer arg.Release()
+
+		dt, _, err := FromSubstraitType(e.Type, ext)
+		if err != nil {
+			return nil, fmt.Errorf("%w: could not determine type for cast", err)
+		}
+
+		var opts *compute.CastOptions
+		switch e.FailureBehavior {
+		case types.BehaviorThrowException:
+			opts = compute.SafeCastOptions(dt)
+		case types.BehaviorUnspecified:

Review Comment:
   "unspecified" means "there is no reasonable default and the producer did not pick a behavior".  Anytime an enum variant of _UNSPECIFIED appears in a plan then that plan is invalid.  We include it to avoid producers accidentally creating valid plans when they forget to set a field that doesn't have a default.



##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()
+	}
+
+	out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext)
+	if err == compute.ErrEmpty {
+		out = compute.NewDatum(param)
+	} else if err != nil {
+		return nil, err
+	}
+	if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) {
+		return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s",
+			arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType)
+	}
+
+	return out, nil
+}
+
+func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
+	if !exp.IsScalar() {
+		return nil, fmt.Errorf("%w: ExecuteScalarExpression cannot execute non-scalar expressions",
+			arrow.ErrInvalid)
+	}
+
+	switch e := exp.(type) {
+	case expr.Literal:
+		return literalToDatum(compute.GetAllocator(ctx), e, ext)
+	case *expr.FieldReference:
+		return execFieldRef(ctx, e, input, ext)
+	case *expr.Cast:
+		if e.Input == nil {
+			return nil, fmt.Errorf("%w: cast without argument to cast", arrow.ErrInvalid)
+		}
+
+		arg, err := executeScalarBatch(ctx, input, e.Input, ext)
+		if err != nil {
+			return nil, err
+		}
+		defer arg.Release()
+
+		dt, _, err := FromSubstraitType(e.Type, ext)
+		if err != nil {
+			return nil, fmt.Errorf("%w: could not determine type for cast", err)
+		}
+
+		var opts *compute.CastOptions
+		switch e.FailureBehavior {
+		case types.BehaviorThrowException:
+			opts = compute.SafeCastOptions(dt)
+		case types.BehaviorUnspecified:
+			opts = compute.UnsafeCastOptions(dt)
+		case types.BehaviorReturnNil:
+			return nil, fmt.Errorf("%w: cast behavior return nil", arrow.ErrNotImplemented)
+		}
+		return compute.CastDatum(ctx, arg, opts)
+	case *expr.ScalarFunction:
+		var (
+			err       error
+			allScalar = true
+			args      = make([]compute.Datum, e.NArgs())
+			argTypes  = make([]arrow.DataType, e.NArgs())
+		)
+		for i := 0; i < e.NArgs(); i++ {
+			switch v := e.Arg(i).(type) {
+			case types.Enum:
+				args[i] = compute.NewDatum(scalar.NewStringScalar(string(v)))
+			case expr.Expression:
+				args[i], err = executeScalarBatch(ctx, input, v, ext)
+				if err != nil {
+					return nil, err
+				}
+				defer args[i].Release()
+
+				if args[i].Kind() != compute.KindScalar {
+					allScalar = false

Review Comment:
   Why do you care about `allScalar`?



##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()

Review Comment:
   Is this a while loop?  A `StructField` reference is a recursively nested structure (e.g. 0.3.1) to mean "grab the 0th field, then grab the 3rd field of that, then grab the 1th field of that).



##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()
+	}
+
+	out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext)
+	if err == compute.ErrEmpty {
+		out = compute.NewDatum(param)
+	} else if err != nil {
+		return nil, err
+	}
+	if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) {
+		return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s",
+			arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType)
+	}
+
+	return out, nil
+}
+
+func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
+	if !exp.IsScalar() {

Review Comment:
   What's a non-scalar expression?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1200770354


##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented

Review Comment:
   it'll hit the case of returning `arrow.ErrNotImplemented` which is at the bottom of the function for anything which doesn't return during the switch.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1197946352


##########
go/arrow/internal/flight_integration/cmd/arrow-flight-integration-server/__debug_bin:
##########


Review Comment:
   -_- i don't even know how that got checked in, i'll remove and add it to the git ignore



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on PR #35654:
URL: https://github.com/apache/arrow/pull/35654#issuecomment-1553222636

   CC @benibus @felipecrv if you two feel like taking a look / reviewing on this despite your lack of familiarity with Go


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1200789175


##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()
+	}
+
+	out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext)
+	if err == compute.ErrEmpty {
+		out = compute.NewDatum(param)
+	} else if err != nil {
+		return nil, err
+	}
+	if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) {
+		return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s",
+			arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType)
+	}
+
+	return out, nil
+}
+
+func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
+	if !exp.IsScalar() {
+		return nil, fmt.Errorf("%w: ExecuteScalarExpression cannot execute non-scalar expressions",
+			arrow.ErrInvalid)
+	}
+
+	switch e := exp.(type) {
+	case expr.Literal:
+		return literalToDatum(compute.GetAllocator(ctx), e, ext)
+	case *expr.FieldReference:
+		return execFieldRef(ctx, e, input, ext)
+	case *expr.Cast:
+		if e.Input == nil {
+			return nil, fmt.Errorf("%w: cast without argument to cast", arrow.ErrInvalid)
+		}
+
+		arg, err := executeScalarBatch(ctx, input, e.Input, ext)
+		if err != nil {
+			return nil, err
+		}
+		defer arg.Release()
+
+		dt, _, err := FromSubstraitType(e.Type, ext)
+		if err != nil {
+			return nil, fmt.Errorf("%w: could not determine type for cast", err)
+		}
+
+		var opts *compute.CastOptions
+		switch e.FailureBehavior {
+		case types.BehaviorThrowException:
+			opts = compute.SafeCastOptions(dt)
+		case types.BehaviorUnspecified:
+			opts = compute.UnsafeCastOptions(dt)
+		case types.BehaviorReturnNil:
+			return nil, fmt.Errorf("%w: cast behavior return nil", arrow.ErrNotImplemented)
+		}
+		return compute.CastDatum(ctx, arg, opts)
+	case *expr.ScalarFunction:
+		var (
+			err       error
+			allScalar = true
+			args      = make([]compute.Datum, e.NArgs())
+			argTypes  = make([]arrow.DataType, e.NArgs())
+		)
+		for i := 0; i < e.NArgs(); i++ {
+			switch v := e.Arg(i).(type) {
+			case types.Enum:
+				args[i] = compute.NewDatum(scalar.NewStringScalar(string(v)))
+			case expr.Expression:
+				args[i], err = executeScalarBatch(ctx, input, v, ext)
+				if err != nil {
+					return nil, err
+				}
+				defer args[i].Release()
+
+				if args[i].Kind() != compute.KindScalar {
+					allScalar = false
+				}
+			default:
+				return nil, arrow.ErrNotImplemented
+			}
+
+			argTypes[i] = args[i].(compute.ArrayLikeDatum).Type()
+		}
+
+		_, conv, ok := ext.DecodeFunction(e.FuncRef())
+		if !ok {
+			return nil, arrow.ErrNotImplemented
+		}
+
+		fname, opts, err := conv(e)
+		if err != nil {
+			return nil, err
+		}
+
+		ectx := compute.GetExecCtx(ctx)
+		fn, ok := ectx.Registry.GetFunction(fname)
+		if !ok {
+			return nil, arrow.ErrInvalid
+		}
+
+		if fn.Kind() != compute.FuncScalar {
+			return nil, arrow.ErrInvalid
+		}
+
+		k, err := fn.DispatchBest(argTypes...)
+		if err != nil {
+			return nil, err
+		}
+
+		var newArgs []compute.Datum
+		// cast arguments if necessary
+		for i, arg := range args {
+			if !arrow.TypeEqual(argTypes[i], arg.(compute.ArrayLikeDatum).Type()) {
+				if newArgs == nil {
+					newArgs = make([]compute.Datum, len(args))
+					copy(newArgs, args)
+				}
+				newArgs[i], err = compute.CastDatum(ctx, arg, compute.SafeCastOptions(argTypes[i]))
+				if err != nil {
+					return nil, err
+				}
+				defer newArgs[i].Release()
+			}
+		}
+		if newArgs != nil {
+			args = newArgs
+		}

Review Comment:
   Good point! this was carried over from my original work where I was evaluating general expressions before I converted this to evaluating substrait. I'll remove this to simplify it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1221876814


##########
go/arrow/compute/exprs/builders.go:
##########
@@ -0,0 +1,445 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"fmt"
+	"strconv"
+	"strings"
+	"unicode"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+// NewDefaultExtensionSet constructs an empty extension set using the default
+// Arrow Extension registry and the default collection of substrait extensions
+// from the Substrait-go repo.
+func NewDefaultExtensionSet() ExtensionIDSet {
+	return NewExtensionSetDefault(expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection))
+}
+
+// NewScalarCall constructs a substrait ScalarFunction expression with the provided
+// options and arguments.
+//
+// The function name (fn) is looked up in the internal Arrow DefaultExtensionIDRegistry
+// to ensure it exists and to convert from the Arrow function name to the substrait
+// function name. It is then looked up using the DefaultCollection from the
+// substrait extensions module to find the declaration. If it cannot be found,
+// we try constructing the compound signature name by getting the types of the
+// arguments which were passed and appending them to the function name appropriately.
+//
+// An error is returned if the function cannot be resolved.
+func NewScalarCall(reg ExtensionIDSet, fn string, opts []*types.FunctionOption, args ...types.FuncArg) (*expr.ScalarFunction, error) {
+	conv, ok := reg.GetArrowRegistry().GetArrowToSubstrait(fn)
+	if !ok {
+		return nil, arrow.ErrNotFound
+	}
+
+	id, convOpts, err := conv(fn)
+	if err != nil {
+		return nil, err
+	}
+
+	opts = append(opts, convOpts...)
+	return expr.NewScalarFunc(reg.GetSubstraitRegistry(), id, opts, args...)
+}
+
+// NewFieldRefFromDotPath constructs a substrait reference segment from
+// a dot path and the base schema.
+//
+// dot_path = '.' name
+//
+//	| '[' digit+ ']'
+//	| dot_path+
+//
+// # Examples
+//
+// Assume root schema of {alpha: i32, beta: struct<gamma: list<i32>>, delta: map<string, i32>}
+//
+//	".alpha" => StructFieldRef(0)
+//	"[2]" => StructFieldRef(2)
+//	".beta[0]" => StructFieldRef(1, StructFieldRef(0))
+//	"[1].gamma[3]" => StructFieldRef(1, StructFieldRef(0, ListElementRef(3)))
+//	".delta.foobar" => StructFieldRef(2, MapKeyRef("foobar"))
+//
+// Note: when parsing a name, a '\' preceding any other character
+// will be dropped from the resulting name. Therefore if a name must
+// contain the characters '.', '\', '[', or ']' then they must be escaped
+// with a preceding '\'.
+func NewFieldRefFromDotPath(dotpath string, rootSchema *arrow.Schema) (expr.ReferenceSegment, error) {
+	if len(dotpath) == 0 {
+		return nil, fmt.Errorf("%w dotpath was empty", arrow.ErrInvalid)
+	}
+
+	parseName := func() string {
+		var name string
+		for {
+			idx := strings.IndexAny(dotpath, `\[.`)
+			if idx == -1 {
+				name += dotpath
+				dotpath = ""
+				break
+			}
+
+			if dotpath[idx] != '\\' {
+				// subscript for a new field ref

Review Comment:
   On line 120, the `name += dotpath[:idx] + string(dotpath[idx+1])` will end up adding the `]` to the `name`, and then we consume the `]` on the next line by doing `dotpath = dotpath[idx+2:]`. For this function, the `parseName` there isn't necessarily a requirement that there be a matching `]`, since you only end up with a `[` here if it's escaped with `\\`.
   
   For the index case of something like `[2]`, the `]` is consumed by line 199 when we do `dotpath = dotpath[subend+1:]` to skip past the `]` (as `dotpath[subend]` is the closing `]`).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] ursabot commented on pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "ursabot (via GitHub)" <gi...@apache.org>.
ursabot commented on PR #35654:
URL: https://github.com/apache/arrow/pull/35654#issuecomment-1583799463

   Benchmark runs are scheduled for baseline = acf3cbac6bdddb8e2b60e1095f591ee84b17f00c and contender = 9be7074f85d6057a92bc008766e8f30d58b30cf7. 9be7074f85d6057a92bc008766e8f30d58b30cf7 is a master commit associated with this PR. Results will be available as each benchmark for each run completes.
   Conbench compare runs links:
   [Finished :arrow_down:0.0% :arrow_up:0.0%] [ec2-t3-xlarge-us-east-2](https://conbench.ursa.dev/compare/runs/dcd400871ec04a608f4a8093cfcfe1f7...809fd6b678e443428430112e3ac54e37/)
   [Finished :arrow_down:0.62% :arrow_up:0.03%] [test-mac-arm](https://conbench.ursa.dev/compare/runs/b9c92d51fe4141b4b04def34de8c4c37...d80881575eee41aa9634bdf49960bc96/)
   [Finished :arrow_down:0.65% :arrow_up:0.65%] [ursa-i9-9960x](https://conbench.ursa.dev/compare/runs/adc7905634f7416889f932ab6fed5c15...a4a568b61f6a4cb2a3737b893e567f6b/)
   [Finished :arrow_down:0.3% :arrow_up:0.3%] [ursa-thinkcentre-m75q](https://conbench.ursa.dev/compare/runs/af28c0a7dfde46d3a1218704a465463d...ab7d75173a944cbba3a44a093755bfca/)
   Buildkite builds:
   [Finished] [`9be7074f` ec2-t3-xlarge-us-east-2](https://buildkite.com/apache-arrow/arrow-bci-benchmark-on-ec2-t3-xlarge-us-east-2/builds/3002)
   [Finished] [`9be7074f` test-mac-arm](https://buildkite.com/apache-arrow/arrow-bci-benchmark-on-test-mac-arm/builds/3038)
   [Finished] [`9be7074f` ursa-i9-9960x](https://buildkite.com/apache-arrow/arrow-bci-benchmark-on-ursa-i9-9960x/builds/3003)
   [Finished] [`9be7074f` ursa-thinkcentre-m75q](https://buildkite.com/apache-arrow/arrow-bci-benchmark-on-ursa-thinkcentre-m75q/builds/3028)
   [Finished] [`acf3cbac` ec2-t3-xlarge-us-east-2](https://buildkite.com/apache-arrow/arrow-bci-benchmark-on-ec2-t3-xlarge-us-east-2/builds/3001)
   [Finished] [`acf3cbac` test-mac-arm](https://buildkite.com/apache-arrow/arrow-bci-benchmark-on-test-mac-arm/builds/3037)
   [Finished] [`acf3cbac` ursa-i9-9960x](https://buildkite.com/apache-arrow/arrow-bci-benchmark-on-ursa-i9-9960x/builds/3002)
   [Finished] [`acf3cbac` ursa-thinkcentre-m75q](https://buildkite.com/apache-arrow/arrow-bci-benchmark-on-ursa-thinkcentre-m75q/builds/3027)
   Supported benchmarks:
   ec2-t3-xlarge-us-east-2: Supported benchmark langs: Python, R. Runs only benchmarks with cloud = True
   test-mac-arm: Supported benchmark langs: C++, Python, R
   ursa-i9-9960x: Supported benchmark langs: Python, R, JavaScript
   ursa-thinkcentre-m75q: Supported benchmark langs: C++, Java
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] github-actions[bot] commented on pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "github-actions[bot] (via GitHub)" <gi...@apache.org>.
github-actions[bot] commented on PR #35654:
URL: https://github.com/apache/arrow/pull/35654#issuecomment-1552067310

   * Closes: #35652


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1197947503


##########
go/arrow/array/util_test.go:
##########
@@ -522,3 +522,29 @@ func TestRecordBuilderUnmarshalJSONExtraFields(t *testing.T) {
 
 	assert.Truef(t, array.RecordEqual(rec1, rec2), "expected: %s\nactual: %s", rec1, rec2)
 }
+
+func TestJSON(t *testing.T) {

Review Comment:
   honestly? I don't know. It must have gotten caught in a git stash pop that i did at some point. I'll take this out.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1200760262


##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()
+	}
+
+	out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext)
+	if err == compute.ErrEmpty {
+		out = compute.NewDatum(param)
+	} else if err != nil {
+		return nil, err
+	}
+	if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) {
+		return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s",
+			arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType)
+	}
+
+	return out, nil
+}
+
+func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
+	if !exp.IsScalar() {
+		return nil, fmt.Errorf("%w: ExecuteScalarExpression cannot execute non-scalar expressions",
+			arrow.ErrInvalid)
+	}
+
+	switch e := exp.(type) {
+	case expr.Literal:
+		return literalToDatum(compute.GetAllocator(ctx), e, ext)
+	case *expr.FieldReference:
+		return execFieldRef(ctx, e, input, ext)
+	case *expr.Cast:
+		if e.Input == nil {
+			return nil, fmt.Errorf("%w: cast without argument to cast", arrow.ErrInvalid)
+		}
+
+		arg, err := executeScalarBatch(ctx, input, e.Input, ext)
+		if err != nil {
+			return nil, err
+		}
+		defer arg.Release()
+
+		dt, _, err := FromSubstraitType(e.Type, ext)
+		if err != nil {
+			return nil, fmt.Errorf("%w: could not determine type for cast", err)
+		}
+
+		var opts *compute.CastOptions
+		switch e.FailureBehavior {
+		case types.BehaviorThrowException:

Review Comment:
   ah, so as things currently stand there's no option for checking for overflows in substrait casting?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1200766866


##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))

Review Comment:
   correct, it just uses empty strings for the names.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1200758149


##########
go/arrow/compute/arithmetic.go:
##########
@@ -627,6 +627,8 @@ func RegisterScalarArithmetic(reg FunctionRegistry) {
 	}{
 		{"sub_unchecked", kernels.OpSub, decPromoteAdd, subUncheckedDoc},
 		{"sub", kernels.OpSubChecked, decPromoteAdd, subDoc},
+		{"subtract_unchecked", kernels.OpSub, decPromoteAdd, subUncheckedDoc},
+		{"subtract", kernels.OpSubChecked, decPromoteAdd, subDoc},

Review Comment:
   just made things easier since for all of the *other* functions I was able to have a simple loop since the names matched and I didn't want to have a separate case for only subtract. It's not strictly needed, it just made some code cleaner.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] westonpace commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "westonpace (via GitHub)" <gi...@apache.org>.
westonpace commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1220404234


##########
go/arrow/compute/exprs/builders.go:
##########
@@ -0,0 +1,445 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"fmt"
+	"strconv"
+	"strings"
+	"unicode"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+// NewDefaultExtensionSet constructs an empty extension set using the default
+// Arrow Extension registry and the default collection of substrait extensions
+// from the Substrait-go repo.
+func NewDefaultExtensionSet() ExtensionIDSet {
+	return NewExtensionSetDefault(expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection))
+}
+
+// NewScalarCall constructs a substrait ScalarFunction expression with the provided
+// options and arguments.
+//
+// The function name (fn) is looked up in the internal Arrow DefaultExtensionIDRegistry
+// to ensure it exists and to convert from the Arrow function name to the substrait
+// function name. It is then looked up using the DefaultCollection from the
+// substrait extensions module to find the declaration. If it cannot be found,
+// we try constructing the compound signature name by getting the types of the
+// arguments which were passed and appending them to the function name appropriately.
+//
+// An error is returned if the function cannot be resolved.
+func NewScalarCall(reg ExtensionIDSet, fn string, opts []*types.FunctionOption, args ...types.FuncArg) (*expr.ScalarFunction, error) {
+	conv, ok := reg.GetArrowRegistry().GetArrowToSubstrait(fn)
+	if !ok {
+		return nil, arrow.ErrNotFound
+	}
+
+	id, convOpts, err := conv(fn)
+	if err != nil {
+		return nil, err
+	}
+
+	opts = append(opts, convOpts...)
+	return expr.NewScalarFunc(reg.GetSubstraitRegistry(), id, opts, args...)
+}
+
+// NewFieldRefFromDotPath constructs a substrait reference segment from
+// a dot path and the base schema.
+//
+// dot_path = '.' name
+//
+//	| '[' digit+ ']'
+//	| dot_path+
+//
+// # Examples
+//
+// Assume root schema of {alpha: i32, beta: struct<gamma: list<i32>>, delta: map<string, i32>}
+//
+//	".alpha" => StructFieldRef(0)
+//	"[2]" => StructFieldRef(2)
+//	".beta[0]" => StructFieldRef(1, StructFieldRef(0))
+//	"[1].gamma[3]" => StructFieldRef(1, StructFieldRef(0, ListElementRef(3)))
+//	".delta.foobar" => StructFieldRef(2, MapKeyRef("foobar"))
+//
+// Note: when parsing a name, a '\' preceding any other character
+// will be dropped from the resulting name. Therefore if a name must
+// contain the characters '.', '\', '[', or ']' then they must be escaped
+// with a preceding '\'.
+func NewFieldRefFromDotPath(dotpath string, rootSchema *arrow.Schema) (expr.ReferenceSegment, error) {
+	if len(dotpath) == 0 {
+		return nil, fmt.Errorf("%w dotpath was empty", arrow.ErrInvalid)
+	}
+
+	parseName := func() string {
+		var name string
+		for {
+			idx := strings.IndexAny(dotpath, `\[.`)
+			if idx == -1 {
+				name += dotpath
+				dotpath = ""
+				break
+			}
+
+			if dotpath[idx] != '\\' {
+				// subscript for a new field ref
+				name += dotpath[:idx]
+				dotpath = dotpath[idx:]
+				break
+			}
+
+			if len(dotpath) == idx+1 {
+				// dotpath ends with a backslash; consume it all
+				name += dotpath
+				dotpath = ""
+				break
+			}
+
+			// append all characters before backslash, then the character which follows it
+			name += dotpath[:idx] + string(dotpath[idx+1])
+			dotpath = dotpath[idx+2:]
+		}
+		return name
+	}
+
+	var curType arrow.DataType = arrow.StructOf(rootSchema.Fields()...)
+	children := make([]expr.ReferenceSegment, 0)
+
+	for len(dotpath) > 0 {
+		subscript := dotpath[0]
+		dotpath = dotpath[1:]
+		switch subscript {
+		case '.':
+			// next element is a name
+			n := parseName()
+			switch ct := curType.(type) {
+			case *arrow.StructType:
+				idx, found := ct.FieldIdx(n)

Review Comment:
   is `FieldIdx` a linear lookup in the types children?  It's probably not something to worry about yet but we've had cases in datasets where users have had thousands of columns and this sort of thing turned into an O(N^2) operation on thousands of values.



##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()
+	}
+
+	out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext)
+	if err == compute.ErrEmpty {
+		out = compute.NewDatum(param)
+	} else if err != nil {
+		return nil, err
+	}
+	if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) {
+		return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s",
+			arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType)
+	}
+
+	return out, nil
+}
+
+func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
+	if !exp.IsScalar() {

Review Comment:
   So, in Substrait, I've only ever seen the terms "scalar", "aggregate", and "window" applied to functions, and not expressions.  I think @ianmcook once told me that all expressions are "scalar" (e.g. sum(x) takes in one thing `x` and returns one thing).
   
   That being said, I know we do have some of this terminology in Arrow-C++ (e.g. execute scalar expression) and so maybe it's completely valid.  Maybe just something to watch out for.  E.g. it's perfectly valid, in purely logical Substrait, to have something like `1 + SUM(x + 2)` as an expression but, to evaluate it in Acero, we need to break it up into three expressions (scalar `x+2`, aggregate `SUM(y)`, scalar `1 + z`).



##########
go/arrow/compute/exprs/builders.go:
##########
@@ -0,0 +1,445 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"fmt"
+	"strconv"
+	"strings"
+	"unicode"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+// NewDefaultExtensionSet constructs an empty extension set using the default
+// Arrow Extension registry and the default collection of substrait extensions
+// from the Substrait-go repo.
+func NewDefaultExtensionSet() ExtensionIDSet {
+	return NewExtensionSetDefault(expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection))
+}
+
+// NewScalarCall constructs a substrait ScalarFunction expression with the provided
+// options and arguments.
+//
+// The function name (fn) is looked up in the internal Arrow DefaultExtensionIDRegistry
+// to ensure it exists and to convert from the Arrow function name to the substrait
+// function name. It is then looked up using the DefaultCollection from the
+// substrait extensions module to find the declaration. If it cannot be found,
+// we try constructing the compound signature name by getting the types of the
+// arguments which were passed and appending them to the function name appropriately.
+//
+// An error is returned if the function cannot be resolved.
+func NewScalarCall(reg ExtensionIDSet, fn string, opts []*types.FunctionOption, args ...types.FuncArg) (*expr.ScalarFunction, error) {
+	conv, ok := reg.GetArrowRegistry().GetArrowToSubstrait(fn)
+	if !ok {
+		return nil, arrow.ErrNotFound
+	}
+
+	id, convOpts, err := conv(fn)
+	if err != nil {
+		return nil, err
+	}
+
+	opts = append(opts, convOpts...)
+	return expr.NewScalarFunc(reg.GetSubstraitRegistry(), id, opts, args...)
+}
+
+// NewFieldRefFromDotPath constructs a substrait reference segment from
+// a dot path and the base schema.
+//
+// dot_path = '.' name
+//
+//	| '[' digit+ ']'
+//	| dot_path+
+//
+// # Examples
+//
+// Assume root schema of {alpha: i32, beta: struct<gamma: list<i32>>, delta: map<string, i32>}
+//
+//	".alpha" => StructFieldRef(0)
+//	"[2]" => StructFieldRef(2)
+//	".beta[0]" => StructFieldRef(1, StructFieldRef(0))
+//	"[1].gamma[3]" => StructFieldRef(1, StructFieldRef(0, ListElementRef(3)))
+//	".delta.foobar" => StructFieldRef(2, MapKeyRef("foobar"))
+//
+// Note: when parsing a name, a '\' preceding any other character
+// will be dropped from the resulting name. Therefore if a name must
+// contain the characters '.', '\', '[', or ']' then they must be escaped
+// with a preceding '\'.
+func NewFieldRefFromDotPath(dotpath string, rootSchema *arrow.Schema) (expr.ReferenceSegment, error) {
+	if len(dotpath) == 0 {
+		return nil, fmt.Errorf("%w dotpath was empty", arrow.ErrInvalid)
+	}
+
+	parseName := func() string {
+		var name string
+		for {
+			idx := strings.IndexAny(dotpath, `\[.`)
+			if idx == -1 {
+				name += dotpath
+				dotpath = ""
+				break
+			}
+
+			if dotpath[idx] != '\\' {
+				// subscript for a new field ref

Review Comment:
   What consumes the `]`?



##########
go/arrow/compute/exprs/builders.go:
##########
@@ -0,0 +1,445 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"fmt"
+	"strconv"
+	"strings"
+	"unicode"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+// NewDefaultExtensionSet constructs an empty extension set using the default
+// Arrow Extension registry and the default collection of substrait extensions
+// from the Substrait-go repo.
+func NewDefaultExtensionSet() ExtensionIDSet {
+	return NewExtensionSetDefault(expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection))
+}
+
+// NewScalarCall constructs a substrait ScalarFunction expression with the provided
+// options and arguments.
+//
+// The function name (fn) is looked up in the internal Arrow DefaultExtensionIDRegistry
+// to ensure it exists and to convert from the Arrow function name to the substrait
+// function name. It is then looked up using the DefaultCollection from the
+// substrait extensions module to find the declaration. If it cannot be found,
+// we try constructing the compound signature name by getting the types of the
+// arguments which were passed and appending them to the function name appropriately.
+//
+// An error is returned if the function cannot be resolved.
+func NewScalarCall(reg ExtensionIDSet, fn string, opts []*types.FunctionOption, args ...types.FuncArg) (*expr.ScalarFunction, error) {
+	conv, ok := reg.GetArrowRegistry().GetArrowToSubstrait(fn)
+	if !ok {
+		return nil, arrow.ErrNotFound
+	}
+
+	id, convOpts, err := conv(fn)
+	if err != nil {
+		return nil, err
+	}
+
+	opts = append(opts, convOpts...)
+	return expr.NewScalarFunc(reg.GetSubstraitRegistry(), id, opts, args...)
+}
+
+// NewFieldRefFromDotPath constructs a substrait reference segment from
+// a dot path and the base schema.
+//
+// dot_path = '.' name
+//
+//	| '[' digit+ ']'
+//	| dot_path+
+//
+// # Examples
+//
+// Assume root schema of {alpha: i32, beta: struct<gamma: list<i32>>, delta: map<string, i32>}
+//
+//	".alpha" => StructFieldRef(0)
+//	"[2]" => StructFieldRef(2)
+//	".beta[0]" => StructFieldRef(1, StructFieldRef(0))
+//	"[1].gamma[3]" => StructFieldRef(1, StructFieldRef(0, ListElementRef(3)))
+//	".delta.foobar" => StructFieldRef(2, MapKeyRef("foobar"))
+//
+// Note: when parsing a name, a '\' preceding any other character
+// will be dropped from the resulting name. Therefore if a name must
+// contain the characters '.', '\', '[', or ']' then they must be escaped
+// with a preceding '\'.
+func NewFieldRefFromDotPath(dotpath string, rootSchema *arrow.Schema) (expr.ReferenceSegment, error) {
+	if len(dotpath) == 0 {
+		return nil, fmt.Errorf("%w dotpath was empty", arrow.ErrInvalid)
+	}
+
+	parseName := func() string {
+		var name string
+		for {
+			idx := strings.IndexAny(dotpath, `\[.`)
+			if idx == -1 {
+				name += dotpath
+				dotpath = ""
+				break
+			}
+
+			if dotpath[idx] != '\\' {
+				// subscript for a new field ref
+				name += dotpath[:idx]
+				dotpath = dotpath[idx:]
+				break
+			}
+
+			if len(dotpath) == idx+1 {
+				// dotpath ends with a backslash; consume it all
+				name += dotpath
+				dotpath = ""
+				break
+			}
+
+			// append all characters before backslash, then the character which follows it
+			name += dotpath[:idx] + string(dotpath[idx+1])
+			dotpath = dotpath[idx+2:]
+		}
+		return name
+	}
+
+	var curType arrow.DataType = arrow.StructOf(rootSchema.Fields()...)
+	children := make([]expr.ReferenceSegment, 0)
+
+	for len(dotpath) > 0 {
+		subscript := dotpath[0]
+		dotpath = dotpath[1:]
+		switch subscript {
+		case '.':
+			// next element is a name
+			n := parseName()
+			switch ct := curType.(type) {
+			case *arrow.StructType:
+				idx, found := ct.FieldIdx(n)
+				if !found {
+					return nil, fmt.Errorf("%w: dot path '%s' referenced invalid field", arrow.ErrInvalid, dotpath)
+				}
+				children = append(children, &expr.StructFieldRef{Field: int32(idx)})
+				curType = ct.Field(idx).Type
+			case *arrow.MapType:
+				curType = ct.KeyType()
+				switch ct.KeyType().ID() {
+				case arrow.BINARY, arrow.LARGE_BINARY:
+					children = append(children, &expr.MapKeyRef{MapKey: expr.NewByteSliceLiteral([]byte(n), false)})
+				case arrow.STRING, arrow.LARGE_STRING:
+					children = append(children, &expr.MapKeyRef{MapKey: expr.NewPrimitiveLiteral(n, false)})
+				default:
+					return nil, fmt.Errorf("%w: MapKeyRef to non-binary/string map not supported", arrow.ErrNotImplemented)
+				}
+			default:
+				return nil, fmt.Errorf("%w: dot path names must refer to struct fields or map keys", arrow.ErrInvalid)
+			}
+		case '[':
+			subend := strings.IndexFunc(dotpath, func(r rune) bool { return !unicode.IsDigit(r) })
+			if subend == -1 || dotpath[subend] != ']' {
+				return nil, fmt.Errorf("%w: dot path '%s' contained an unterminated index", arrow.ErrInvalid, dotpath)
+			}
+			idx, _ := strconv.Atoi(dotpath[:subend])
+			switch ct := curType.(type) {
+			case *arrow.StructType:
+				if idx > len(ct.Fields()) {
+					return nil, fmt.Errorf("%w: field out of bounds in dotpath", arrow.ErrIndex)
+				}
+				curType = ct.Field(idx).Type
+				children = append(children, &expr.StructFieldRef{Field: int32(idx)})
+			case *arrow.MapType:
+				curType = ct.KeyType()
+				var keyLiteral expr.Literal
+				// TODO: implement user defined types and variations
+				switch ct.KeyType().ID() {
+				case arrow.INT8:
+					keyLiteral = expr.NewPrimitiveLiteral(int8(idx), false)
+				case arrow.INT16:
+					keyLiteral = expr.NewPrimitiveLiteral(int16(idx), false)
+				case arrow.INT32:
+					keyLiteral = expr.NewPrimitiveLiteral(int32(idx), false)
+				case arrow.INT64:
+					keyLiteral = expr.NewPrimitiveLiteral(int64(idx), false)
+				case arrow.FLOAT32:
+					keyLiteral = expr.NewPrimitiveLiteral(float32(idx), false)
+				case arrow.FLOAT64:
+					keyLiteral = expr.NewPrimitiveLiteral(float64(idx), false)

Review Comment:
   Floats as map keys seems pretty unlikely.  Also, since you are using `atoi` above I think you're requiring the keys to be integral.  Wouldn't it be better to just call this case unsupported?



##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()
+	}
+
+	out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext)
+	if err == compute.ErrEmpty {
+		out = compute.NewDatum(param)
+	} else if err != nil {
+		return nil, err
+	}
+	if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) {
+		return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s",
+			arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType)
+	}
+
+	return out, nil
+}
+
+func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
+	if !exp.IsScalar() {
+		return nil, fmt.Errorf("%w: ExecuteScalarExpression cannot execute non-scalar expressions",
+			arrow.ErrInvalid)
+	}
+
+	switch e := exp.(type) {
+	case expr.Literal:
+		return literalToDatum(compute.GetAllocator(ctx), e, ext)
+	case *expr.FieldReference:
+		return execFieldRef(ctx, e, input, ext)
+	case *expr.Cast:
+		if e.Input == nil {
+			return nil, fmt.Errorf("%w: cast without argument to cast", arrow.ErrInvalid)
+		}
+
+		arg, err := executeScalarBatch(ctx, input, e.Input, ext)
+		if err != nil {
+			return nil, err
+		}
+		defer arg.Release()
+
+		dt, _, err := FromSubstraitType(e.Type, ext)
+		if err != nil {
+			return nil, fmt.Errorf("%w: could not determine type for cast", err)
+		}
+
+		var opts *compute.CastOptions
+		switch e.FailureBehavior {
+		case types.BehaviorThrowException:
+			opts = compute.SafeCastOptions(dt)
+		case types.BehaviorUnspecified:
+			opts = compute.UnsafeCastOptions(dt)
+		case types.BehaviorReturnNil:
+			return nil, fmt.Errorf("%w: cast behavior return nil", arrow.ErrNotImplemented)
+		}
+		return compute.CastDatum(ctx, arg, opts)
+	case *expr.ScalarFunction:
+		var (
+			err       error
+			allScalar = true
+			args      = make([]compute.Datum, e.NArgs())
+			argTypes  = make([]arrow.DataType, e.NArgs())
+		)
+		for i := 0; i < e.NArgs(); i++ {
+			switch v := e.Arg(i).(type) {
+			case types.Enum:
+				args[i] = compute.NewDatum(scalar.NewStringScalar(string(v)))
+			case expr.Expression:
+				args[i], err = executeScalarBatch(ctx, input, v, ext)
+				if err != nil {
+					return nil, err
+				}
+				defer args[i].Release()
+
+				if args[i].Kind() != compute.KindScalar {
+					allScalar = false

Review Comment:
   Hmm, I'm not certain this would be a valid function.  I'm not sure how you would apply it in a scalar context?  For example, what does `SELECT all_enum_func(FOO, BAR), l_quantity FROM line_item;` return?
   
   However, given it isn't a valid function, I'm not sure it's worth quibbling over too much.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow] zeroshade commented on a diff in pull request #35654: GH-35652: [Go][Compute] Allow executing Substrait Expressions using Go Compute

Posted by "zeroshade (via GitHub)" <gi...@apache.org>.
zeroshade commented on code in PR #35654:
URL: https://github.com/apache/arrow/pull/35654#discussion_r1221793810


##########
go/arrow/compute/exprs/exec.go:
##########
@@ -0,0 +1,631 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+
+package exprs
+
+import (
+	"context"
+	"fmt"
+	"unsafe"
+
+	"github.com/apache/arrow/go/v13/arrow"
+	"github.com/apache/arrow/go/v13/arrow/array"
+	"github.com/apache/arrow/go/v13/arrow/compute"
+	"github.com/apache/arrow/go/v13/arrow/compute/internal/exec"
+	"github.com/apache/arrow/go/v13/arrow/decimal128"
+	"github.com/apache/arrow/go/v13/arrow/endian"
+	"github.com/apache/arrow/go/v13/arrow/internal/debug"
+	"github.com/apache/arrow/go/v13/arrow/memory"
+	"github.com/apache/arrow/go/v13/arrow/scalar"
+	"github.com/substrait-io/substrait-go/expr"
+	"github.com/substrait-io/substrait-go/extensions"
+	"github.com/substrait-io/substrait-go/types"
+)
+
+func makeExecBatch(ctx context.Context, schema *arrow.Schema, partial compute.Datum) (out compute.ExecBatch, err error) {
+	// cleanup if we get an error
+	defer func() {
+		if err != nil {
+			for _, v := range out.Values {
+				if v != nil {
+					v.Release()
+				}
+			}
+		}
+	}()
+
+	if partial.Kind() == compute.KindRecord {
+		partialBatch := partial.(*compute.RecordDatum).Value
+		batchSchema := partialBatch.Schema()
+
+		out.Values = make([]compute.Datum, len(schema.Fields()))
+		out.Len = partialBatch.NumRows()
+
+		for i, field := range schema.Fields() {
+			idxes := batchSchema.FieldIndices(field.Name)
+			switch len(idxes) {
+			case 0:
+				out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+			case 1:
+				col := partialBatch.Column(idxes[0])
+				if !arrow.TypeEqual(col.DataType(), field.Type) {
+					// referenced field was present but didn't have expected type
+					// we'll cast this case for now
+					col, err = compute.CastArray(ctx, col, compute.SafeCastOptions(field.Type))
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+					defer col.Release()
+				}
+				out.Values[i] = compute.NewDatum(col)
+			default:
+				err = fmt.Errorf("%w: exec batch field '%s' ambiguous, more than one match",
+					arrow.ErrInvalid, field.Name)
+				return compute.ExecBatch{}, err
+			}
+		}
+		return
+	}
+
+	part, ok := partial.(compute.ArrayLikeDatum)
+	if !ok {
+		return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+	}
+
+	// wasteful but useful for testing
+	if part.Type().ID() == arrow.STRUCT {
+		switch part := part.(type) {
+		case *compute.ArrayDatum:
+			arr := part.MakeArray().(*array.Struct)
+			defer arr.Release()
+
+			batch := array.RecordFromStructArray(arr, nil)
+			defer batch.Release()
+			return makeExecBatch(ctx, schema, compute.NewDatumWithoutOwning(batch))
+		case *compute.ScalarDatum:
+			out.Len = 1
+			out.Values = make([]compute.Datum, len(schema.Fields()))
+
+			s := part.Value.(*scalar.Struct)
+			dt := s.Type.(*arrow.StructType)
+
+			for i, field := range schema.Fields() {
+				idx, found := dt.FieldIdx(field.Name)
+				if !found {
+					out.Values[i] = compute.NewDatum(scalar.MakeNullScalar(field.Type))
+					continue
+				}
+
+				val := s.Value[idx]
+				if !arrow.TypeEqual(val.DataType(), field.Type) {
+					// referenced field was present but didn't have the expected
+					// type. for now we'll cast this
+					val, err = val.CastTo(field.Type)
+					if err != nil {
+						return compute.ExecBatch{}, err
+					}
+				}
+				out.Values[i] = compute.NewDatum(val)
+			}
+			return
+		}
+	}
+
+	return out, fmt.Errorf("%w: MakeExecBatch from %s", arrow.ErrNotImplemented, partial)
+}
+
+// ToArrowSchema takes a substrait NamedStruct and an extension set (for
+// type resolution mapping) and creates the equivalent Arrow Schema.
+func ToArrowSchema(base types.NamedStruct, ext ExtensionIDSet) (*arrow.Schema, error) {
+	fields := make([]arrow.Field, len(base.Names))
+	for i, typ := range base.Struct.Types {
+		dt, nullable, err := FromSubstraitType(typ, ext)
+		if err != nil {
+			return nil, err
+		}
+		fields[i] = arrow.Field{
+			Name:     base.Names[i],
+			Type:     dt,
+			Nullable: nullable,
+		}
+	}
+
+	return arrow.NewSchema(fields, nil), nil
+}
+
+type (
+	regCtxKey struct{}
+	extCtxKey struct{}
+)
+
+func WithExtensionRegistry(ctx context.Context, reg *ExtensionIDRegistry) context.Context {
+	return context.WithValue(ctx, regCtxKey{}, reg)
+}
+
+func GetExtensionRegistry(ctx context.Context) *ExtensionIDRegistry {
+	v, ok := ctx.Value(regCtxKey{}).(*ExtensionIDRegistry)
+	if !ok {
+		v = DefaultExtensionIDRegistry
+	}
+	return v
+}
+
+func WithExtensionIDSet(ctx context.Context, ext ExtensionIDSet) context.Context {
+	return context.WithValue(ctx, extCtxKey{}, ext)
+}
+
+func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
+	v, ok := ctx.Value(extCtxKey{}).(ExtensionIDSet)
+	if !ok {
+		return NewExtensionSet(
+			expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection),
+			GetExtensionRegistry(ctx))
+	}
+	return v
+}
+
+func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
+	switch v := lit.(type) {
+	case *expr.PrimitiveLiteral[bool]:
+		return compute.NewDatum(scalar.NewBooleanScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int8]:
+		return compute.NewDatum(scalar.NewInt8Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int16]:
+		return compute.NewDatum(scalar.NewInt16Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int32]:
+		return compute.NewDatum(scalar.NewInt32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[int64]:
+		return compute.NewDatum(scalar.NewInt64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float32]:
+		return compute.NewDatum(scalar.NewFloat32Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[float64]:
+		return compute.NewDatum(scalar.NewFloat64Scalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[string]:
+		return compute.NewDatum(scalar.NewStringScalar(v.Value)), nil
+	case *expr.PrimitiveLiteral[types.Timestamp]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value), &arrow.TimestampType{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.TimestampTz]:
+		return compute.NewDatum(scalar.NewTimestampScalar(arrow.Timestamp(v.Value),
+			&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: TimestampTzTimezone})), nil
+	case *expr.PrimitiveLiteral[types.Date]:
+		return compute.NewDatum(scalar.NewDate32Scalar(arrow.Date32(v.Value))), nil
+	case *expr.PrimitiveLiteral[types.Time]:
+		return compute.NewDatum(scalar.NewTime64Scalar(arrow.Time64(v.Value), &arrow.Time64Type{Unit: arrow.Microsecond})), nil
+	case *expr.PrimitiveLiteral[types.FixedChar]:
+		length := int(v.Type.(*types.FixedCharType).Length)
+		return compute.NewDatum(scalar.NewExtensionScalar(
+			scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes([]byte(v.Value)),
+				&arrow.FixedSizeBinaryType{ByteWidth: length}), fixedChar(int32(length)))), nil
+	case *expr.ByteSliceLiteral[[]byte]:
+		return compute.NewDatum(scalar.NewBinaryScalar(memory.NewBufferBytes(v.Value), arrow.BinaryTypes.Binary)), nil
+	case *expr.ByteSliceLiteral[types.UUID]:
+		return compute.NewDatum(scalar.NewExtensionScalar(scalar.NewFixedSizeBinaryScalar(
+			memory.NewBufferBytes(v.Value), uuid().(arrow.ExtensionType).StorageType()), uuid())), nil
+	case *expr.ByteSliceLiteral[types.FixedBinary]:
+		return compute.NewDatum(scalar.NewFixedSizeBinaryScalar(memory.NewBufferBytes(v.Value),
+			&arrow.FixedSizeBinaryType{ByteWidth: int(v.Type.(*types.FixedBinaryType).Length)})), nil
+	case *expr.NullLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+		return compute.NewDatum(scalar.MakeNullScalar(dt)), nil
+	case *expr.ListLiteral:
+		var elemType arrow.DataType
+
+		values := make([]scalar.Scalar, len(v.Value))
+		for i, val := range v.Value {
+			d, err := literalToDatum(mem, val, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer d.Release()
+			values[i] = d.(*compute.ScalarDatum).Value
+			if elemType != nil {
+				if !arrow.TypeEqual(values[i].DataType(), elemType) {
+					return nil, fmt.Errorf("%w: %s has a value whose type doesn't match the other list values",
+						arrow.ErrInvalid, v)
+				}
+			} else {
+				elemType = values[i].DataType()
+			}
+		}
+
+		bldr := array.NewBuilder(memory.DefaultAllocator, elemType)
+		defer bldr.Release()
+		if err := scalar.AppendSlice(bldr, values); err != nil {
+			return nil, err
+		}
+		arr := bldr.NewArray()
+		defer arr.Release()
+		return compute.NewDatum(scalar.NewListScalar(arr)), nil
+	case *expr.MapLiteral:
+		dt, _, err := FromSubstraitType(v.Type, ext)
+		if err != nil {
+			return nil, err
+		}
+
+		mapType, ok := dt.(*arrow.MapType)
+		if !ok {
+			return nil, fmt.Errorf("%w: map literal with non-map type", arrow.ErrInvalid)
+		}
+
+		keys, values := make([]scalar.Scalar, len(v.Value)), make([]scalar.Scalar, len(v.Value))
+		for i, kv := range v.Value {
+			k, err := literalToDatum(mem, kv.Key, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer k.Release()
+			scalarKey := k.(*compute.ScalarDatum).Value
+
+			v, err := literalToDatum(mem, kv.Value, ext)
+			if err != nil {
+				return nil, err
+			}
+			defer v.Release()
+			scalarValue := v.(*compute.ScalarDatum).Value
+
+			if !arrow.TypeEqual(mapType.KeyType(), scalarKey.DataType()) {
+				return nil, fmt.Errorf("%w: key type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarKey.DataType())
+			}
+			if !arrow.TypeEqual(mapType.ValueType(), scalarValue.DataType()) {
+				return nil, fmt.Errorf("%w: value type mismatch for %s, got key with type %s",
+					arrow.ErrInvalid, mapType, scalarValue.DataType())
+			}
+
+			keys[i], values[i] = scalarKey, scalarValue
+		}
+
+		keyBldr, valBldr := array.NewBuilder(mem, mapType.KeyType()), array.NewBuilder(mem, mapType.ValueType())
+		defer keyBldr.Release()
+		defer valBldr.Release()
+
+		if err := scalar.AppendSlice(keyBldr, keys); err != nil {
+			return nil, err
+		}
+		if err := scalar.AppendSlice(valBldr, values); err != nil {
+			return nil, err
+		}
+
+		keyArr, valArr := keyBldr.NewArray(), valBldr.NewArray()
+		defer keyArr.Release()
+		defer valArr.Release()
+
+		kvArr, err := array.NewStructArray([]arrow.Array{keyArr, valArr}, []string{"key", "value"})
+		if err != nil {
+			return nil, err
+		}
+		defer kvArr.Release()
+
+		return compute.NewDatumWithoutOwning(scalar.NewMapScalar(kvArr)), nil
+	case *expr.StructLiteral:
+		fields := make([]scalar.Scalar, len(v.Value))
+		names := make([]string, len(v.Value))
+
+		s, err := scalar.NewStructScalarWithNames(fields, names)
+		return compute.NewDatum(s), err
+	case *expr.ProtoLiteral:
+		switch v := v.Value.(type) {
+		case *types.Decimal:
+			if len(v.Value) != arrow.Decimal128SizeBytes {
+				return nil, fmt.Errorf("%w: decimal literal had %d bytes (expected %d)",
+					arrow.ErrInvalid, len(v.Value), arrow.Decimal128SizeBytes)
+			}
+
+			var val decimal128.Num
+			data := (*(*[arrow.Decimal128SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+			copy(data, v.Value)
+			if endian.IsBigEndian {
+				// reverse the bytes
+				for i := len(data)/2 - 1; i >= 0; i-- {
+					opp := len(data) - 1 - i
+					data[i], data[opp] = data[opp], data[i]
+				}
+			}
+
+			return compute.NewDatum(scalar.NewDecimal128Scalar(val,
+				&arrow.Decimal128Type{Precision: v.Precision, Scale: v.Scale})), nil
+		case *types.UserDefinedLiteral: // not yet implemented
+		case *types.IntervalYearToMonth:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalYear()
+			bldr.Append(v.Years)
+			bldr.Append(v.Months)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.IntervalDayToSecond:
+			bldr := array.NewInt32Builder(memory.DefaultAllocator)
+			defer bldr.Release()
+			typ := intervalDay()
+			bldr.Append(v.Days)
+			bldr.Append(v.Seconds)
+			arr := bldr.NewArray()
+			defer arr.Release()
+			return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
+				scalar.NewFixedSizeListScalar(arr), typ)}, nil
+		case *types.VarChar:
+			return compute.NewDatum(scalar.NewExtensionScalar(
+				scalar.NewStringScalar(v.Value), varChar(int32(v.Length)))), nil
+		}
+	}
+
+	return nil, arrow.ErrNotImplemented
+}
+
+// ExecuteScalarExpression executes the given substrait expression using the provided datum as input.
+// It will first create an exec batch using the input schema and the datum.
+// The datum may have missing or incorrectly ordered columns while the input schema
+// should describe the expected input schema for the expression. Missing fields will
+// be replaced with null scalars and incorrectly ordered columns will be re-ordered
+// according to the schema.
+//
+// You can provide an allocator to use through the context via compute.WithAllocator.
+//
+// You can provide the ExtensionIDSet to use through the context via WithExtensionIDSet.
+func ExecuteScalarExpression(ctx context.Context, inputSchema *arrow.Schema, expression expr.Expression, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	batch, err := makeExecBatch(ctx, inputSchema, partialInput)
+	if err != nil {
+		return nil, err
+	}
+	defer func() {
+		for _, v := range batch.Values {
+			v.Release()
+		}
+	}()
+
+	return executeScalarBatch(ctx, batch, expression, GetExtensionIDSet(ctx))
+}
+
+// ExecuteScalarSubstrait uses the provided Substrait extended expression to
+// determine the expected input schema (replacing missing fields in the partial
+// input datum with null scalars and re-ordering columns if necessary) and
+// ExtensionIDSet to use. You can provide the extension registry to use
+// through the context via WithExtensionRegistry, otherwise the default
+// Arrow registry will be used. You can provide a memory.Allocator to use
+// the same way via compute.WithAllocator.
+func ExecuteScalarSubstrait(ctx context.Context, expression *expr.Extended, partialInput compute.Datum) (compute.Datum, error) {
+	if expression == nil {
+		return nil, arrow.ErrInvalid
+	}
+
+	var toExecute expr.Expression
+
+	switch len(expression.ReferredExpr) {
+	case 0:
+		return nil, fmt.Errorf("%w: no referred expression to execute", arrow.ErrInvalid)
+	case 1:
+		if toExecute = expression.ReferredExpr[0].GetExpr(); toExecute == nil {
+			return nil, fmt.Errorf("%w: measures not implemented", arrow.ErrNotImplemented)
+		}
+	default:
+		return nil, fmt.Errorf("%w: only single referred expression implemented", arrow.ErrNotImplemented)
+	}
+
+	reg := GetExtensionRegistry(ctx)
+	set := NewExtensionSet(expr.NewExtensionRegistry(expression.Extensions, &extensions.DefaultCollection), reg)
+	sc, err := ToArrowSchema(expression.BaseSchema, set)
+	if err != nil {
+		return nil, err
+	}
+
+	return ExecuteScalarExpression(WithExtensionIDSet(ctx, set), sc, toExecute, partialInput)
+}
+
+func execFieldRef(ctx context.Context, e *expr.FieldReference, input compute.ExecBatch, ext ExtensionIDSet) (compute.Datum, error) {
+	if e.Root != expr.RootReference {
+		return nil, fmt.Errorf("%w: only RootReference is implemented", arrow.ErrNotImplemented)
+	}
+
+	ref, ok := e.Reference.(expr.ReferenceSegment)
+	if !ok {
+		return nil, fmt.Errorf("%w: only direct references are implemented", arrow.ErrNotImplemented)
+	}
+
+	expectedType, _, err := FromSubstraitType(e.GetType(), ext)
+	if err != nil {
+		return nil, err
+	}
+
+	var param compute.Datum
+	if sref, ok := ref.(*expr.StructFieldRef); ok {
+		if sref.Field < 0 || sref.Field >= int32(len(input.Values)) {
+			return nil, arrow.ErrInvalid
+		}
+		param = input.Values[sref.Field]
+		ref = ref.GetChild()
+	}
+
+	out, err := GetReferencedValue(compute.GetAllocator(ctx), ref, param, ext)
+	if err == compute.ErrEmpty {
+		out = compute.NewDatum(param)
+	} else if err != nil {
+		return nil, err
+	}
+	if !arrow.TypeEqual(out.(compute.ArrayLikeDatum).Type(), expectedType) {
+		return nil, fmt.Errorf("%w: referenced field %s was %s, but should have been %s",
+			arrow.ErrInvalid, ref, out.(compute.ArrayLikeDatum).Type(), expectedType)
+	}
+
+	return out, nil
+}
+
+func executeScalarBatch(ctx context.Context, input compute.ExecBatch, exp expr.Expression, ext ExtensionIDSet) (compute.Datum, error) {
+	if !exp.IsScalar() {

Review Comment:
   Yea, that's how I understood things. So the `SUM(y)` should return false for the `IsScalar` check based on the current stuff in substrait-go, WindowFunctions return false for `IsScalar` currently.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org