You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by fr...@apache.org on 2019/01/07 03:34:35 UTC

[calcite-avatica-go] branch master updated: [CALCITE-2763] Fix handling of nils (nulls) when executing queries and scanning query results

This is an automated email from the ASF dual-hosted git repository.

francischuang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite-avatica-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 8b69d6a  [CALCITE-2763] Fix handling of nils (nulls) when executing queries and scanning query results
8b69d6a is described below

commit 8b69d6a9b4809cfef41e10c625af96ac75165696
Author: Francis Chuang <fr...@apache.org>
AuthorDate: Mon Jan 7 14:25:13 2019 +1100

    [CALCITE-2763] Fix handling of nils (nulls) when executing queries and scanning query results
---
 driver_hsqldb_test.go  | 264 +++++++++++++++++++++++++++++++++++++
 driver_phoenix_test.go | 344 ++++++++++++++++++++++++++++++++++++++++++++++++-
 rows.go                |   6 +
 statement.go           |   1 +
 4 files changed, 612 insertions(+), 3 deletions(-)

diff --git a/driver_hsqldb_test.go b/driver_hsqldb_test.go
index 539a80f..39143ed 100644
--- a/driver_hsqldb_test.go
+++ b/driver_hsqldb_test.go
@@ -286,6 +286,270 @@ func TestHSQLDBDataTypes(t *testing.T) {
 	})
 }
 
+func TestHSQLDBSQLNullTypes(t *testing.T) {
+
+	skipTestIfNotHSQLDB(t)
+
+	runTests(t, dsn, func(dbt *DBTest) {
+
+		// Create and seed table
+		dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` (
+				id INTEGER PRIMARY KEY,
+				int INTEGER,
+				tint TINYINT,
+				sint SMALLINT,
+				bint BIGINT,
+				num NUMERIC(10,3),
+				dec DECIMAL(10,3),
+				re REAL,
+				flt FLOAT,
+				dbl DOUBLE,
+				bool BOOLEAN,
+				ch CHAR(3),
+				var VARCHAR(128),
+				bin BINARY(20),
+				varbin VARBINARY(128),
+				dt DATE,
+				tmstmp TIMESTAMP,
+			    )`)
+
+		var (
+			idValue                 = time.Now().Unix()
+			integerValue            = sql.NullInt64{}
+			tintValue               = sql.NullInt64{}
+			sintValue               = sql.NullInt64{}
+			bintValue               = sql.NullInt64{}
+			numValue                = sql.NullString{}
+			decValue                = sql.NullString{}
+			reValue                 = sql.NullFloat64{}
+			fltValue                = sql.NullFloat64{}
+			dblValue                = sql.NullFloat64{}
+			booleanValue            = sql.NullBool{}
+			chValue                 = sql.NullString{}
+			varcharValue            = sql.NullString{}
+			binValue     *[]byte    = nil
+			varbinValue  *[]byte    = nil
+			dtValue      *time.Time = nil
+			tmstmpValue  *time.Time = nil
+		)
+
+		dbt.mustExec(`INSERT INTO `+dbt.tableName+` (id, int, tint, sint, bint, num, dec, re, flt, dbl, bool, ch, var, bin, varbin, dt, tmstmp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+			idValue,
+			integerValue,
+			tintValue,
+			sintValue,
+			bintValue,
+			numValue,
+			decValue,
+			reValue,
+			fltValue,
+			dblValue,
+			booleanValue,
+			chValue,
+			varcharValue,
+			binValue,
+			varbinValue,
+			dtValue,
+			tmstmpValue,
+		)
+
+		rows := dbt.mustQuery("SELECT * FROM "+dbt.tableName+" WHERE id = ?", idValue)
+		defer rows.Close()
+
+		var (
+			id      int64
+			integer sql.NullInt64
+			tint    sql.NullInt64
+			sint    sql.NullInt64
+			bint    sql.NullInt64
+			num     sql.NullString
+			dec     sql.NullString
+			re      sql.NullFloat64
+			flt     sql.NullFloat64
+			dbl     sql.NullFloat64
+			boolean sql.NullBool
+			ch      sql.NullString
+			varchar sql.NullString
+			bin     *[]byte
+			varbin  *[]byte
+			dt      *time.Time
+			tmstmp  *time.Time
+		)
+
+		for rows.Next() {
+			err := rows.Scan(&id, &integer, &tint, &sint, &bint, &num, &dec, &re, &flt, &dbl, &boolean, &ch, &varchar, &bin, &varbin, &dt, &tmstmp)
+
+			if err != nil {
+				dbt.Fatal(err)
+			}
+		}
+
+		comparisons := []struct {
+			result   interface{}
+			expected interface{}
+		}{
+			{integer, integerValue},
+			{tint, tintValue},
+			{sint, sintValue},
+			{bint, bintValue},
+			{num, numValue},
+			{dec, decValue},
+			{re, reValue},
+			{flt, fltValue},
+			{dbl, dblValue},
+			{boolean, booleanValue},
+			{ch, chValue},
+			{varchar, varcharValue},
+			{bin, binValue},
+			{varbin, varbinValue},
+			{dt, dtValue},
+			{tmstmp, tmstmpValue},
+		}
+
+		for i, tt := range comparisons {
+
+			if v, ok := tt.expected.(time.Time); ok {
+
+				if !v.Equal(tt.result.(time.Time)) {
+					dbt.Fatalf("Expected %v for case %d, got %v.", tt.expected, i, tt.result)
+				}
+
+			} else if v, ok := tt.expected.([]byte); ok {
+
+				if !bytes.Equal(v, tt.result.([]byte)) {
+					dbt.Fatalf("Expected %v for case %d, got %v.", tt.expected, i, tt.result)
+				}
+
+			} else if tt.expected != tt.result {
+				dbt.Errorf("Expected %v for case %d, got %v.", tt.expected, i, tt.result)
+			}
+		}
+	})
+}
+
+func TestHSQLDBNulls(t *testing.T) {
+
+	skipTestIfNotHSQLDB(t)
+
+	runTests(t, dsn, func(dbt *DBTest) {
+
+		// Create and seed table
+		dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` (
+				id INTEGER PRIMARY KEY,
+				int INTEGER,
+				tint TINYINT,
+				sint SMALLINT,
+				bint BIGINT,
+				num NUMERIC(10,3),
+				dec DECIMAL(10,3),
+				re REAL,
+				flt FLOAT,
+				dbl DOUBLE,
+				bool BOOLEAN,
+				ch CHAR(3),
+				var VARCHAR(128),
+				bin BINARY(20),
+				varbin VARBINARY(128),
+				dt DATE,
+				tmstmp TIMESTAMP,
+			    )`)
+
+		idValue := time.Now().Unix()
+
+		dbt.mustExec(`INSERT INTO `+dbt.tableName+` (id, int, tint, sint, bint, num, dec, re, flt, dbl, bool, ch, var, bin, varbin, dt, tmstmp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+			idValue,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+		)
+
+		rows := dbt.mustQuery("SELECT * FROM "+dbt.tableName+" WHERE id = ?", idValue)
+		defer rows.Close()
+
+		var (
+			id      int64
+			integer sql.NullInt64
+			tint    sql.NullInt64
+			sint    sql.NullInt64
+			bint    sql.NullInt64
+			num     sql.NullString
+			dec     sql.NullString
+			re      sql.NullFloat64
+			flt     sql.NullFloat64
+			dbl     sql.NullFloat64
+			boolean sql.NullBool
+			ch      sql.NullString
+			varchar sql.NullString
+			bin     *[]byte
+			varbin  *[]byte
+			dt      *time.Time
+			tmstmp  *time.Time
+		)
+
+		for rows.Next() {
+			err := rows.Scan(&id, &integer, &tint, &sint, &bint, &num, &dec, &re, &flt, &dbl, &boolean, &ch, &varchar, &bin, &varbin, &dt, &tmstmp)
+
+			if err != nil {
+				dbt.Fatal(err)
+			}
+		}
+
+		comparisons := []struct {
+			result   interface{}
+			expected interface{}
+		}{
+			{integer, sql.NullInt64{}},
+			{tint, sql.NullInt64{}},
+			{sint, sql.NullInt64{}},
+			{bint, sql.NullInt64{}},
+			{num, sql.NullString{}},
+			{dec, sql.NullString{}},
+			{re, sql.NullFloat64{}},
+			{flt, sql.NullFloat64{}},
+			{dbl, sql.NullFloat64{}},
+			{boolean, sql.NullBool{}},
+			{ch, sql.NullString{}},
+			{varchar, sql.NullString{}},
+			{bin, (*[]byte)(nil)},
+			{varbin, (*[]byte)(nil)},
+			{dt, (*time.Time)(nil)},
+			{tmstmp, (*time.Time)(nil)},
+		}
+
+		for i, tt := range comparisons {
+
+			if v, ok := tt.expected.(time.Time); ok {
+
+				if !v.Equal(tt.result.(time.Time)) {
+					dbt.Fatalf("Expected %v for case %d, got %v.", tt.expected, i, tt.result)
+				}
+
+			} else if v, ok := tt.expected.([]byte); ok {
+
+				if !bytes.Equal(v, tt.result.([]byte)) {
+					dbt.Fatalf("Expected %v for case %d, got %v.", tt.expected, i, tt.result)
+				}
+
+			} else if tt.expected != tt.result {
+				dbt.Errorf("Expected %v for case %d, got %v.", tt.expected, i, tt.result)
+			}
+		}
+	})
+}
+
 // TODO: Test case commented out due to CALCITE-1951
 /*func TestHSQLDBLocations(t *testing.T) {
 
diff --git a/driver_phoenix_test.go b/driver_phoenix_test.go
index f004413..90892e2 100644
--- a/driver_phoenix_test.go
+++ b/driver_phoenix_test.go
@@ -93,7 +93,7 @@ func TestPhoenixZeroValues(t *testing.T) {
 			var i int
 			var flt float64
 			var b bool
-			var s string
+			var s sql.NullString
 
 			err := rows.Scan(&i, &flt, &b, &s)
 
@@ -113,8 +113,8 @@ func TestPhoenixZeroValues(t *testing.T) {
 				dbt.Fatalf("Boolean should be false, got %v", b)
 			}
 
-			if s != "" {
-				dbt.Fatalf("String should be \"\", got %v", s)
+			if val, _ := s.Value(); val != nil {
+				dbt.Fatalf("String should be nil, got %v", s)
 			}
 		}
 
@@ -301,6 +301,344 @@ func TestPhoenixDataTypes(t *testing.T) {
 	})
 }
 
+func TestPhoenixSQLNullTypes(t *testing.T) {
+
+	skipTestIfNotPhoenix(t)
+
+	runTests(t, dsn, func(dbt *DBTest) {
+
+		// Create and seed table
+		dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` (
+				id INTEGER PRIMARY KEY,
+				int INTEGER,
+				uint UNSIGNED_INT,
+				bint BIGINT,
+				ulong UNSIGNED_LONG,
+				tint TINYINT,
+				utint UNSIGNED_TINYINT,
+				sint SMALLINT,
+				usint UNSIGNED_SMALLINT,
+				flt FLOAT,
+				uflt UNSIGNED_FLOAT,
+				dbl DOUBLE,
+				udbl UNSIGNED_DOUBLE,
+				dec DECIMAL,
+				bool BOOLEAN,
+				tm TIME,
+				dt DATE,
+				tmstmp TIMESTAMP,
+				utm UNSIGNED_TIME,
+				udt UNSIGNED_DATE,
+				utmstmp UNSIGNED_TIMESTAMP,
+				var VARCHAR,
+				ch CHAR(3),
+				bin BINARY(20),
+				varbin VARBINARY
+			    ) TRANSACTIONAL=false`)
+
+		var (
+			idValue                  = time.Now().Unix()
+			integerValue             = sql.NullInt64{}
+			uintegerValue            = sql.NullInt64{}
+			bintValue                = sql.NullInt64{}
+			ulongValue               = sql.NullInt64{}
+			tintValue                = sql.NullInt64{}
+			utintValue               = sql.NullInt64{}
+			sintValue                = sql.NullInt64{}
+			usintValue               = sql.NullInt64{}
+			fltValue                 = sql.NullFloat64{}
+			ufltValue                = sql.NullFloat64{}
+			dblValue                 = sql.NullFloat64{}
+			udblValue                = sql.NullFloat64{}
+			decValue                 = sql.NullString{}
+			booleanValue             = sql.NullBool{}
+			tmValue       *time.Time = nil
+			dtValue       *time.Time = nil
+			tmstmpValue   *time.Time = nil
+			utmValue      *time.Time = nil
+			udtValue      *time.Time = nil
+			utmstmpValue  *time.Time = nil
+			varcharValue             = sql.NullString{}
+			chValue                  = sql.NullString{}
+			binValue      *[]byte    = nil
+			varbinValue   *[]byte    = nil
+		)
+
+		dbt.mustExec(`UPSERT INTO `+dbt.tableName+` VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+			idValue,
+			integerValue,
+			uintegerValue,
+			bintValue,
+			ulongValue,
+			tintValue,
+			utintValue,
+			sintValue,
+			usintValue,
+			fltValue,
+			ufltValue,
+			dblValue,
+			udblValue,
+			decValue,
+			booleanValue,
+			tmValue,
+			dtValue,
+			tmstmpValue,
+			utmValue,
+			udtValue,
+			utmstmpValue,
+			varcharValue,
+			chValue,
+			binValue,
+			varbinValue,
+		)
+
+		rows := dbt.mustQuery("SELECT * FROM "+dbt.tableName+" WHERE id = ?", idValue)
+		defer rows.Close()
+
+		var (
+			id       int64
+			integer  sql.NullInt64
+			uinteger sql.NullInt64
+			bint     sql.NullInt64
+			ulong    sql.NullInt64
+			tint     sql.NullInt64
+			utint    sql.NullInt64
+			sint     sql.NullInt64
+			usint    sql.NullInt64
+			flt      sql.NullFloat64
+			uflt     sql.NullFloat64
+			dbl      sql.NullFloat64
+			udbl     sql.NullFloat64
+			dec      sql.NullString
+			boolean  sql.NullBool
+			tm       *time.Time
+			dt       *time.Time
+			tmstmp   *time.Time
+			utm      *time.Time
+			udt      *time.Time
+			utmstmp  *time.Time
+			varchar  sql.NullString
+			ch       sql.NullString
+			bin      *[]byte
+			varbin   *[]byte
+		)
+
+		for rows.Next() {
+
+			err := rows.Scan(&id, &integer, &uinteger, &bint, &ulong, &tint, &utint, &sint, &usint, &flt, &uflt, &dbl, &udbl, &dec, &boolean, &tm, &dt, &tmstmp, &utm, &udt, &utmstmp, &varchar, &ch, &bin, &varbin)
+
+			if err != nil {
+				dbt.Fatal(err)
+			}
+		}
+
+		comparisons := []struct {
+			result   interface{}
+			expected interface{}
+		}{
+			{integer, integerValue},
+			{uinteger, uintegerValue},
+			{bint, bintValue},
+			{ulong, ulongValue},
+			{tint, tintValue},
+			{utint, utintValue},
+			{sint, sintValue},
+			{usint, usintValue},
+			{flt, fltValue},
+			{uflt, ufltValue},
+			{dbl, dblValue},
+			{udbl, udblValue},
+			{dec, decValue},
+			{boolean, booleanValue},
+			{tm, tmValue},
+			{dt, dtValue},
+			{tmstmp, tmstmpValue},
+			{utm, utmValue},
+			{udt, udtValue},
+			{utmstmp, utmstmpValue},
+			{varchar, varcharValue},
+			{ch, chValue},
+			{*bin, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
+			{varbin, varbinValue},
+		}
+
+		for i, tt := range comparisons {
+
+			if v, ok := tt.expected.(time.Time); ok {
+
+				if !v.Equal(tt.result.(time.Time)) {
+					dbt.Fatalf("Expected %v for case %d, got %v.", tt.expected, i, tt.result)
+				}
+
+			} else if v, ok := tt.expected.([]byte); ok {
+
+				if !bytes.Equal(v, tt.result.([]byte)) {
+					dbt.Fatalf("Expected %v for case %d, got %v.", tt.expected, i, tt.result)
+				}
+
+			} else if tt.expected != tt.result {
+				dbt.Errorf("Expected %v for case %d, got %v.", tt.expected, i, tt.result)
+			}
+		}
+	})
+}
+
+func TestPhoenixNulls(t *testing.T) {
+
+	skipTestIfNotPhoenix(t)
+
+	runTests(t, dsn, func(dbt *DBTest) {
+
+		// Create and seed table
+		dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` (
+				id INTEGER PRIMARY KEY,
+				int INTEGER,
+				uint UNSIGNED_INT,
+				bint BIGINT,
+				ulong UNSIGNED_LONG,
+				tint TINYINT,
+				utint UNSIGNED_TINYINT,
+				sint SMALLINT,
+				usint UNSIGNED_SMALLINT,
+				flt FLOAT,
+				uflt UNSIGNED_FLOAT,
+				dbl DOUBLE,
+				udbl UNSIGNED_DOUBLE,
+				dec DECIMAL,
+				bool BOOLEAN,
+				tm TIME,
+				dt DATE,
+				tmstmp TIMESTAMP,
+				utm UNSIGNED_TIME,
+				udt UNSIGNED_DATE,
+				utmstmp UNSIGNED_TIMESTAMP,
+				var VARCHAR,
+				ch CHAR(3),
+				bin BINARY(20),
+				varbin VARBINARY
+			    ) TRANSACTIONAL=false`)
+
+		idValue := time.Now().Unix()
+
+		dbt.mustExec(`UPSERT INTO `+dbt.tableName+` VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+			idValue,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+			nil,
+		)
+
+		rows := dbt.mustQuery("SELECT * FROM "+dbt.tableName+" WHERE id = ?", idValue)
+		defer rows.Close()
+
+		var (
+			id       int64
+			integer  sql.NullInt64
+			uinteger sql.NullInt64
+			bint     sql.NullInt64
+			ulong    sql.NullInt64
+			tint     sql.NullInt64
+			utint    sql.NullInt64
+			sint     sql.NullInt64
+			usint    sql.NullInt64
+			flt      sql.NullFloat64
+			uflt     sql.NullFloat64
+			dbl      sql.NullFloat64
+			udbl     sql.NullFloat64
+			dec      sql.NullString
+			boolean  sql.NullBool
+			tm       *time.Time
+			dt       *time.Time
+			tmstmp   *time.Time
+			utm      *time.Time
+			udt      *time.Time
+			utmstmp  *time.Time
+			varchar  sql.NullString
+			ch       sql.NullString
+			bin      *[]byte
+			varbin   *[]byte
+		)
+
+		for rows.Next() {
+
+			err := rows.Scan(&id, &integer, &uinteger, &bint, &ulong, &tint, &utint, &sint, &usint, &flt, &uflt, &dbl, &udbl, &dec, &boolean, &tm, &dt, &tmstmp, &utm, &udt, &utmstmp, &varchar, &ch, &bin, &varbin)
+
+			if err != nil {
+				dbt.Fatal(err)
+			}
+		}
+
+		comparisons := []struct {
+			result   interface{}
+			expected interface{}
+		}{
+			{integer, sql.NullInt64{}},
+			{uinteger, sql.NullInt64{}},
+			{bint, sql.NullInt64{}},
+			{ulong, sql.NullInt64{}},
+			{tint, sql.NullInt64{}},
+			{utint, sql.NullInt64{}},
+			{sint, sql.NullInt64{}},
+			{usint, sql.NullInt64{}},
+			{flt, sql.NullFloat64{}},
+			{uflt, sql.NullFloat64{}},
+			{dbl, sql.NullFloat64{}},
+			{udbl, sql.NullFloat64{}},
+			{dec, sql.NullString{}},
+			{boolean, sql.NullBool{}},
+			{tm, (*time.Time)(nil)},
+			{dt, (*time.Time)(nil)},
+			{tmstmp, (*time.Time)(nil)},
+			{utm, (*time.Time)(nil)},
+			{udt, (*time.Time)(nil)},
+			{utmstmp, (*time.Time)(nil)},
+			{varchar, sql.NullString{}},
+			{ch, sql.NullString{}},
+			{*bin, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
+			{varbin, (*[]byte)(nil)},
+		}
+
+		for i, tt := range comparisons {
+
+			if v, ok := tt.expected.(time.Time); ok {
+
+				if !v.Equal(tt.result.(time.Time)) {
+					dbt.Fatalf("Expected %v for case %d, got %v.", tt.expected, i, tt.result)
+				}
+
+			} else if v, ok := tt.expected.([]byte); ok {
+
+				if !bytes.Equal(v, tt.result.([]byte)) {
+					dbt.Fatalf("Expected %v for case %d, got %v.", tt.expected, i, tt.result)
+				}
+
+			} else if tt.expected != tt.result {
+				dbt.Errorf("Expected %v for case %d, got %v.", tt.expected, i, tt.result)
+			}
+		}
+	})
+}
+
 func TestPhoenixLocations(t *testing.T) {
 
 	skipTestIfNotPhoenix(t)
diff --git a/rows.go b/rows.go
index fa3eddb..8fe7f1e 100644
--- a/rows.go
+++ b/rows.go
@@ -180,7 +180,13 @@ func newRows(conn *conn, statementID uint32, resultSets []*message.ResultSetResp
 
 // typedValueToNative converts values from avatica's types to Go's native types
 func typedValueToNative(rep message.Rep, v *message.TypedValue, config *Config) interface{} {
+
+	if v.Type == message.Rep_NULL {
+		return nil
+	}
+
 	switch rep {
+
 	case message.Rep_BOOLEAN, message.Rep_PRIMITIVE_BOOLEAN:
 		return v.BoolValue
 
diff --git a/statement.go b/statement.go
index 3e3cab2..af2e2ce 100644
--- a/statement.go
+++ b/statement.go
@@ -156,6 +156,7 @@ func (s *stmt) parametersToTypedValues(vals []namedValue) []*message.TypedValue
 		typed := message.TypedValue{}
 		if val.Value == nil {
 			typed.Null = true
+			typed.Type = message.Rep_NULL
 		} else {
 
 			switch v := val.Value.(type) {