You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ze...@apache.org on 2023/01/23 22:01:44 UTC

[arrow] branch master updated: GH-33840: [Go] Improve SQLite Flight SQL Example and provide mainprog (#33841)

This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 8449c55371 GH-33840: [Go] Improve SQLite Flight SQL Example and provide mainprog (#33841)
8449c55371 is described below

commit 8449c553710a6c03d2f21a657c91c87742fa8be2
Author: Matt Topol <zo...@gmail.com>
AuthorDate: Mon Jan 23 17:01:37 2023 -0500

    GH-33840: [Go] Improve SQLite Flight SQL Example and provide mainprog (#33841)
    
    
    
    ### Rationale for this change
    
    Better facilitate Flight SQL testing by providing a simple mainprog which can be easily installed via `go get` or `go install` for testing and for running CI against.
    
    ### What changes are included in this PR?
    
    Improvements to the FlightSQL SQLite Go example and the creation of a mainprog at github.com/apache/arrow/go/v11/arrow/flight/flightsql/cmd/sqlite_flightsql_server
    
    ### Are these changes tested?
    
    Covered by existing tests in sqlite_server_test.go
    
    * Closes: #33840
    
    Authored-by: Matt Topol <zo...@gmail.com>
    Signed-off-by: Matt Topol <zo...@gmail.com>
---
 go/arrow/cdata/cdata_exports.go                    |  58 +++++---
 go/arrow/cdata/cdata_fulltest.c                    |  21 +++
 go/arrow/cdata/cdata_test.go                       |  46 +++++++
 go/arrow/cdata/cdata_test_framework.go             |  26 ++++
 go/arrow/datatype_nested.go                        |   4 +-
 .../example/cmd/sqlite_flightsql_server/main.go    |  58 ++++++++
 .../flight/flightsql/example/sql_batch_reader.go   |  74 +++++++---
 go/arrow/flight/flightsql/example/sqlite_server.go | 151 ++++++++++++++++-----
 go/arrow/flight/record_batch_reader.go             |  27 ++++
 go/arrow/scalar/temporal.go                        |  20 +++
 10 files changed, 416 insertions(+), 69 deletions(-)

diff --git a/go/arrow/cdata/cdata_exports.go b/go/arrow/cdata/cdata_exports.go
index b998decf0b..722cac71b0 100644
--- a/go/arrow/cdata/cdata_exports.go
+++ b/go/arrow/cdata/cdata_exports.go
@@ -38,6 +38,7 @@ import (
 	"fmt"
 	"reflect"
 	"runtime/cgo"
+	"strconv"
 	"strings"
 	"unsafe"
 
@@ -152,6 +153,8 @@ func (exp *schemaExporter) exportFormat(dt arrow.DataType) string {
 		return fmt.Sprintf("w:%d", dt.ByteWidth)
 	case *arrow.Decimal128Type:
 		return fmt.Sprintf("d:%d,%d", dt.Precision, dt.Scale)
+	case *arrow.Decimal256Type:
+		return fmt.Sprintf("d:%d,%d,256", dt.Precision, dt.Scale)
 	case *arrow.BinaryType:
 		return "z"
 	case *arrow.LargeBinaryType:
@@ -235,6 +238,20 @@ func (exp *schemaExporter) exportFormat(dt arrow.DataType) string {
 			exp.flags |= C.ARROW_FLAG_DICTIONARY_ORDERED
 		}
 		return exp.exportFormat(dt.IndexType)
+	case arrow.UnionType:
+		var b strings.Builder
+		if dt.Mode() == arrow.SparseMode {
+			b.WriteString("+us:")
+		} else {
+			b.WriteString("+ud:")
+		}
+		for i, c := range dt.TypeCodes() {
+			if i != 0 {
+				b.WriteByte(',')
+			}
+			b.WriteString(strconv.Itoa(int(c)))
+		}
+		return b.String()
 	}
 	panic("unsupported data type for export")
 }
@@ -250,23 +267,11 @@ func (exp *schemaExporter) export(field arrow.Field) {
 	case *arrow.DictionaryType:
 		exp.dict = new(schemaExporter)
 		exp.dict.export(arrow.Field{Type: dt.ValueType})
-	case *arrow.ListType:
-		exp.children = make([]schemaExporter, 1)
-		exp.children[0].export(dt.ElemField())
-	case *arrow.LargeListType:
-		exp.children = make([]schemaExporter, 1)
-		exp.children[0].export(dt.ElemField())
-	case *arrow.StructType:
+	case arrow.NestedType:
 		exp.children = make([]schemaExporter, len(dt.Fields()))
 		for i, f := range dt.Fields() {
 			exp.children[i].export(f)
 		}
-	case *arrow.MapType:
-		exp.children = make([]schemaExporter, 1)
-		exp.children[0].export(dt.ValueField())
-	case *arrow.FixedSizeListType:
-		exp.children = make([]schemaExporter, 1)
-		exp.children[0].export(dt.ElemField())
 	}
 
 	exp.exportMeta(&field.Metadata)
@@ -364,9 +369,21 @@ func exportArray(arr arrow.Array, out *CArrowArray, outSchema *CArrowSchema) {
 	out.n_buffers = C.int64_t(len(arr.Data().Buffers()))
 
 	if out.n_buffers > 0 {
-		buffers := allocateBufferPtrArr(len(arr.Data().Buffers()))
-		for i := range arr.Data().Buffers() {
-			buf := arr.Data().Buffers()[i]
+		var (
+			nbuffers = len(arr.Data().Buffers())
+			bufs     = arr.Data().Buffers()
+		)
+		// unions don't have validity bitmaps, but we keep them shifted
+		// to make processing easier in other contexts. This means that
+		// we have to adjust for union arrays
+		if arr.DataType().ID() == arrow.DENSE_UNION || arr.DataType().ID() == arrow.SPARSE_UNION {
+			out.n_buffers--
+			nbuffers--
+			bufs = bufs[1:]
+		}
+		buffers := allocateBufferPtrArr(nbuffers)
+		for i := range bufs {
+			buf := bufs[i]
 			if buf == nil || buf.Len() == 0 {
 				buffers[i] = nil
 				continue
@@ -408,6 +425,15 @@ func exportArray(arr arrow.Array, out *CArrowArray, outSchema *CArrowSchema) {
 	case *array.Dictionary:
 		out.dictionary = (*CArrowArray)(C.malloc(C.sizeof_struct_ArrowArray))
 		exportArray(arr.Dictionary(), out.dictionary, nil)
+	case array.Union:
+		out.n_children = C.int64_t(arr.NumFields())
+		childPtrs := allocateArrowArrayPtrArr(arr.NumFields())
+		children := allocateArrowArrayArr(arr.NumFields())
+		for i := 0; i < arr.NumFields(); i++ {
+			exportArray(arr.Field(i), &children[i], nil)
+			childPtrs[i] = &children[i]
+		}
+		out.children = (**CArrowArray)(unsafe.Pointer(&childPtrs[0]))
 	default:
 		out.n_children = 0
 		out.children = nil
diff --git a/go/arrow/cdata/cdata_fulltest.c b/go/arrow/cdata/cdata_fulltest.c
index 837d347d53..4731c0ef39 100644
--- a/go/arrow/cdata/cdata_fulltest.c
+++ b/go/arrow/cdata/cdata_fulltest.c
@@ -244,6 +244,27 @@ struct ArrowSchema** test_map(const char** fmts, const char** names, int64_t* fl
     return schemas;
 }
 
+struct ArrowSchema** test_union(const char** fmts, const char** names, int64_t* flags, const int n) {
+    struct ArrowSchema** schemas = malloc(sizeof(struct ArrowSchema*)*n);
+     for (int i = 0; i < n; ++i) {
+        schemas[i] = malloc(sizeof(struct ArrowSchema));
+        *schemas[i] = (struct ArrowSchema) {
+            .format = fmts[i],
+            .name = names[i],
+            .metadata = NULL,
+            .flags = flags[i],
+            .children = NULL,
+            .n_children = 0,
+            .dictionary = NULL,
+            .release = &release_nested_dynamic,
+        };
+    }
+
+    schemas[0]->n_children = n-1;
+    schemas[0]->children = &schemas[1];
+    return schemas;
+}
+
 struct streamcounter {
     int n;
     int max;
diff --git a/go/arrow/cdata/cdata_test.go b/go/arrow/cdata/cdata_test.go
index 8976f377f2..32d41aa087 100644
--- a/go/arrow/cdata/cdata_test.go
+++ b/go/arrow/cdata/cdata_test.go
@@ -579,6 +579,50 @@ func createTestMapArr() arrow.Array {
 	return bld.NewArray()
 }
 
+func createTestSparseUnion() arrow.Array {
+	return createTestUnionArr(arrow.SparseMode)
+}
+
+func createTestDenseUnion() arrow.Array {
+	return createTestUnionArr(arrow.DenseMode)
+}
+
+func createTestUnionArr(mode arrow.UnionMode) arrow.Array {
+	fields := []arrow.Field{
+		arrow.Field{Name: "u0", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
+		arrow.Field{Name: "u1", Type: arrow.PrimitiveTypes.Uint8, Nullable: true},
+	}
+	typeCodes := []arrow.UnionTypeCode{5, 10}
+	bld := array.NewBuilder(memory.DefaultAllocator, arrow.UnionOf(mode, fields, typeCodes)).(array.UnionBuilder)
+	defer bld.Release()
+
+	u0Bld := bld.Child(0).(*array.Int32Builder)
+	u1Bld := bld.Child(1).(*array.Uint8Builder)
+
+	bld.Append(5)
+	if mode == arrow.SparseMode {
+		u1Bld.AppendNull()
+	}
+	u0Bld.Append(128)
+	bld.Append(5)
+	if mode == arrow.SparseMode {
+		u1Bld.AppendNull()
+	}
+	u0Bld.Append(256)
+	bld.Append(10)
+	if mode == arrow.SparseMode {
+		u0Bld.AppendNull()
+	}
+	u1Bld.Append(127)
+	bld.Append(10)
+	if mode == arrow.SparseMode {
+		u0Bld.AppendNull()
+	}
+	u1Bld.Append(25)
+
+	return bld.NewArray()
+}
+
 func TestNestedArrays(t *testing.T) {
 	tests := []struct {
 		name string
@@ -589,6 +633,8 @@ func TestNestedArrays(t *testing.T) {
 		{"fixed size list", createTestFixedSizeList},
 		{"struct", createTestStructArr},
 		{"map", createTestMapArr},
+		{"sparse union", createTestSparseUnion},
+		{"dense union", createTestDenseUnion},
 	}
 
 	for _, tt := range tests {
diff --git a/go/arrow/cdata/cdata_test_framework.go b/go/arrow/cdata/cdata_test_framework.go
index 0ddda26938..3f0b81f90a 100644
--- a/go/arrow/cdata/cdata_test_framework.go
+++ b/go/arrow/cdata/cdata_test_framework.go
@@ -53,6 +53,7 @@ package cdata
 // struct ArrowSchema** test_struct(const char** fmts, const char** names, int64_t* flags, const int n);
 // struct ArrowSchema** test_map(const char** fmts, const char** names, int64_t* flags, const int n);
 // struct ArrowSchema** test_schema(const char** fmts, const char** names, int64_t* flags, const int n);
+// struct ArrowSchema** test_union(const char** fmts, const char** names, int64_t* flags, const int n);
 // int test_exported_stream(struct ArrowArrayStream* stream);
 import "C"
 import (
@@ -182,6 +183,24 @@ func testMap(fmts, names []string, flags []int64) **CArrowSchema {
 	return C.test_map((**C.char)(unsafe.Pointer(&cfmts[0])), (**C.char)(unsafe.Pointer(&cnames[0])), (*C.int64_t)(unsafe.Pointer(&cflags[0])), C.int(len(fmts)))
 }
 
+func testUnion(fmts, names []string, flags []int64) **CArrowSchema {
+	if len(fmts) != len(names) || len(names) != len(flags) {
+		panic("testing unions must all have the same size slices in args")
+	}
+
+	cfmts := make([]*C.char, len(fmts))
+	cnames := make([]*C.char, len(names))
+	cflags := make([]C.int64_t, len(flags))
+
+	for i := range fmts {
+		cfmts[i] = C.CString(fmts[i])
+		cnames[i] = C.CString(names[i])
+		cflags[i] = C.int64_t(flags[i])
+	}
+
+	return C.test_union((**C.char)(unsafe.Pointer(&cfmts[0])), (**C.char)(unsafe.Pointer(&cnames[0])), (*C.int64_t)(unsafe.Pointer(&cflags[0])), C.int(len(fmts)))
+}
+
 func testSchema(fmts, names []string, flags []int64) **CArrowSchema {
 	if len(fmts) != len(names) || len(names) != len(flags) {
 		panic("testing structs must all have the same size slices in args")
@@ -235,6 +254,13 @@ func createCArr(arr arrow.Array) *CArrowArray {
 		clist := []*CArrowArray{createCArr(arr.ListValues())}
 		children = (**CArrowArray)(unsafe.Pointer(&clist[0]))
 		nchildren += 1
+	case array.Union:
+		clist := []*CArrowArray{}
+		for i := 0; i < arr.NumFields(); i++ {
+			clist = append(clist, createCArr(arr.Field(i)))
+			nchildren += 1
+		}
+		children = (**CArrowArray)(unsafe.Pointer(&clist[0]))
 	}
 
 	carr.children = children
diff --git a/go/arrow/datatype_nested.go b/go/arrow/datatype_nested.go
index 2fd9779cf5..8966df90bb 100644
--- a/go/arrow/datatype_nested.go
+++ b/go/arrow/datatype_nested.go
@@ -596,7 +596,7 @@ func (t *SparseUnionType) Fingerprint() string {
 	return typeFingerprint(t) + "[s" + t.fingerprint()
 }
 func (SparseUnionType) Layout() DataTypeLayout {
-	return DataTypeLayout{Buffers: []BufferSpec{SpecAlwaysNull(), SpecFixedWidth(Uint8SizeBytes)}}
+	return DataTypeLayout{Buffers: []BufferSpec{SpecFixedWidth(Uint8SizeBytes)}}
 }
 func (t *SparseUnionType) String() string {
 	return t.Name() + t.unionType.String()
@@ -659,7 +659,7 @@ func (t *DenseUnionType) Fingerprint() string {
 }
 
 func (DenseUnionType) Layout() DataTypeLayout {
-	return DataTypeLayout{Buffers: []BufferSpec{SpecAlwaysNull(), SpecFixedWidth(Uint8SizeBytes), SpecFixedWidth(Int32SizeBytes)}}
+	return DataTypeLayout{Buffers: []BufferSpec{SpecFixedWidth(Uint8SizeBytes), SpecFixedWidth(Int32SizeBytes)}}
 }
 
 func (DenseUnionType) OffsetTypeTraits() OffsetTraits { return Int32Traits }
diff --git a/go/arrow/flight/flightsql/example/cmd/sqlite_flightsql_server/main.go b/go/arrow/flight/flightsql/example/cmd/sqlite_flightsql_server/main.go
new file mode 100644
index 0000000000..eb01050262
--- /dev/null
+++ b/go/arrow/flight/flightsql/example/cmd/sqlite_flightsql_server/main.go
@@ -0,0 +1,58 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+// +build go1.18
+
+package main
+
+import (
+	"flag"
+	"fmt"
+	"log"
+	"net"
+	"os"
+	"strconv"
+
+	"github.com/apache/arrow/go/v11/arrow/flight"
+	"github.com/apache/arrow/go/v11/arrow/flight/flightsql"
+	"github.com/apache/arrow/go/v11/arrow/flight/flightsql/example"
+)
+
+func main() {
+	var (
+		host = flag.String("host", "localhost", "hostname to bind to")
+		port = flag.Int("port", 0, "port to bind to")
+	)
+
+	flag.Parse()
+
+	srv, err := example.NewSQLiteFlightSQLServer()
+	if err != nil {
+		log.Fatal(err)
+	}
+
+	server := flight.NewServerWithMiddleware(nil)
+	server.RegisterFlightService(flightsql.NewFlightServer(srv))
+	server.Init(net.JoinHostPort(*host, strconv.Itoa(*port)))
+	server.SetShutdownOnSignals(os.Interrupt, os.Kill)
+
+	fmt.Println("Starting SQLite Flight SQL Server on", server.Addr(), "...")
+
+	if err := server.Serve(); err != nil {
+		log.Fatal(err)
+	}
+}
diff --git a/go/arrow/flight/flightsql/example/sql_batch_reader.go b/go/arrow/flight/flightsql/example/sql_batch_reader.go
index cc4249a83b..d7b05822db 100644
--- a/go/arrow/flight/flightsql/example/sql_batch_reader.go
+++ b/go/arrow/flight/flightsql/example/sql_batch_reader.go
@@ -22,6 +22,7 @@ package example
 import (
 	"database/sql"
 	"reflect"
+	"strconv"
 	"strings"
 	"sync/atomic"
 
@@ -39,9 +40,15 @@ func getArrowTypeFromString(dbtype string) arrow.DataType {
 	}
 
 	switch dbtype {
+	case "tinyint":
+		return arrow.PrimitiveTypes.Int8
+	case "mediumint":
+		return arrow.PrimitiveTypes.Int32
 	case "int", "integer":
 		return arrow.PrimitiveTypes.Int64
-	case "real":
+	case "float":
+		return arrow.PrimitiveTypes.Float32
+	case "real", "double":
 		return arrow.PrimitiveTypes.Float64
 	case "blob":
 		return arrow.BinaryTypes.Binary
@@ -52,14 +59,31 @@ func getArrowTypeFromString(dbtype string) arrow.DataType {
 	}
 }
 
+var sqliteDenseUnion = arrow.DenseUnionOf([]arrow.Field{
+	{Name: "int", Type: arrow.PrimitiveTypes.Int64, Nullable: true},
+	{Name: "float", Type: arrow.PrimitiveTypes.Float64, Nullable: true},
+	{Name: "string", Type: arrow.BinaryTypes.String, Nullable: true},
+}, []arrow.UnionTypeCode{0, 1, 2})
+
 func getArrowType(c *sql.ColumnType) arrow.DataType {
 	dbtype := strings.ToLower(c.DatabaseTypeName())
 	if dbtype == "" {
+		if c.ScanType() == nil {
+			return sqliteDenseUnion
+		}
 		switch c.ScanType().Kind() {
+		case reflect.Int8, reflect.Uint8:
+			return arrow.PrimitiveTypes.Int8
+		case reflect.Int32, reflect.Uint32:
+			return arrow.PrimitiveTypes.Int32
 		case reflect.Int, reflect.Int64, reflect.Uint64:
 			return arrow.PrimitiveTypes.Int64
-		case reflect.Float32, reflect.Float64:
+		case reflect.Float32:
+			return arrow.PrimitiveTypes.Float32
+		case reflect.Float64:
 			return arrow.PrimitiveTypes.Float64
+		case reflect.String:
+			return arrow.BinaryTypes.String
 		}
 	}
 	return getArrowTypeFromString(dbtype)
@@ -83,9 +107,11 @@ func NewSqlBatchReaderWithSchema(mem memory.Allocator, schema *arrow.Schema, row
 	rowdest := make([]interface{}, len(schema.Fields()))
 	for i, f := range schema.Fields() {
 		switch f.Type.ID() {
-		case arrow.UINT8:
+		case arrow.DENSE_UNION, arrow.SPARSE_UNION:
+			rowdest[i] = new(interface{})
+		case arrow.UINT8, arrow.INT8:
 			if f.Nullable {
-				rowdest[i] = &sql.NullInt32{}
+				rowdest[i] = &sql.NullByte{}
 			} else {
 				rowdest[i] = new(uint8)
 			}
@@ -101,7 +127,7 @@ func NewSqlBatchReaderWithSchema(mem memory.Allocator, schema *arrow.Schema, row
 			} else {
 				rowdest[i] = new(int64)
 			}
-		case arrow.FLOAT64:
+		case arrow.FLOAT32, arrow.FLOAT64:
 			if f.Nullable {
 				rowdest[i] = &sql.NullFloat64{}
 			} else {
@@ -140,13 +166,18 @@ func NewSqlBatchReader(mem memory.Allocator, rows *sql.Rows) (*SqlBatchReader, e
 	fields := make([]arrow.Field, len(cols))
 	for i, c := range cols {
 		fields[i].Name = c.Name()
+		if c.Name() == "?" {
+			fields[i].Name += ":" + strconv.Itoa(i)
+		}
 		fields[i].Nullable, _ = c.Nullable()
 		fields[i].Type = getArrowType(c)
 		fields[i].Metadata = getColumnMetadata(bldr, getSqlTypeFromTypeName(c.DatabaseTypeName()), "")
 		switch fields[i].Type.ID() {
-		case arrow.UINT8:
+		case arrow.DENSE_UNION, arrow.SPARSE_UNION:
+			rowdest[i] = new(interface{})
+		case arrow.UINT8, arrow.INT8:
 			if fields[i].Nullable {
-				rowdest[i] = &sql.NullInt32{}
+				rowdest[i] = &sql.NullByte{}
 			} else {
 				rowdest[i] = new(uint8)
 			}
@@ -162,7 +193,7 @@ func NewSqlBatchReader(mem memory.Allocator, rows *sql.Rows) (*SqlBatchReader, e
 			} else {
 				rowdest[i] = new(int64)
 			}
-		case arrow.FLOAT64:
+		case arrow.FLOAT64, arrow.FLOAT32:
 			if fields[i].Nullable {
 				rowdest[i] = &sql.NullFloat64{}
 			} else {
@@ -231,6 +262,12 @@ func (r *SqlBatchReader) Next() bool {
 			switch v := v.(type) {
 			case *uint8:
 				fb.(*array.Uint8Builder).Append(*v)
+			case *sql.NullByte:
+				if !v.Valid {
+					fb.AppendNull()
+				} else {
+					fb.(*array.Uint8Builder).Append(v.Byte)
+				}
 			case *int64:
 				fb.(*array.Int64Builder).Append(*v)
 			case *sql.NullInt64:
@@ -245,20 +282,25 @@ func (r *SqlBatchReader) Next() bool {
 				if !v.Valid {
 					fb.AppendNull()
 				} else {
-					switch b := fb.(type) {
-					case *array.Int32Builder:
-						b.Append(v.Int32)
-					case *array.Uint8Builder:
-						b.Append(uint8(v.Int32))
-					}
+					fb.(*array.Int32Builder).Append(v.Int32)
 				}
 			case *float64:
-				fb.(*array.Float64Builder).Append(*v)
+				switch b := fb.(type) {
+				case *array.Float64Builder:
+					b.Append(*v)
+				case *array.Float32Builder:
+					b.Append(float32(*v))
+				}
 			case *sql.NullFloat64:
 				if !v.Valid {
 					fb.AppendNull()
 				} else {
-					fb.(*array.Float64Builder).Append(v.Float64)
+					switch b := fb.(type) {
+					case *array.Float64Builder:
+						b.Append(v.Float64)
+					case *array.Float32Builder:
+						b.Append(float32(v.Float64))
+					}
 				}
 			case *[]byte:
 				if v == nil {
diff --git a/go/arrow/flight/flightsql/example/sqlite_server.go b/go/arrow/flight/flightsql/example/sqlite_server.go
index 1b1707aa79..c6b0990312 100644
--- a/go/arrow/flight/flightsql/example/sqlite_server.go
+++ b/go/arrow/flight/flightsql/example/sqlite_server.go
@@ -142,7 +142,7 @@ func prepareQueryForGetKeys(filter string) string {
 
 type Statement struct {
 	stmt   *sql.Stmt
-	params []interface{}
+	params [][]interface{}
 }
 
 type SQLiteFlightSQLServer struct {
@@ -183,6 +183,7 @@ func NewSQLiteFlightSQLServer() (*SQLiteFlightSQLServer, error) {
 		return nil, err
 	}
 	ret := &SQLiteFlightSQLServer{db: db}
+	ret.Alloc = memory.DefaultAllocator
 	for k, v := range SqlInfoResultMap() {
 		ret.RegisterSqlInfo(flightsql.SqlInfo(k), v)
 	}
@@ -380,41 +381,113 @@ func doGetQuery(ctx context.Context, mem memory.Allocator, db *sql.DB, query str
 	return schema, ch, nil
 }
 
-func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (*arrow.Schema, <-chan flight.StreamChunk, error) {
+func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) {
 	val, ok := s.prepared.Load(string(cmd.GetPreparedStatementHandle()))
 	if !ok {
 		return nil, nil, status.Error(codes.InvalidArgument, "prepared statement not found")
 	}
 
 	stmt := val.(Statement)
-	rows, err := stmt.stmt.QueryContext(ctx, stmt.params...)
-	if err != nil {
-		return nil, nil, err
-	}
+	readers := make([]array.RecordReader, 0, len(stmt.params))
+	if len(stmt.params) == 0 {
+		rows, err := stmt.stmt.QueryContext(ctx)
+		if err != nil {
+			return nil, nil, err
+		}
 
-	rdr, err := NewSqlBatchReader(s.Alloc, rows)
-	if err != nil {
-		return nil, nil, err
+		rdr, err := NewSqlBatchReader(s.Alloc, rows)
+		if err != nil {
+			return nil, nil, err
+		}
+
+		schema = rdr.schema
+		readers = append(readers, rdr)
+	} else {
+		defer func() {
+			if err != nil {
+				for _, r := range readers {
+					r.Release()
+				}
+			}
+		}()
+		var (
+			rows *sql.Rows
+			rdr  *SqlBatchReader
+		)
+		// if we have multiple rows of bound params, execute the query
+		// multiple times and concatenate the result sets.
+		for _, p := range stmt.params {
+			rows, err = stmt.stmt.QueryContext(ctx, p...)
+			if err != nil {
+				return nil, nil, err
+			}
+
+			if schema == nil {
+				rdr, err = NewSqlBatchReader(s.Alloc, rows)
+				if err != nil {
+					return nil, nil, err
+				}
+				schema = rdr.schema
+			} else {
+				rdr, err = NewSqlBatchReaderWithSchema(s.Alloc, schema, rows)
+				if err != nil {
+					return nil, nil, err
+				}
+			}
+
+			readers = append(readers, rdr)
+		}
 	}
 
-	schema := rdr.schema
 	ch := make(chan flight.StreamChunk)
-	go flight.StreamChunksFromReader(rdr, ch)
-	return schema, ch, nil
+	go flight.ConcatenateReaders(readers, ch)
+	out = ch
+	return
 }
 
-func getParamsForStatement(rdr flight.MessageReader) (params []interface{}, err error) {
+func scalarToIFace(s scalar.Scalar) (interface{}, error) {
+	if !s.IsValid() {
+		return nil, nil
+	}
+
+	switch val := s.(type) {
+	case *scalar.Int8:
+		return val.Value, nil
+	case *scalar.Uint8:
+		return val.Value, nil
+	case *scalar.Int32:
+		return val.Value, nil
+	case *scalar.Int64:
+		return val.Value, nil
+	case *scalar.Float32:
+		return val.Value, nil
+	case *scalar.Float64:
+		return val.Value, nil
+	case *scalar.String:
+		return string(val.Value.Bytes()), nil
+	case *scalar.Binary:
+		return val.Value.Bytes(), nil
+	case scalar.DateScalar:
+		return val.ToTime(), nil
+	case scalar.TimeScalar:
+		return val.ToTime(), nil
+	case *scalar.DenseUnion:
+		return scalarToIFace(val.Value)
+	default:
+		return nil, fmt.Errorf("unsupported type: %s", val)
+	}
+}
+
+func getParamsForStatement(rdr flight.MessageReader) (params [][]interface{}, err error) {
+	params = make([][]interface{}, 0)
 	for rdr.Next() {
 		rec := rdr.Record()
 
 		nrows := int(rec.NumRows())
 		ncols := int(rec.NumCols())
 
-		if len(params) < int(ncols) {
-			params = make([]interface{}, ncols)
-		}
-
 		for i := 0; i < nrows; i++ {
+			invokeParams := make([]interface{}, ncols)
 			for c := 0; c < ncols; c++ {
 				col := rec.Column(c)
 				sc, err := scalar.GetScalar(col, i)
@@ -425,21 +498,12 @@ func getParamsForStatement(rdr flight.MessageReader) (params []interface{}, err
 					r.Release()
 				}
 
-				switch v := sc.(*scalar.DenseUnion).Value.(type) {
-				case *scalar.Int64:
-					params[c] = v.Value
-				case *scalar.Float32:
-					params[c] = v.Value
-				case *scalar.Float64:
-					params[c] = v.Value
-				case *scalar.String:
-					params[c] = string(v.Value.Bytes())
-				case *scalar.Binary:
-					params[c] = v.Value.Bytes()
-				default:
-					return nil, fmt.Errorf("unsupported type: %s", v)
+				invokeParams[c], err = scalarToIFace(sc)
+				if err != nil {
+					return nil, err
 				}
 			}
+			params = append(params, invokeParams)
 		}
 	}
 
@@ -475,13 +539,30 @@ func (s *SQLiteFlightSQLServer) DoPutPreparedStatementUpdate(ctx context.Context
 		return 0, status.Errorf(codes.Internal, "error gathering parameters for prepared statement: %s", err.Error())
 	}
 
-	stmt.params = args
-	result, err := stmt.stmt.ExecContext(ctx, args...)
-	if err != nil {
-		return 0, err
+	if len(args) == 0 {
+		result, err := stmt.stmt.ExecContext(ctx)
+		if err != nil {
+			return 0, err
+		}
+
+		return result.RowsAffected()
+	}
+
+	var totalAffected int64
+	for _, p := range args {
+		result, err := stmt.stmt.ExecContext(ctx, p...)
+		if err != nil {
+			return totalAffected, err
+		}
+
+		n, err := result.RowsAffected()
+		if err != nil {
+			return totalAffected, err
+		}
+		totalAffected += n
 	}
 
-	return result.RowsAffected()
+	return totalAffected, nil
 }
 
 func (s *SQLiteFlightSQLServer) GetFlightInfoPrimaryKeys(_ context.Context, cmd flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
diff --git a/go/arrow/flight/record_batch_reader.go b/go/arrow/flight/record_batch_reader.go
index 277f1eb4cc..477c05df34 100644
--- a/go/arrow/flight/record_batch_reader.go
+++ b/go/arrow/flight/record_batch_reader.go
@@ -234,3 +234,30 @@ func StreamChunksFromReader(rdr array.RecordReader, ch chan<- StreamChunk) {
 		}
 	}
 }
+
+func ConcatenateReaders(rdrs []array.RecordReader, ch chan<- StreamChunk) {
+	defer close(ch)
+	defer func() {
+		for _, r := range rdrs {
+			r.Release()
+		}
+
+		if err := recover(); err != nil {
+			ch <- StreamChunk{Err: fmt.Errorf("panic while reading: %s", err)}
+		}
+	}()
+
+	for _, r := range rdrs {
+		for r.Next() {
+			rec := r.Record()
+			rec.Retain()
+			ch <- StreamChunk{Data: rec}
+		}
+		if e, ok := r.(haserr); ok {
+			if e.Err() != nil {
+				ch <- StreamChunk{Err: e.Err()}
+				return
+			}
+		}
+	}
+}
diff --git a/go/arrow/scalar/temporal.go b/go/arrow/scalar/temporal.go
index a19cd49e3c..f76200ee32 100644
--- a/go/arrow/scalar/temporal.go
+++ b/go/arrow/scalar/temporal.go
@@ -85,12 +85,14 @@ func NewDurationScalar(val arrow.Duration, typ arrow.DataType) *Duration {
 
 type DateScalar interface {
 	TemporalScalar
+	ToTime() time.Time
 	date()
 }
 
 type TimeScalar interface {
 	TemporalScalar
 	Unit() arrow.TimeUnit
+	ToTime() time.Time
 	time()
 }
 
@@ -197,6 +199,9 @@ func (s *Date32) String() string {
 	}
 	return string(val.(*String).Value.Bytes())
 }
+func (s *Date32) ToTime() time.Time {
+	return s.Value.ToTime()
+}
 
 func NewDate32Scalar(val arrow.Date32) *Date32 {
 	return &Date32{scalar{arrow.FixedWidthTypes.Date32, true}, val}
@@ -227,6 +232,9 @@ func (s *Date64) String() string {
 	}
 	return string(val.(*String).Value.Bytes())
 }
+func (s *Date64) ToTime() time.Time {
+	return s.Value.ToTime()
+}
 
 func NewDate64Scalar(val arrow.Date64) *Date64 {
 	return &Date64{scalar{arrow.FixedWidthTypes.Date64, true}, val}
@@ -262,6 +270,10 @@ func (s *Time32) Data() []byte {
 	return (*[arrow.Time32SizeBytes]byte)(unsafe.Pointer(&s.Value))[:]
 }
 
+func (s *Time32) ToTime() time.Time {
+	return s.Value.ToTime(s.Unit())
+}
+
 func NewTime32Scalar(val arrow.Time32, typ arrow.DataType) *Time32 {
 	return &Time32{scalar{typ, true}, val}
 }
@@ -295,6 +307,10 @@ func (s *Time64) String() string {
 	return string(val.(*String).Value.Bytes())
 }
 
+func (s *Time64) ToTime() time.Time {
+	return s.Value.ToTime(s.Unit())
+}
+
 func NewTime64Scalar(val arrow.Time64, typ arrow.DataType) *Time64 {
 	return &Time64{scalar{typ, true}, val}
 }
@@ -328,6 +344,10 @@ func (s *Timestamp) String() string {
 	return string(val.(*String).Value.Bytes())
 }
 
+func (s *Timestamp) ToTime() time.Time {
+	return s.Value.ToTime(s.Unit())
+}
+
 func NewTimestampScalar(val arrow.Timestamp, typ arrow.DataType) *Timestamp {
 	return &Timestamp{scalar{typ, true}, val}
 }