You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ze...@apache.org on 2023/01/23 22:01:44 UTC
[arrow] branch master updated: GH-33840: [Go] Improve SQLite Flight SQL Example and provide mainprog (#33841)
This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 8449c55371 GH-33840: [Go] Improve SQLite Flight SQL Example and provide mainprog (#33841)
8449c55371 is described below
commit 8449c553710a6c03d2f21a657c91c87742fa8be2
Author: Matt Topol <zo...@gmail.com>
AuthorDate: Mon Jan 23 17:01:37 2023 -0500
GH-33840: [Go] Improve SQLite Flight SQL Example and provide mainprog (#33841)
### Rationale for this change
Better facilitate Flight SQL testing by providing a simple mainprog which can be easily installed via `go get` or `go install` for testing and for running CI against.
### What changes are included in this PR?
Improvements to the FlightSQL SQLite Go example and the creation of a mainprog at github.com/apache/arrow/go/v11/arrow/flight/flightsql/cmd/sqlite_flightsql_server
### Are these changes tested?
Covered by existing tests in sqlite_server_test.go
* Closes: #33840
Authored-by: Matt Topol <zo...@gmail.com>
Signed-off-by: Matt Topol <zo...@gmail.com>
---
go/arrow/cdata/cdata_exports.go | 58 +++++---
go/arrow/cdata/cdata_fulltest.c | 21 +++
go/arrow/cdata/cdata_test.go | 46 +++++++
go/arrow/cdata/cdata_test_framework.go | 26 ++++
go/arrow/datatype_nested.go | 4 +-
.../example/cmd/sqlite_flightsql_server/main.go | 58 ++++++++
.../flight/flightsql/example/sql_batch_reader.go | 74 +++++++---
go/arrow/flight/flightsql/example/sqlite_server.go | 151 ++++++++++++++++-----
go/arrow/flight/record_batch_reader.go | 27 ++++
go/arrow/scalar/temporal.go | 20 +++
10 files changed, 416 insertions(+), 69 deletions(-)
diff --git a/go/arrow/cdata/cdata_exports.go b/go/arrow/cdata/cdata_exports.go
index b998decf0b..722cac71b0 100644
--- a/go/arrow/cdata/cdata_exports.go
+++ b/go/arrow/cdata/cdata_exports.go
@@ -38,6 +38,7 @@ import (
"fmt"
"reflect"
"runtime/cgo"
+ "strconv"
"strings"
"unsafe"
@@ -152,6 +153,8 @@ func (exp *schemaExporter) exportFormat(dt arrow.DataType) string {
return fmt.Sprintf("w:%d", dt.ByteWidth)
case *arrow.Decimal128Type:
return fmt.Sprintf("d:%d,%d", dt.Precision, dt.Scale)
+ case *arrow.Decimal256Type:
+ return fmt.Sprintf("d:%d,%d,256", dt.Precision, dt.Scale)
case *arrow.BinaryType:
return "z"
case *arrow.LargeBinaryType:
@@ -235,6 +238,20 @@ func (exp *schemaExporter) exportFormat(dt arrow.DataType) string {
exp.flags |= C.ARROW_FLAG_DICTIONARY_ORDERED
}
return exp.exportFormat(dt.IndexType)
+ case arrow.UnionType:
+ var b strings.Builder
+ if dt.Mode() == arrow.SparseMode {
+ b.WriteString("+us:")
+ } else {
+ b.WriteString("+ud:")
+ }
+ for i, c := range dt.TypeCodes() {
+ if i != 0 {
+ b.WriteByte(',')
+ }
+ b.WriteString(strconv.Itoa(int(c)))
+ }
+ return b.String()
}
panic("unsupported data type for export")
}
@@ -250,23 +267,11 @@ func (exp *schemaExporter) export(field arrow.Field) {
case *arrow.DictionaryType:
exp.dict = new(schemaExporter)
exp.dict.export(arrow.Field{Type: dt.ValueType})
- case *arrow.ListType:
- exp.children = make([]schemaExporter, 1)
- exp.children[0].export(dt.ElemField())
- case *arrow.LargeListType:
- exp.children = make([]schemaExporter, 1)
- exp.children[0].export(dt.ElemField())
- case *arrow.StructType:
+ case arrow.NestedType:
exp.children = make([]schemaExporter, len(dt.Fields()))
for i, f := range dt.Fields() {
exp.children[i].export(f)
}
- case *arrow.MapType:
- exp.children = make([]schemaExporter, 1)
- exp.children[0].export(dt.ValueField())
- case *arrow.FixedSizeListType:
- exp.children = make([]schemaExporter, 1)
- exp.children[0].export(dt.ElemField())
}
exp.exportMeta(&field.Metadata)
@@ -364,9 +369,21 @@ func exportArray(arr arrow.Array, out *CArrowArray, outSchema *CArrowSchema) {
out.n_buffers = C.int64_t(len(arr.Data().Buffers()))
if out.n_buffers > 0 {
- buffers := allocateBufferPtrArr(len(arr.Data().Buffers()))
- for i := range arr.Data().Buffers() {
- buf := arr.Data().Buffers()[i]
+ var (
+ nbuffers = len(arr.Data().Buffers())
+ bufs = arr.Data().Buffers()
+ )
+ // unions don't have validity bitmaps, but we keep them shifted
+ // to make processing easier in other contexts. This means that
+ // we have to adjust for union arrays
+ if arr.DataType().ID() == arrow.DENSE_UNION || arr.DataType().ID() == arrow.SPARSE_UNION {
+ out.n_buffers--
+ nbuffers--
+ bufs = bufs[1:]
+ }
+ buffers := allocateBufferPtrArr(nbuffers)
+ for i := range bufs {
+ buf := bufs[i]
if buf == nil || buf.Len() == 0 {
buffers[i] = nil
continue
@@ -408,6 +425,15 @@ func exportArray(arr arrow.Array, out *CArrowArray, outSchema *CArrowSchema) {
case *array.Dictionary:
out.dictionary = (*CArrowArray)(C.malloc(C.sizeof_struct_ArrowArray))
exportArray(arr.Dictionary(), out.dictionary, nil)
+ case array.Union:
+ out.n_children = C.int64_t(arr.NumFields())
+ childPtrs := allocateArrowArrayPtrArr(arr.NumFields())
+ children := allocateArrowArrayArr(arr.NumFields())
+ for i := 0; i < arr.NumFields(); i++ {
+ exportArray(arr.Field(i), &children[i], nil)
+ childPtrs[i] = &children[i]
+ }
+ out.children = (**CArrowArray)(unsafe.Pointer(&childPtrs[0]))
default:
out.n_children = 0
out.children = nil
diff --git a/go/arrow/cdata/cdata_fulltest.c b/go/arrow/cdata/cdata_fulltest.c
index 837d347d53..4731c0ef39 100644
--- a/go/arrow/cdata/cdata_fulltest.c
+++ b/go/arrow/cdata/cdata_fulltest.c
@@ -244,6 +244,27 @@ struct ArrowSchema** test_map(const char** fmts, const char** names, int64_t* fl
return schemas;
}
+struct ArrowSchema** test_union(const char** fmts, const char** names, int64_t* flags, const int n) {
+ struct ArrowSchema** schemas = malloc(sizeof(struct ArrowSchema*)*n);
+ for (int i = 0; i < n; ++i) {
+ schemas[i] = malloc(sizeof(struct ArrowSchema));
+ *schemas[i] = (struct ArrowSchema) {
+ .format = fmts[i],
+ .name = names[i],
+ .metadata = NULL,
+ .flags = flags[i],
+ .children = NULL,
+ .n_children = 0,
+ .dictionary = NULL,
+ .release = &release_nested_dynamic,
+ };
+ }
+
+ schemas[0]->n_children = n-1;
+ schemas[0]->children = &schemas[1];
+ return schemas;
+}
+
struct streamcounter {
int n;
int max;
diff --git a/go/arrow/cdata/cdata_test.go b/go/arrow/cdata/cdata_test.go
index 8976f377f2..32d41aa087 100644
--- a/go/arrow/cdata/cdata_test.go
+++ b/go/arrow/cdata/cdata_test.go
@@ -579,6 +579,50 @@ func createTestMapArr() arrow.Array {
return bld.NewArray()
}
+func createTestSparseUnion() arrow.Array {
+ return createTestUnionArr(arrow.SparseMode)
+}
+
+func createTestDenseUnion() arrow.Array {
+ return createTestUnionArr(arrow.DenseMode)
+}
+
+func createTestUnionArr(mode arrow.UnionMode) arrow.Array {
+ fields := []arrow.Field{
+ arrow.Field{Name: "u0", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
+ arrow.Field{Name: "u1", Type: arrow.PrimitiveTypes.Uint8, Nullable: true},
+ }
+ typeCodes := []arrow.UnionTypeCode{5, 10}
+ bld := array.NewBuilder(memory.DefaultAllocator, arrow.UnionOf(mode, fields, typeCodes)).(array.UnionBuilder)
+ defer bld.Release()
+
+ u0Bld := bld.Child(0).(*array.Int32Builder)
+ u1Bld := bld.Child(1).(*array.Uint8Builder)
+
+ bld.Append(5)
+ if mode == arrow.SparseMode {
+ u1Bld.AppendNull()
+ }
+ u0Bld.Append(128)
+ bld.Append(5)
+ if mode == arrow.SparseMode {
+ u1Bld.AppendNull()
+ }
+ u0Bld.Append(256)
+ bld.Append(10)
+ if mode == arrow.SparseMode {
+ u0Bld.AppendNull()
+ }
+ u1Bld.Append(127)
+ bld.Append(10)
+ if mode == arrow.SparseMode {
+ u0Bld.AppendNull()
+ }
+ u1Bld.Append(25)
+
+ return bld.NewArray()
+}
+
func TestNestedArrays(t *testing.T) {
tests := []struct {
name string
@@ -589,6 +633,8 @@ func TestNestedArrays(t *testing.T) {
{"fixed size list", createTestFixedSizeList},
{"struct", createTestStructArr},
{"map", createTestMapArr},
+ {"sparse union", createTestSparseUnion},
+ {"dense union", createTestDenseUnion},
}
for _, tt := range tests {
diff --git a/go/arrow/cdata/cdata_test_framework.go b/go/arrow/cdata/cdata_test_framework.go
index 0ddda26938..3f0b81f90a 100644
--- a/go/arrow/cdata/cdata_test_framework.go
+++ b/go/arrow/cdata/cdata_test_framework.go
@@ -53,6 +53,7 @@ package cdata
// struct ArrowSchema** test_struct(const char** fmts, const char** names, int64_t* flags, const int n);
// struct ArrowSchema** test_map(const char** fmts, const char** names, int64_t* flags, const int n);
// struct ArrowSchema** test_schema(const char** fmts, const char** names, int64_t* flags, const int n);
+// struct ArrowSchema** test_union(const char** fmts, const char** names, int64_t* flags, const int n);
// int test_exported_stream(struct ArrowArrayStream* stream);
import "C"
import (
@@ -182,6 +183,24 @@ func testMap(fmts, names []string, flags []int64) **CArrowSchema {
return C.test_map((**C.char)(unsafe.Pointer(&cfmts[0])), (**C.char)(unsafe.Pointer(&cnames[0])), (*C.int64_t)(unsafe.Pointer(&cflags[0])), C.int(len(fmts)))
}
+func testUnion(fmts, names []string, flags []int64) **CArrowSchema {
+ if len(fmts) != len(names) || len(names) != len(flags) {
+ panic("testing unions must all have the same size slices in args")
+ }
+
+ cfmts := make([]*C.char, len(fmts))
+ cnames := make([]*C.char, len(names))
+ cflags := make([]C.int64_t, len(flags))
+
+ for i := range fmts {
+ cfmts[i] = C.CString(fmts[i])
+ cnames[i] = C.CString(names[i])
+ cflags[i] = C.int64_t(flags[i])
+ }
+
+ return C.test_union((**C.char)(unsafe.Pointer(&cfmts[0])), (**C.char)(unsafe.Pointer(&cnames[0])), (*C.int64_t)(unsafe.Pointer(&cflags[0])), C.int(len(fmts)))
+}
+
func testSchema(fmts, names []string, flags []int64) **CArrowSchema {
if len(fmts) != len(names) || len(names) != len(flags) {
panic("testing structs must all have the same size slices in args")
@@ -235,6 +254,13 @@ func createCArr(arr arrow.Array) *CArrowArray {
clist := []*CArrowArray{createCArr(arr.ListValues())}
children = (**CArrowArray)(unsafe.Pointer(&clist[0]))
nchildren += 1
+ case array.Union:
+ clist := []*CArrowArray{}
+ for i := 0; i < arr.NumFields(); i++ {
+ clist = append(clist, createCArr(arr.Field(i)))
+ nchildren += 1
+ }
+ children = (**CArrowArray)(unsafe.Pointer(&clist[0]))
}
carr.children = children
diff --git a/go/arrow/datatype_nested.go b/go/arrow/datatype_nested.go
index 2fd9779cf5..8966df90bb 100644
--- a/go/arrow/datatype_nested.go
+++ b/go/arrow/datatype_nested.go
@@ -596,7 +596,7 @@ func (t *SparseUnionType) Fingerprint() string {
return typeFingerprint(t) + "[s" + t.fingerprint()
}
func (SparseUnionType) Layout() DataTypeLayout {
- return DataTypeLayout{Buffers: []BufferSpec{SpecAlwaysNull(), SpecFixedWidth(Uint8SizeBytes)}}
+ return DataTypeLayout{Buffers: []BufferSpec{SpecFixedWidth(Uint8SizeBytes)}}
}
func (t *SparseUnionType) String() string {
return t.Name() + t.unionType.String()
@@ -659,7 +659,7 @@ func (t *DenseUnionType) Fingerprint() string {
}
func (DenseUnionType) Layout() DataTypeLayout {
- return DataTypeLayout{Buffers: []BufferSpec{SpecAlwaysNull(), SpecFixedWidth(Uint8SizeBytes), SpecFixedWidth(Int32SizeBytes)}}
+ return DataTypeLayout{Buffers: []BufferSpec{SpecFixedWidth(Uint8SizeBytes), SpecFixedWidth(Int32SizeBytes)}}
}
func (DenseUnionType) OffsetTypeTraits() OffsetTraits { return Int32Traits }
diff --git a/go/arrow/flight/flightsql/example/cmd/sqlite_flightsql_server/main.go b/go/arrow/flight/flightsql/example/cmd/sqlite_flightsql_server/main.go
new file mode 100644
index 0000000000..eb01050262
--- /dev/null
+++ b/go/arrow/flight/flightsql/example/cmd/sqlite_flightsql_server/main.go
@@ -0,0 +1,58 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+// +build go1.18
+
+package main
+
+import (
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "os"
+ "strconv"
+
+ "github.com/apache/arrow/go/v11/arrow/flight"
+ "github.com/apache/arrow/go/v11/arrow/flight/flightsql"
+ "github.com/apache/arrow/go/v11/arrow/flight/flightsql/example"
+)
+
+func main() {
+ var (
+ host = flag.String("host", "localhost", "hostname to bind to")
+ port = flag.Int("port", 0, "port to bind to")
+ )
+
+ flag.Parse()
+
+ srv, err := example.NewSQLiteFlightSQLServer()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ server := flight.NewServerWithMiddleware(nil)
+ server.RegisterFlightService(flightsql.NewFlightServer(srv))
+ server.Init(net.JoinHostPort(*host, strconv.Itoa(*port)))
+ server.SetShutdownOnSignals(os.Interrupt, os.Kill)
+
+ fmt.Println("Starting SQLite Flight SQL Server on", server.Addr(), "...")
+
+ if err := server.Serve(); err != nil {
+ log.Fatal(err)
+ }
+}
diff --git a/go/arrow/flight/flightsql/example/sql_batch_reader.go b/go/arrow/flight/flightsql/example/sql_batch_reader.go
index cc4249a83b..d7b05822db 100644
--- a/go/arrow/flight/flightsql/example/sql_batch_reader.go
+++ b/go/arrow/flight/flightsql/example/sql_batch_reader.go
@@ -22,6 +22,7 @@ package example
import (
"database/sql"
"reflect"
+ "strconv"
"strings"
"sync/atomic"
@@ -39,9 +40,15 @@ func getArrowTypeFromString(dbtype string) arrow.DataType {
}
switch dbtype {
+ case "tinyint":
+ return arrow.PrimitiveTypes.Int8
+ case "mediumint":
+ return arrow.PrimitiveTypes.Int32
case "int", "integer":
return arrow.PrimitiveTypes.Int64
- case "real":
+ case "float":
+ return arrow.PrimitiveTypes.Float32
+ case "real", "double":
return arrow.PrimitiveTypes.Float64
case "blob":
return arrow.BinaryTypes.Binary
@@ -52,14 +59,31 @@ func getArrowTypeFromString(dbtype string) arrow.DataType {
}
}
+var sqliteDenseUnion = arrow.DenseUnionOf([]arrow.Field{
+ {Name: "int", Type: arrow.PrimitiveTypes.Int64, Nullable: true},
+ {Name: "float", Type: arrow.PrimitiveTypes.Float64, Nullable: true},
+ {Name: "string", Type: arrow.BinaryTypes.String, Nullable: true},
+}, []arrow.UnionTypeCode{0, 1, 2})
+
func getArrowType(c *sql.ColumnType) arrow.DataType {
dbtype := strings.ToLower(c.DatabaseTypeName())
if dbtype == "" {
+ if c.ScanType() == nil {
+ return sqliteDenseUnion
+ }
switch c.ScanType().Kind() {
+ case reflect.Int8, reflect.Uint8:
+ return arrow.PrimitiveTypes.Int8
+ case reflect.Int32, reflect.Uint32:
+ return arrow.PrimitiveTypes.Int32
case reflect.Int, reflect.Int64, reflect.Uint64:
return arrow.PrimitiveTypes.Int64
- case reflect.Float32, reflect.Float64:
+ case reflect.Float32:
+ return arrow.PrimitiveTypes.Float32
+ case reflect.Float64:
return arrow.PrimitiveTypes.Float64
+ case reflect.String:
+ return arrow.BinaryTypes.String
}
}
return getArrowTypeFromString(dbtype)
@@ -83,9 +107,11 @@ func NewSqlBatchReaderWithSchema(mem memory.Allocator, schema *arrow.Schema, row
rowdest := make([]interface{}, len(schema.Fields()))
for i, f := range schema.Fields() {
switch f.Type.ID() {
- case arrow.UINT8:
+ case arrow.DENSE_UNION, arrow.SPARSE_UNION:
+ rowdest[i] = new(interface{})
+ case arrow.UINT8, arrow.INT8:
if f.Nullable {
- rowdest[i] = &sql.NullInt32{}
+ rowdest[i] = &sql.NullByte{}
} else {
rowdest[i] = new(uint8)
}
@@ -101,7 +127,7 @@ func NewSqlBatchReaderWithSchema(mem memory.Allocator, schema *arrow.Schema, row
} else {
rowdest[i] = new(int64)
}
- case arrow.FLOAT64:
+ case arrow.FLOAT32, arrow.FLOAT64:
if f.Nullable {
rowdest[i] = &sql.NullFloat64{}
} else {
@@ -140,13 +166,18 @@ func NewSqlBatchReader(mem memory.Allocator, rows *sql.Rows) (*SqlBatchReader, e
fields := make([]arrow.Field, len(cols))
for i, c := range cols {
fields[i].Name = c.Name()
+ if c.Name() == "?" {
+ fields[i].Name += ":" + strconv.Itoa(i)
+ }
fields[i].Nullable, _ = c.Nullable()
fields[i].Type = getArrowType(c)
fields[i].Metadata = getColumnMetadata(bldr, getSqlTypeFromTypeName(c.DatabaseTypeName()), "")
switch fields[i].Type.ID() {
- case arrow.UINT8:
+ case arrow.DENSE_UNION, arrow.SPARSE_UNION:
+ rowdest[i] = new(interface{})
+ case arrow.UINT8, arrow.INT8:
if fields[i].Nullable {
- rowdest[i] = &sql.NullInt32{}
+ rowdest[i] = &sql.NullByte{}
} else {
rowdest[i] = new(uint8)
}
@@ -162,7 +193,7 @@ func NewSqlBatchReader(mem memory.Allocator, rows *sql.Rows) (*SqlBatchReader, e
} else {
rowdest[i] = new(int64)
}
- case arrow.FLOAT64:
+ case arrow.FLOAT64, arrow.FLOAT32:
if fields[i].Nullable {
rowdest[i] = &sql.NullFloat64{}
} else {
@@ -231,6 +262,12 @@ func (r *SqlBatchReader) Next() bool {
switch v := v.(type) {
case *uint8:
fb.(*array.Uint8Builder).Append(*v)
+ case *sql.NullByte:
+ if !v.Valid {
+ fb.AppendNull()
+ } else {
+ fb.(*array.Uint8Builder).Append(v.Byte)
+ }
case *int64:
fb.(*array.Int64Builder).Append(*v)
case *sql.NullInt64:
@@ -245,20 +282,25 @@ func (r *SqlBatchReader) Next() bool {
if !v.Valid {
fb.AppendNull()
} else {
- switch b := fb.(type) {
- case *array.Int32Builder:
- b.Append(v.Int32)
- case *array.Uint8Builder:
- b.Append(uint8(v.Int32))
- }
+ fb.(*array.Int32Builder).Append(v.Int32)
}
case *float64:
- fb.(*array.Float64Builder).Append(*v)
+ switch b := fb.(type) {
+ case *array.Float64Builder:
+ b.Append(*v)
+ case *array.Float32Builder:
+ b.Append(float32(*v))
+ }
case *sql.NullFloat64:
if !v.Valid {
fb.AppendNull()
} else {
- fb.(*array.Float64Builder).Append(v.Float64)
+ switch b := fb.(type) {
+ case *array.Float64Builder:
+ b.Append(v.Float64)
+ case *array.Float32Builder:
+ b.Append(float32(v.Float64))
+ }
}
case *[]byte:
if v == nil {
diff --git a/go/arrow/flight/flightsql/example/sqlite_server.go b/go/arrow/flight/flightsql/example/sqlite_server.go
index 1b1707aa79..c6b0990312 100644
--- a/go/arrow/flight/flightsql/example/sqlite_server.go
+++ b/go/arrow/flight/flightsql/example/sqlite_server.go
@@ -142,7 +142,7 @@ func prepareQueryForGetKeys(filter string) string {
type Statement struct {
stmt *sql.Stmt
- params []interface{}
+ params [][]interface{}
}
type SQLiteFlightSQLServer struct {
@@ -183,6 +183,7 @@ func NewSQLiteFlightSQLServer() (*SQLiteFlightSQLServer, error) {
return nil, err
}
ret := &SQLiteFlightSQLServer{db: db}
+ ret.Alloc = memory.DefaultAllocator
for k, v := range SqlInfoResultMap() {
ret.RegisterSqlInfo(flightsql.SqlInfo(k), v)
}
@@ -380,41 +381,113 @@ func doGetQuery(ctx context.Context, mem memory.Allocator, db *sql.DB, query str
return schema, ch, nil
}
-func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (*arrow.Schema, <-chan flight.StreamChunk, error) {
+func (s *SQLiteFlightSQLServer) DoGetPreparedStatement(ctx context.Context, cmd flightsql.PreparedStatementQuery) (schema *arrow.Schema, out <-chan flight.StreamChunk, err error) {
val, ok := s.prepared.Load(string(cmd.GetPreparedStatementHandle()))
if !ok {
return nil, nil, status.Error(codes.InvalidArgument, "prepared statement not found")
}
stmt := val.(Statement)
- rows, err := stmt.stmt.QueryContext(ctx, stmt.params...)
- if err != nil {
- return nil, nil, err
- }
+ readers := make([]array.RecordReader, 0, len(stmt.params))
+ if len(stmt.params) == 0 {
+ rows, err := stmt.stmt.QueryContext(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
- rdr, err := NewSqlBatchReader(s.Alloc, rows)
- if err != nil {
- return nil, nil, err
+ rdr, err := NewSqlBatchReader(s.Alloc, rows)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ schema = rdr.schema
+ readers = append(readers, rdr)
+ } else {
+ defer func() {
+ if err != nil {
+ for _, r := range readers {
+ r.Release()
+ }
+ }
+ }()
+ var (
+ rows *sql.Rows
+ rdr *SqlBatchReader
+ )
+ // if we have multiple rows of bound params, execute the query
+ // multiple times and concatenate the result sets.
+ for _, p := range stmt.params {
+ rows, err = stmt.stmt.QueryContext(ctx, p...)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ if schema == nil {
+ rdr, err = NewSqlBatchReader(s.Alloc, rows)
+ if err != nil {
+ return nil, nil, err
+ }
+ schema = rdr.schema
+ } else {
+ rdr, err = NewSqlBatchReaderWithSchema(s.Alloc, schema, rows)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+
+ readers = append(readers, rdr)
+ }
}
- schema := rdr.schema
ch := make(chan flight.StreamChunk)
- go flight.StreamChunksFromReader(rdr, ch)
- return schema, ch, nil
+ go flight.ConcatenateReaders(readers, ch)
+ out = ch
+ return
}
-func getParamsForStatement(rdr flight.MessageReader) (params []interface{}, err error) {
+func scalarToIFace(s scalar.Scalar) (interface{}, error) {
+ if !s.IsValid() {
+ return nil, nil
+ }
+
+ switch val := s.(type) {
+ case *scalar.Int8:
+ return val.Value, nil
+ case *scalar.Uint8:
+ return val.Value, nil
+ case *scalar.Int32:
+ return val.Value, nil
+ case *scalar.Int64:
+ return val.Value, nil
+ case *scalar.Float32:
+ return val.Value, nil
+ case *scalar.Float64:
+ return val.Value, nil
+ case *scalar.String:
+ return string(val.Value.Bytes()), nil
+ case *scalar.Binary:
+ return val.Value.Bytes(), nil
+ case scalar.DateScalar:
+ return val.ToTime(), nil
+ case scalar.TimeScalar:
+ return val.ToTime(), nil
+ case *scalar.DenseUnion:
+ return scalarToIFace(val.Value)
+ default:
+ return nil, fmt.Errorf("unsupported type: %s", val)
+ }
+}
+
+func getParamsForStatement(rdr flight.MessageReader) (params [][]interface{}, err error) {
+ params = make([][]interface{}, 0)
for rdr.Next() {
rec := rdr.Record()
nrows := int(rec.NumRows())
ncols := int(rec.NumCols())
- if len(params) < int(ncols) {
- params = make([]interface{}, ncols)
- }
-
for i := 0; i < nrows; i++ {
+ invokeParams := make([]interface{}, ncols)
for c := 0; c < ncols; c++ {
col := rec.Column(c)
sc, err := scalar.GetScalar(col, i)
@@ -425,21 +498,12 @@ func getParamsForStatement(rdr flight.MessageReader) (params []interface{}, err
r.Release()
}
- switch v := sc.(*scalar.DenseUnion).Value.(type) {
- case *scalar.Int64:
- params[c] = v.Value
- case *scalar.Float32:
- params[c] = v.Value
- case *scalar.Float64:
- params[c] = v.Value
- case *scalar.String:
- params[c] = string(v.Value.Bytes())
- case *scalar.Binary:
- params[c] = v.Value.Bytes()
- default:
- return nil, fmt.Errorf("unsupported type: %s", v)
+ invokeParams[c], err = scalarToIFace(sc)
+ if err != nil {
+ return nil, err
}
}
+ params = append(params, invokeParams)
}
}
@@ -475,13 +539,30 @@ func (s *SQLiteFlightSQLServer) DoPutPreparedStatementUpdate(ctx context.Context
return 0, status.Errorf(codes.Internal, "error gathering parameters for prepared statement: %s", err.Error())
}
- stmt.params = args
- result, err := stmt.stmt.ExecContext(ctx, args...)
- if err != nil {
- return 0, err
+ if len(args) == 0 {
+ result, err := stmt.stmt.ExecContext(ctx)
+ if err != nil {
+ return 0, err
+ }
+
+ return result.RowsAffected()
+ }
+
+ var totalAffected int64
+ for _, p := range args {
+ result, err := stmt.stmt.ExecContext(ctx, p...)
+ if err != nil {
+ return totalAffected, err
+ }
+
+ n, err := result.RowsAffected()
+ if err != nil {
+ return totalAffected, err
+ }
+ totalAffected += n
}
- return result.RowsAffected()
+ return totalAffected, nil
}
func (s *SQLiteFlightSQLServer) GetFlightInfoPrimaryKeys(_ context.Context, cmd flightsql.TableRef, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
diff --git a/go/arrow/flight/record_batch_reader.go b/go/arrow/flight/record_batch_reader.go
index 277f1eb4cc..477c05df34 100644
--- a/go/arrow/flight/record_batch_reader.go
+++ b/go/arrow/flight/record_batch_reader.go
@@ -234,3 +234,30 @@ func StreamChunksFromReader(rdr array.RecordReader, ch chan<- StreamChunk) {
}
}
}
+
+func ConcatenateReaders(rdrs []array.RecordReader, ch chan<- StreamChunk) {
+ defer close(ch)
+ defer func() {
+ for _, r := range rdrs {
+ r.Release()
+ }
+
+ if err := recover(); err != nil {
+ ch <- StreamChunk{Err: fmt.Errorf("panic while reading: %s", err)}
+ }
+ }()
+
+ for _, r := range rdrs {
+ for r.Next() {
+ rec := r.Record()
+ rec.Retain()
+ ch <- StreamChunk{Data: rec}
+ }
+ if e, ok := r.(haserr); ok {
+ if e.Err() != nil {
+ ch <- StreamChunk{Err: e.Err()}
+ return
+ }
+ }
+ }
+}
diff --git a/go/arrow/scalar/temporal.go b/go/arrow/scalar/temporal.go
index a19cd49e3c..f76200ee32 100644
--- a/go/arrow/scalar/temporal.go
+++ b/go/arrow/scalar/temporal.go
@@ -85,12 +85,14 @@ func NewDurationScalar(val arrow.Duration, typ arrow.DataType) *Duration {
type DateScalar interface {
TemporalScalar
+ ToTime() time.Time
date()
}
type TimeScalar interface {
TemporalScalar
Unit() arrow.TimeUnit
+ ToTime() time.Time
time()
}
@@ -197,6 +199,9 @@ func (s *Date32) String() string {
}
return string(val.(*String).Value.Bytes())
}
+func (s *Date32) ToTime() time.Time {
+ return s.Value.ToTime()
+}
func NewDate32Scalar(val arrow.Date32) *Date32 {
return &Date32{scalar{arrow.FixedWidthTypes.Date32, true}, val}
@@ -227,6 +232,9 @@ func (s *Date64) String() string {
}
return string(val.(*String).Value.Bytes())
}
+func (s *Date64) ToTime() time.Time {
+ return s.Value.ToTime()
+}
func NewDate64Scalar(val arrow.Date64) *Date64 {
return &Date64{scalar{arrow.FixedWidthTypes.Date64, true}, val}
@@ -262,6 +270,10 @@ func (s *Time32) Data() []byte {
return (*[arrow.Time32SizeBytes]byte)(unsafe.Pointer(&s.Value))[:]
}
+func (s *Time32) ToTime() time.Time {
+ return s.Value.ToTime(s.Unit())
+}
+
func NewTime32Scalar(val arrow.Time32, typ arrow.DataType) *Time32 {
return &Time32{scalar{typ, true}, val}
}
@@ -295,6 +307,10 @@ func (s *Time64) String() string {
return string(val.(*String).Value.Bytes())
}
+func (s *Time64) ToTime() time.Time {
+ return s.Value.ToTime(s.Unit())
+}
+
func NewTime64Scalar(val arrow.Time64, typ arrow.DataType) *Time64 {
return &Time64{scalar{typ, true}, val}
}
@@ -328,6 +344,10 @@ func (s *Timestamp) String() string {
return string(val.(*String).Value.Bytes())
}
+func (s *Timestamp) ToTime() time.Time {
+ return s.Value.ToTime(s.Unit())
+}
+
func NewTimestampScalar(val arrow.Timestamp, typ arrow.DataType) *Timestamp {
return &Timestamp{scalar{typ, true}, val}
}