You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "srebhan (via GitHub)" <gi...@apache.org> on 2023/04/11 08:34:07 UTC

[GitHub] [arrow] srebhan commented on a diff in pull request #34331: GH-34332: [Go][FlightRPC] Add driver for `database/sql` framework

srebhan commented on code in PR #34331:
URL: https://github.com/apache/arrow/pull/34331#discussion_r1162478621


##########
go/arrow/flight/flightsql/driver.go:
##########
@@ -0,0 +1,772 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package flightsql
+
+import (
+	"context"
+	"database/sql"
+	"database/sql/driver"
+	"errors"
+	"fmt"
+	"io"
+	"net/url"
+	"sort"
+	"strings"
+	"time"
+
+	"github.com/apache/arrow/go/v12/arrow"
+	"github.com/apache/arrow/go/v12/arrow/array"
+	"github.com/apache/arrow/go/v12/arrow/memory"
+
+	"google.golang.org/grpc"
+)
+
+var (
+	ErrNotSupported          = errors.New("not supported")
+	ErrOutOfRange            = errors.New("index out of range")
+	ErrTransactionInProgress = errors.New("transaction still in progress")
+)
+
+type Rows struct {
+	schema        *arrow.Schema
+	records       []arrow.Record
+	currentRecord int
+	currentRow    int
+}
+
+// Columns returns the names of the columns.
+func (r *Rows) Columns() []string {
+	if len(r.records) == 0 {
+		return nil
+	}
+
+	// All records have the same columns
+	var cols []string
+	for _, c := range r.schema.Fields() {
+		cols = append(cols, c.Name)
+	}
+
+	return cols
+}
+
+// Close closes the rows iterator.
+func (r *Rows) Close() error {
+	for _, rec := range r.records {
+		rec.Release()
+	}
+	r.currentRecord = 0
+	r.currentRow = 0
+
+	return nil
+}
+
+// Next is called to populate the next row of data into
+// the provided slice. The provided slice will be the same
+// size as the Columns() are wide.
+//
+// Next should return io.EOF when there are no more rows.
+//
+// The dest should not be written to outside of Next. Care
+// should be taken when closing Rows not to modify
+// a buffer held in dest.
+func (r *Rows) Next(dest []driver.Value) error {
+	if r.currentRecord >= len(r.records) {
+		return io.EOF
+	}
+	record := r.records[r.currentRecord]
+
+	if int64(r.currentRow) >= record.NumRows() {
+		return ErrOutOfRange
+	}
+
+	for i, arr := range record.Columns() {
+		v, err := fromArrowType(arr, r.currentRow)
+		if err != nil {
+			return err
+		}
+		dest[i] = v
+	}
+
+	r.currentRow++
+	if int64(r.currentRow) >= record.NumRows() {
+		r.currentRecord++
+		r.currentRow = 0
+	}
+
+	return nil
+}
+
+type Result struct {
+	affected   int64
+	lastinsert int64
+}
+
+// LastInsertId returns the database's auto-generated ID after, for example,
+// an INSERT into a table with primary key.
+func (r *Result) LastInsertId() (int64, error) {
+	if r.lastinsert < 0 {
+		return -1, ErrNotSupported
+	}
+	return r.lastinsert, nil
+}
+
+// RowsAffected returns the number of rows affected by the query.
+func (r *Result) RowsAffected() (int64, error) {
+	if r.affected < 0 {
+		return -1, ErrNotSupported
+	}
+	return r.affected, ErrNotSupported
+}
+
+type Stmt struct {
+	stmt   *PreparedStatement
+	client *Client
+
+	timeout time.Duration
+}
+
+// Close closes the statement.
+func (s *Stmt) Close() error {
+	ctx := context.Background()
+	if s.timeout > 0 {
+		var cancel context.CancelFunc
+		ctx, cancel = context.WithTimeout(ctx, s.timeout)
+		defer cancel()
+	}
+
+	return s.stmt.Close(ctx)
+}
+
+// NumInput returns the number of placeholder parameters.
+func (s *Stmt) NumInput() int {
+	schema := s.stmt.ParameterSchema()
+	if schema == nil {
+		// NumInput may also return -1, if the driver doesn't know its number
+		// of placeholders. In that case, the sql package will not sanity check
+		// Exec or Query argument counts.
+		return -1
+	}
+
+	// If NumInput returns >= 0, the sql package will sanity check argument
+	// counts from callers and return errors to the caller before the
+	// statement's Exec or Query methods are called.
+	return len(schema.Fields())
+}
+
+// Exec executes a query that doesn't return rows, such
+// as an INSERT or UPDATE.
+func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
+	var params []driver.NamedValue
+	for i, arg := range args {
+		params = append(params, driver.NamedValue{
+			Name:    fmt.Sprintf("arg_%d", i),
+			Ordinal: i,
+			Value:   arg,
+		})
+	}
+
+	return s.ExecContext(context.Background(), params)
+}
+
+// ExecContext executes a query that doesn't return rows, such as an INSERT or UPDATE.
+func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
+	if err := s.setParameters(args); err != nil {
+		return nil, err
+	}
+
+	if _, set := ctx.Deadline(); !set && s.timeout > 0 {
+		var cancel context.CancelFunc
+		ctx, cancel = context.WithTimeout(ctx, s.timeout)
+		defer cancel()
+	}
+
+	n, err := s.stmt.ExecuteUpdate(ctx)
+	if err != nil {
+		return nil, err
+	}
+
+	// FIXME: For now we ignore the number of affected records as it seems like
+	// the returned value is always one.
+	_ = n
+
+	return &Result{affected: -1, lastinsert: -1}, nil
+}
+
+// Query executes a query that may return rows, such as a SELECT.
+func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
+	var params []driver.NamedValue
+	for i, arg := range args {
+		params = append(params, driver.NamedValue{
+			Name:    fmt.Sprintf("arg_%d", i),
+			Ordinal: i,
+			Value:   arg,
+		})
+	}
+
+	return s.QueryContext(context.Background(), params)
+}
+
+// QueryContext executes a query that may return rows, such as a SELECT.
+func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
+	if err := s.setParameters(args); err != nil {
+		return nil, err
+	}
+
+	if _, set := ctx.Deadline(); !set && s.timeout > 0 {
+		var cancel context.CancelFunc
+		ctx, cancel = context.WithTimeout(ctx, s.timeout)
+		defer cancel()
+	}
+
+	info, err := s.stmt.Execute(ctx)
+	if err != nil {
+		return nil, err
+	}
+
+	rows := Rows{}
+	for _, endpoint := range info.Endpoint {
+		reader, err := s.client.DoGet(ctx, endpoint.GetTicket())
+		if err != nil {
+			return nil, fmt.Errorf("getting ticket failed: %w", err)
+		}
+		record, err := reader.Read()
+		if err != nil {
+			return nil, fmt.Errorf("reading record failed: %w", err)
+		}
+
+		if rows.schema == nil {
+			rows.schema = record.Schema()
+		}
+		if !rows.schema.Equal(record.Schema()) {
+			return nil, fmt.Errorf("mixed schemas %w", ErrNotSupported)
+		}
+		rows.records = append(rows.records, record)
+	}
+
+	return &rows, nil
+}
+
+func (s *Stmt) setParameters(args []driver.NamedValue) error {
+	if len(args) == 0 {
+		s.stmt.SetParameters(nil)
+		return nil
+	}
+
+	var fields []arrow.Field
+	sort.SliceStable(args, func(i, j int) bool {
+		return args[i].Ordinal < args[j].Ordinal
+	})
+
+	for _, arg := range args {
+		dt, err := toArrowDataType(arg.Value)
+		if err != nil {
+			return fmt.Errorf("schema: %w", err)
+		}
+		fields = append(fields, arrow.Field{
+			Name:     arg.Name,
+			Type:     dt,
+			Nullable: true,
+		})
+	}
+
+	schema := s.stmt.ParameterSchema()

Review Comment:
   Added unit-tests do show that you don't get a panic if you use the wrong type or wrong number of params. IMO the only way to panic is if the server does not check the parameters and panics there, but this is the server's fault, isn't it? Can you construct a concrete test-case @zeroshade that panics in the driver? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org