You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/25 01:26:49 UTC
[spark] branch master updated: [SPARK-41238][CONNECT][PYTHON] Support more built-in datatypes
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4006d195111 [SPARK-41238][CONNECT][PYTHON] Support more built-in datatypes
4006d195111 is described below
commit 4006d195111334b4b795680e547dea9dd0acda22
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Fri Nov 25 09:26:25 2022 +0800
[SPARK-41238][CONNECT][PYTHON] Support more built-in datatypes
### What changes were proposed in this pull request?
1, in the sever side, make `proto_datatype` <-> `catalyst_datatype` conversion support all the built-in sql datatypes;
2, in the client side, make `proto_datatype` <-> `pyspark_catalyst_datatype` conversion support [all the datatypes that are supported in pyspark now.](https://github.com/apache/spark/blob/master/python/pyspark/sql/types.py#L60-L83)
### Why are the changes needed?
right now, only `long`, `string`, `struct` are supported
```
grpc._channel._InactiveRpcError: <_InactiveRpcError of RPC that terminated with:
status = StatusCode.UNKNOWN
details = "Does not support convert float to connect proto types."
debug_error_string = "{"created":"1669206685.760099000","description":"Error received from peer ipv6:[::1]:15002","file":"src/core/lib/surface/call.cc","file_line":1064,"grpc_message":"Does not support convert float to connect proto types.","grpc_status":2}"
```
this PR make the schema and literal expr support more datatypes.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
added UT
Closes #38770 from zhengruifeng/connect_support_more_datatypes.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../main/protobuf/spark/connect/expressions.proto | 2 +-
.../src/main/protobuf/spark/connect/types.proto | 123 +++--
.../org/apache/spark/sql/connect/dsl/package.scala | 5 +-
.../connect/planner/DataTypeProtoConverter.scala | 281 ++++++++++--
.../connect/planner/SparkConnectServiceSuite.scala | 4 +-
python/pyspark/sql/connect/client.py | 108 ++++-
.../pyspark/sql/connect/proto/expressions_pb2.py | 66 +--
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 16 +-
python/pyspark/sql/connect/proto/types_pb2.py | 261 ++++++-----
python/pyspark/sql/connect/proto/types_pb2.pyi | 507 +++++++++++++--------
.../sql/tests/connect/test_connect_basic.py | 67 +++
11 files changed, 972 insertions(+), 468 deletions(-)
diff --git a/connector/connect/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
index ac5fe24d349..7ff06aeb196 100644
--- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
@@ -68,7 +68,7 @@ message Expression {
bytes uuid = 28;
DataType null = 29; // a typed null literal
List list = 30;
- DataType.List empty_list = 31;
+ DataType.Array empty_array = 31;
DataType.Map empty_map = 32;
UserDefined user_defined = 33;
}
diff --git a/connector/connect/src/main/protobuf/spark/connect/types.proto b/connector/connect/src/main/protobuf/spark/connect/types.proto
index ad043d85947..56dbf28665e 100644
--- a/connector/connect/src/main/protobuf/spark/connect/types.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/types.proto
@@ -26,31 +26,46 @@ option java_package = "org.apache.spark.connect.proto";
// itself but only describes it.
message DataType {
oneof kind {
- Boolean bool = 1;
- I8 i8 = 2;
- I16 i16 = 3;
- I32 i32 = 5;
- I64 i64 = 7;
- FP32 fp32 = 10;
- FP64 fp64 = 11;
- String string = 12;
- Binary binary = 13;
- Timestamp timestamp = 14;
- Date date = 16;
- Time time = 17;
- IntervalYear interval_year = 19;
- IntervalDay interval_day = 20;
- TimestampTZ timestamp_tz = 29;
- UUID uuid = 32;
-
- FixedChar fixed_char = 21;
- VarChar varchar = 22;
- FixedBinary fixed_binary = 23;
- Decimal decimal = 24;
-
- Struct struct = 25;
- List list = 27;
- Map map = 28;
+ NULL null = 1;
+
+ Binary binary = 2;
+
+ Boolean boolean = 3;
+
+ // Numeric types
+ Byte byte = 4;
+ Short short = 5;
+ Integer integer = 6;
+ Long long = 7;
+
+ Float float = 8;
+ Double double = 9;
+ Decimal decimal = 10;
+
+ // String types
+ String string = 11;
+ Char char = 12;
+ VarChar var_char = 13;
+
+ // Datatime types
+ Date date = 14;
+ Timestamp timestamp = 15;
+ TimestampNTZ timestamp_ntz = 16;
+
+ // Interval types
+ CalendarInterval calendar_interval = 17;
+ YearMonthInterval year_month_interval = 18;
+ DayTimeInterval day_time_interval = 19;
+
+ // Complex types
+ Array array = 20;
+ Struct struct = 21;
+ Map map = 22;
+
+
+ UUID uuid = 25;
+
+ FixedBinary fixed_binary = 26;
uint32 user_defined_type_reference = 31;
}
@@ -59,27 +74,27 @@ message DataType {
uint32 type_variation_reference = 1;
}
- message I8 {
+ message Byte {
uint32 type_variation_reference = 1;
}
- message I16 {
+ message Short {
uint32 type_variation_reference = 1;
}
- message I32 {
+ message Integer {
uint32 type_variation_reference = 1;
}
- message I64 {
+ message Long {
uint32 type_variation_reference = 1;
}
- message FP32 {
+ message Float {
uint32 type_variation_reference = 1;
}
- message FP64 {
+ message Double {
uint32 type_variation_reference = 1;
}
@@ -91,6 +106,10 @@ message DataType {
uint32 type_variation_reference = 1;
}
+ message NULL {
+ uint32 type_variation_reference = 1;
+ }
+
message Timestamp {
uint32 type_variation_reference = 1;
}
@@ -99,20 +118,24 @@ message DataType {
uint32 type_variation_reference = 1;
}
- message Time {
+ message TimestampNTZ {
uint32 type_variation_reference = 1;
}
- message TimestampTZ {
+ message CalendarInterval {
uint32 type_variation_reference = 1;
}
- message IntervalYear {
- uint32 type_variation_reference = 1;
+ message YearMonthInterval {
+ optional int32 start_field = 1;
+ optional int32 end_field = 2;
+ uint32 type_variation_reference = 3;
}
- message IntervalDay {
- uint32 type_variation_reference = 1;
+ message DayTimeInterval {
+ optional int32 start_field = 1;
+ optional int32 end_field = 2;
+ uint32 type_variation_reference = 3;
}
message UUID {
@@ -120,7 +143,7 @@ message DataType {
}
// Start compound types.
- message FixedChar {
+ message Char {
int32 length = 1;
uint32 type_variation_reference = 2;
}
@@ -136,14 +159,14 @@ message DataType {
}
message Decimal {
- int32 scale = 1;
- int32 precision = 2;
+ optional int32 scale = 1;
+ optional int32 precision = 2;
uint32 type_variation_reference = 3;
}
message StructField {
- DataType type = 1;
- string name = 2;
+ string name = 1;
+ DataType data_type = 2;
bool nullable = 3;
map<string, string> metadata = 4;
}
@@ -153,16 +176,16 @@ message DataType {
uint32 type_variation_reference = 2;
}
- message List {
- DataType DataType = 1;
- uint32 type_variation_reference = 2;
- bool element_nullable = 3;
+ message Array {
+ DataType element_type = 1;
+ bool contains_null = 2;
+ uint32 type_variation_reference = 3;
}
message Map {
- DataType key = 1;
- DataType value = 2;
- uint32 type_variation_reference = 3;
- bool value_nullable = 4;
+ DataType key_type = 1;
+ DataType value_type = 2;
+ bool value_contains_null = 3;
+ uint32 type_variation_reference = 4;
}
}
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 1827aa4e3c0..efebb67aeda 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -54,7 +54,8 @@ package object dsl {
for (attr <- attrs) {
val structField = DataType.StructField.newBuilder()
structField.setName(attr.getName)
- structField.setType(attr.getType)
+ structField.setDataType(attr.getType)
+ structField.setNullable(true)
structExpr.addFields(structField)
}
Expression.QualifiedAttribute
@@ -66,7 +67,7 @@ package object dsl {
/** Creates a new AttributeReference of type int */
def int: Expression.QualifiedAttribute = protoQualifiedAttrWithType(
- DataType.newBuilder().setI32(DataType.I32.newBuilder()).build())
+ DataType.newBuilder().setInteger(DataType.Integer.newBuilder()).build())
private def protoQualifiedAttrWithType(dataType: DataType): Expression.QualifiedAttribute =
Expression.QualifiedAttribute
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
index 088030b2dbc..0b8d79596c3 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
@@ -21,7 +21,7 @@ import scala.collection.convert.ImplicitConversions._
import org.apache.spark.connect.proto
import org.apache.spark.sql.SaveMode
-import org.apache.spark.sql.types.{DataType, IntegerType, LongType, MapType, StringType, StructField, StructType}
+import org.apache.spark.sql.types._
/**
* This object offers methods to convert to/from connect proto to catalyst types.
@@ -29,65 +29,264 @@ import org.apache.spark.sql.types.{DataType, IntegerType, LongType, MapType, Str
object DataTypeProtoConverter {
def toCatalystType(t: proto.DataType): DataType = {
t.getKindCase match {
- case proto.DataType.KindCase.I32 => IntegerType
+ case proto.DataType.KindCase.NULL => NullType
+
+ case proto.DataType.KindCase.BINARY => BinaryType
+
+ case proto.DataType.KindCase.BOOLEAN => BooleanType
+
+ case proto.DataType.KindCase.BYTE => ByteType
+ case proto.DataType.KindCase.SHORT => ShortType
+ case proto.DataType.KindCase.INTEGER => IntegerType
+ case proto.DataType.KindCase.LONG => LongType
+
+ case proto.DataType.KindCase.FLOAT => FloatType
+ case proto.DataType.KindCase.DOUBLE => DoubleType
+ case proto.DataType.KindCase.DECIMAL => toCatalystDecimalType(t.getDecimal)
+
case proto.DataType.KindCase.STRING => StringType
- case proto.DataType.KindCase.STRUCT => convertProtoDataTypeToCatalyst(t.getStruct)
- case proto.DataType.KindCase.MAP => convertProtoDataTypeToCatalyst(t.getMap)
+ case proto.DataType.KindCase.CHAR => CharType(t.getChar.getLength)
+ case proto.DataType.KindCase.VAR_CHAR => VarcharType(t.getVarChar.getLength)
+
+ case proto.DataType.KindCase.DATE => DateType
+ case proto.DataType.KindCase.TIMESTAMP => TimestampType
+ case proto.DataType.KindCase.TIMESTAMP_NTZ => TimestampNTZType
+
+ case proto.DataType.KindCase.CALENDAR_INTERVAL => CalendarIntervalType
+ case proto.DataType.KindCase.YEAR_MONTH_INTERVAL =>
+ toCatalystYearMonthIntervalType(t.getYearMonthInterval)
+ case proto.DataType.KindCase.DAY_TIME_INTERVAL =>
+ toCatalystDayTimeIntervalType(t.getDayTimeInterval)
+
+ case proto.DataType.KindCase.ARRAY => toCatalystArrayType(t.getArray)
+ case proto.DataType.KindCase.STRUCT => toCatalystStructType(t.getStruct)
+ case proto.DataType.KindCase.MAP => toCatalystMapType(t.getMap)
case _ =>
throw InvalidPlanInput(s"Does not support convert ${t.getKindCase} to catalyst types.")
}
}
- private def convertProtoDataTypeToCatalyst(t: proto.DataType.Struct): StructType = {
- // TODO: handle nullability
- val structFields =
- t.getFieldsList.map(f => StructField(f.getName, toCatalystType(f.getType))).toList
- StructType.apply(structFields)
+ private def toCatalystDecimalType(t: proto.DataType.Decimal): DecimalType = {
+ (t.hasPrecision, t.hasScale) match {
+ case (true, true) => DecimalType(t.getPrecision, t.getScale)
+ case (true, false) => new DecimalType(t.getPrecision)
+ case _ => new DecimalType()
+ }
+ }
+
+ private def toCatalystYearMonthIntervalType(t: proto.DataType.YearMonthInterval) = {
+ (t.hasStartField, t.hasEndField) match {
+ case (true, true) => YearMonthIntervalType(t.getStartField.toByte, t.getEndField.toByte)
+ case (true, false) => YearMonthIntervalType(t.getStartField.toByte)
+ case _ => YearMonthIntervalType()
+ }
}
- private def convertProtoDataTypeToCatalyst(t: proto.DataType.Map): MapType = {
- MapType(toCatalystType(t.getKey), toCatalystType(t.getValue))
+ private def toCatalystDayTimeIntervalType(t: proto.DataType.DayTimeInterval) = {
+ (t.hasStartField, t.hasEndField) match {
+ case (true, true) => DayTimeIntervalType(t.getStartField.toByte, t.getEndField.toByte)
+ case (true, false) => DayTimeIntervalType(t.getStartField.toByte)
+ case _ => DayTimeIntervalType()
+ }
+ }
+
+ private def toCatalystArrayType(t: proto.DataType.Array): ArrayType = {
+ ArrayType(toCatalystType(t.getElementType), t.getContainsNull)
+ }
+
+ private def toCatalystStructType(t: proto.DataType.Struct): StructType = {
+ // TODO: support metadata
+ val fields = t.getFieldsList.toSeq.map { protoField =>
+ StructField(
+ name = protoField.getName,
+ dataType = toCatalystType(protoField.getDataType),
+ nullable = protoField.getNullable,
+ metadata = Metadata.empty)
+ }
+ StructType.apply(fields)
+ }
+
+ private def toCatalystMapType(t: proto.DataType.Map): MapType = {
+ MapType(toCatalystType(t.getKeyType), toCatalystType(t.getValueType), t.getValueContainsNull)
}
def toConnectProtoType(t: DataType): proto.DataType = {
t match {
+ case NullType =>
+ proto.DataType
+ .newBuilder()
+ .setNull(proto.DataType.NULL.getDefaultInstance)
+ .build()
+
+ case BooleanType =>
+ proto.DataType
+ .newBuilder()
+ .setBoolean(proto.DataType.Boolean.getDefaultInstance)
+ .build()
+
+ case BinaryType =>
+ proto.DataType
+ .newBuilder()
+ .setBinary(proto.DataType.Binary.getDefaultInstance)
+ .build()
+
+ case ByteType =>
+ proto.DataType
+ .newBuilder()
+ .setByte(proto.DataType.Byte.getDefaultInstance)
+ .build()
+
+ case ShortType =>
+ proto.DataType
+ .newBuilder()
+ .setShort(proto.DataType.Short.getDefaultInstance)
+ .build()
+
case IntegerType =>
- proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build()
- case StringType =>
- proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build()
+ proto.DataType
+ .newBuilder()
+ .setInteger(proto.DataType.Integer.getDefaultInstance)
+ .build()
+
case LongType =>
- proto.DataType.newBuilder().setI64(proto.DataType.I64.getDefaultInstance).build()
- case struct: StructType =>
- toConnectProtoStructType(struct)
- case map: MapType => toConnectProtoMapType(map)
- case _ =>
- throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.")
- }
- }
+ proto.DataType
+ .newBuilder()
+ .setLong(proto.DataType.Long.getDefaultInstance)
+ .build()
- def toConnectProtoMapType(schema: MapType): proto.DataType = {
- proto.DataType
- .newBuilder()
- .setMap(
- proto.DataType.Map
+ case FloatType =>
+ proto.DataType
.newBuilder()
- .setKey(toConnectProtoType(schema.keyType))
- .setValue(toConnectProtoType(schema.valueType))
- .build())
- .build()
- }
+ .setFloat(proto.DataType.Float.getDefaultInstance)
+ .build()
+
+ case DoubleType =>
+ proto.DataType
+ .newBuilder()
+ .setDouble(proto.DataType.Double.getDefaultInstance)
+ .build()
+
+ case DecimalType.Fixed(precision, scale) =>
+ proto.DataType
+ .newBuilder()
+ .setDecimal(
+ proto.DataType.Decimal.newBuilder().setPrecision(precision).setScale(scale).build())
+ .build()
+
+ case StringType =>
+ proto.DataType
+ .newBuilder()
+ .setString(proto.DataType.String.getDefaultInstance)
+ .build()
+
+ case CharType(length) =>
+ proto.DataType
+ .newBuilder()
+ .setChar(proto.DataType.Char.newBuilder().setLength(length).build())
+ .build()
+
+ case VarcharType(length) =>
+ proto.DataType
+ .newBuilder()
+ .setVarChar(proto.DataType.VarChar.newBuilder().setLength(length).build())
+ .build()
+
+ case DateType =>
+ proto.DataType
+ .newBuilder()
+ .setDate(proto.DataType.Date.getDefaultInstance)
+ .build()
- def toConnectProtoStructType(schema: StructType): proto.DataType = {
- val struct = proto.DataType.Struct.newBuilder()
- for (structField <- schema.fields) {
- struct.addFields(
- proto.DataType.StructField
+ case TimestampType =>
+ proto.DataType
.newBuilder()
- .setName(structField.name)
- .setType(toConnectProtoType(structField.dataType))
- .setNullable(structField.nullable))
+ .setTimestamp(proto.DataType.Timestamp.getDefaultInstance)
+ .build()
+
+ case TimestampNTZType =>
+ proto.DataType
+ .newBuilder()
+ .setTimestampNtz(proto.DataType.TimestampNTZ.getDefaultInstance)
+ .build()
+
+ case CalendarIntervalType =>
+ proto.DataType
+ .newBuilder()
+ .setCalendarInterval(proto.DataType.CalendarInterval.getDefaultInstance)
+ .build()
+
+ case YearMonthIntervalType(startField, endField) =>
+ proto.DataType
+ .newBuilder()
+ .setYearMonthInterval(
+ proto.DataType.YearMonthInterval
+ .newBuilder()
+ .setStartField(startField)
+ .setEndField(endField)
+ .build())
+ .build()
+
+ case DayTimeIntervalType(startField, endField) =>
+ proto.DataType
+ .newBuilder()
+ .setDayTimeInterval(
+ proto.DataType.DayTimeInterval
+ .newBuilder()
+ .setStartField(startField)
+ .setEndField(endField)
+ .build())
+ .build()
+
+ case ArrayType(elementType: DataType, containsNull: Boolean) =>
+ proto.DataType
+ .newBuilder()
+ .setArray(
+ proto.DataType.Array
+ .newBuilder()
+ .setElementType(toConnectProtoType(elementType))
+ .setContainsNull(containsNull)
+ .build())
+ .build()
+
+ case StructType(fields: Array[StructField]) =>
+ // TODO: support metadata
+ val protoFields = fields.toSeq.map {
+ case StructField(
+ name: String,
+ dataType: DataType,
+ nullable: Boolean,
+ metadata: Metadata) =>
+ proto.DataType.StructField
+ .newBuilder()
+ .setName(name)
+ .setDataType(toConnectProtoType(dataType))
+ .setNullable(nullable)
+ .build()
+ }
+ proto.DataType
+ .newBuilder()
+ .setStruct(
+ proto.DataType.Struct
+ .newBuilder()
+ .addAllFields(protoFields)
+ .build())
+ .build()
+
+ case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) =>
+ proto.DataType
+ .newBuilder()
+ .setMap(
+ proto.DataType.Map
+ .newBuilder()
+ .setKeyType(toConnectProtoType(keyType))
+ .setValueType(toConnectProtoType(valueType))
+ .setValueContainsNull(valueContainsNull)
+ .build())
+ .build()
+
+ case _ =>
+ throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.")
}
- proto.DataType.newBuilder().setStruct(struct).build()
}
def toSaveMode(mode: proto.WriteOperation.SaveMode): SaveMode = {
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 5f18b0d45c5..6ca3c2430c4 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -61,10 +61,10 @@ class SparkConnectServiceSuite extends SharedSparkSession {
assert(schema.getFieldsCount == 2)
assert(
schema.getFields(0).getName == "col1"
- && schema.getFields(0).getType.getKindCase == proto.DataType.KindCase.I32)
+ && schema.getFields(0).getDataType.getKindCase == proto.DataType.KindCase.INTEGER)
assert(
schema.getFields(1).getName == "col2"
- && schema.getFields(1).getType.getKindCase == proto.DataType.KindCase.STRING)
+ && schema.getFields(1).getDataType.getKindCase == proto.DataType.KindCase.STRING)
}
}
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 24d104a0418..eb2e2227fb9 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -32,7 +32,29 @@ from pyspark import cloudpickle
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.plan import SQL, Range
-from pyspark.sql.types import DataType, StructType, StructField, LongType, StringType
+from pyspark.sql.types import (
+ DataType,
+ ByteType,
+ ShortType,
+ IntegerType,
+ FloatType,
+ DateType,
+ TimestampType,
+ DayTimeIntervalType,
+ MapType,
+ StringType,
+ CharType,
+ VarcharType,
+ StructType,
+ StructField,
+ ArrayType,
+ DoubleType,
+ LongType,
+ DecimalType,
+ BinaryType,
+ BooleanType,
+ NullType,
+)
from typing import Iterable, Optional, Any, Union, List, Tuple, Dict
@@ -356,38 +378,78 @@ class RemoteSparkSession(object):
return self._execute_and_fetch(req)
def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType:
- if schema.HasField("struct"):
- structFields = []
- for proto_field in schema.struct.fields:
- structFields.append(
- StructField(
- proto_field.name,
- self._proto_schema_to_pyspark_schema(proto_field.type),
- proto_field.nullable,
- )
- )
- return StructType(structFields)
- elif schema.HasField("i64"):
+ if schema.HasField("null"):
+ return NullType()
+ elif schema.HasField("boolean"):
+ return BooleanType()
+ elif schema.HasField("binary"):
+ return BinaryType()
+ elif schema.HasField("byte"):
+ return ByteType()
+ elif schema.HasField("short"):
+ return ShortType()
+ elif schema.HasField("integer"):
+ return IntegerType()
+ elif schema.HasField("long"):
return LongType()
+ elif schema.HasField("float"):
+ return FloatType()
+ elif schema.HasField("double"):
+ return DoubleType()
+ elif schema.HasField("decimal"):
+ p = schema.decimal.precision if schema.decimal.HasField("precision") else 10
+ s = schema.decimal.scale if schema.decimal.HasField("scale") else 0
+ return DecimalType(precision=p, scale=s)
elif schema.HasField("string"):
return StringType()
+ elif schema.HasField("char"):
+ return CharType(schema.char.length)
+ elif schema.HasField("var_char"):
+ return VarcharType(schema.var_char.length)
+ elif schema.HasField("date"):
+ return DateType()
+ elif schema.HasField("timestamp"):
+ return TimestampType()
+ elif schema.HasField("day_time_interval"):
+ return DayTimeIntervalType()
+ elif schema.HasField("array"):
+ return ArrayType(
+ self._proto_schema_to_pyspark_schema(schema.array.element_type),
+ schema.array.contains_null,
+ )
+ elif schema.HasField("struct"):
+ fields = [
+ StructField(
+ f.name,
+ self._proto_schema_to_pyspark_schema(f.data_type),
+ f.nullable,
+ )
+ for f in schema.struct.fields
+ ]
+ return StructType(fields)
+ elif schema.HasField("map"):
+ return MapType(
+ self._proto_schema_to_pyspark_schema(schema.map.key_type),
+ self._proto_schema_to_pyspark_schema(schema.map.value_type),
+ schema.map.value_contains_null,
+ )
else:
- raise Exception("Only support long, string, struct conversion")
+ raise Exception(f"Unsupported data type {schema}")
def schema(self, plan: pb2.Plan) -> StructType:
proto_schema = self._analyze(plan).schema
# Server side should populate the struct field which is the schema.
assert proto_schema.HasField("struct")
- structFields = []
- for proto_field in proto_schema.struct.fields:
- structFields.append(
- StructField(
- proto_field.name,
- self._proto_schema_to_pyspark_schema(proto_field.type),
- proto_field.nullable,
- )
+
+ fields = [
+ StructField(
+ f.name,
+ self._proto_schema_to_pyspark_schema(f.data_type),
+ f.nullable,
)
- return StructType(structFields)
+ for f in proto_schema.struct.fields
+ ]
+ return StructType(fields)
def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str:
result = self._analyze(plan, explain_mode)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index c372df7d324..7435a54d7ec 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xf0\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFu [...]
+ b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xf3\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFu [...]
)
@@ -235,37 +235,37 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
_EXPRESSION._serialized_start = 105
- _EXPRESSION._serialized_end = 3161
+ _EXPRESSION._serialized_end = 3164
_EXPRESSION_LITERAL._serialized_start = 613
- _EXPRESSION_LITERAL._serialized_end = 2696
- _EXPRESSION_LITERAL_VARCHAR._serialized_start = 1923
- _EXPRESSION_LITERAL_VARCHAR._serialized_end = 1978
- _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1980
- _EXPRESSION_LITERAL_DECIMAL._serialized_end = 2063
- _EXPRESSION_LITERAL_MAP._serialized_start = 2066
- _EXPRESSION_LITERAL_MAP._serialized_end = 2272
- _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_start = 2152
- _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_end = 2272
- _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_start = 2274
- _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_end = 2341
- _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_start = 2343
- _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_end = 2446
- _EXPRESSION_LITERAL_STRUCT._serialized_start = 2448
- _EXPRESSION_LITERAL_STRUCT._serialized_end = 2515
- _EXPRESSION_LITERAL_LIST._serialized_start = 2517
- _EXPRESSION_LITERAL_LIST._serialized_end = 2582
- _EXPRESSION_LITERAL_USERDEFINED._serialized_start = 2584
- _EXPRESSION_LITERAL_USERDEFINED._serialized_end = 2680
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2698
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2768
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2770
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2869
- _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2871
- _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2921
- _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2923
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2939
- _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_start = 2941
- _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_end = 3026
- _EXPRESSION_ALIAS._serialized_start = 3028
- _EXPRESSION_ALIAS._serialized_end = 3148
+ _EXPRESSION_LITERAL._serialized_end = 2699
+ _EXPRESSION_LITERAL_VARCHAR._serialized_start = 1926
+ _EXPRESSION_LITERAL_VARCHAR._serialized_end = 1981
+ _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1983
+ _EXPRESSION_LITERAL_DECIMAL._serialized_end = 2066
+ _EXPRESSION_LITERAL_MAP._serialized_start = 2069
+ _EXPRESSION_LITERAL_MAP._serialized_end = 2275
+ _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_start = 2155
+ _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_end = 2275
+ _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_start = 2277
+ _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_end = 2344
+ _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_start = 2346
+ _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_end = 2449
+ _EXPRESSION_LITERAL_STRUCT._serialized_start = 2451
+ _EXPRESSION_LITERAL_STRUCT._serialized_end = 2518
+ _EXPRESSION_LITERAL_LIST._serialized_start = 2520
+ _EXPRESSION_LITERAL_LIST._serialized_end = 2585
+ _EXPRESSION_LITERAL_USERDEFINED._serialized_start = 2587
+ _EXPRESSION_LITERAL_USERDEFINED._serialized_end = 2683
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2701
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2771
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2773
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2872
+ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2874
+ _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2924
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2926
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2942
+ _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_start = 2944
+ _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_end = 3029
+ _EXPRESSION_ALIAS._serialized_start = 3031
+ _EXPRESSION_ALIAS._serialized_end = 3151
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index ea538b2ebec..05c8cbe6385 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -280,7 +280,7 @@ class Expression(google.protobuf.message.Message):
UUID_FIELD_NUMBER: builtins.int
NULL_FIELD_NUMBER: builtins.int
LIST_FIELD_NUMBER: builtins.int
- EMPTY_LIST_FIELD_NUMBER: builtins.int
+ EMPTY_ARRAY_FIELD_NUMBER: builtins.int
EMPTY_MAP_FIELD_NUMBER: builtins.int
USER_DEFINED_FIELD_NUMBER: builtins.int
NULLABLE_FIELD_NUMBER: builtins.int
@@ -323,7 +323,7 @@ class Expression(google.protobuf.message.Message):
@property
def list(self) -> global___Expression.Literal.List: ...
@property
- def empty_list(self) -> pyspark.sql.connect.proto.types_pb2.DataType.List: ...
+ def empty_array(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Array: ...
@property
def empty_map(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Map: ...
@property
@@ -365,7 +365,7 @@ class Expression(google.protobuf.message.Message):
uuid: builtins.bytes = ...,
null: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
list: global___Expression.Literal.List | None = ...,
- empty_list: pyspark.sql.connect.proto.types_pb2.DataType.List | None = ...,
+ empty_array: pyspark.sql.connect.proto.types_pb2.DataType.Array | None = ...,
empty_map: pyspark.sql.connect.proto.types_pb2.DataType.Map | None = ...,
user_defined: global___Expression.Literal.UserDefined | None = ...,
nullable: builtins.bool = ...,
@@ -382,8 +382,8 @@ class Expression(google.protobuf.message.Message):
b"date",
"decimal",
b"decimal",
- "empty_list",
- b"empty_list",
+ "empty_array",
+ b"empty_array",
"empty_map",
b"empty_map",
"fixed_binary",
@@ -443,8 +443,8 @@ class Expression(google.protobuf.message.Message):
b"date",
"decimal",
b"decimal",
- "empty_list",
- b"empty_list",
+ "empty_array",
+ b"empty_array",
"empty_map",
b"empty_map",
"fixed_binary",
@@ -524,7 +524,7 @@ class Expression(google.protobuf.message.Message):
"uuid",
"null",
"list",
- "empty_list",
+ "empty_array",
"empty_map",
"user_defined",
] | None: ...
diff --git a/python/pyspark/sql/connect/proto/types_pb2.py b/python/pyspark/sql/connect/proto/types_pb2.py
index 3507b03602c..dd6567d96a2 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.py
+++ b/python/pyspark/sql/connect/proto/types_pb2.py
@@ -30,35 +30,36 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xc1\x1c\n\x08\x44\x61taType\x12\x35\n\x04\x62ool\x18\x01 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x04\x62ool\x12,\n\x02i8\x18\x02 \x01(\x0b\x32\x1a.spark.connect.DataType.I8H\x00R\x02i8\x12/\n\x03i16\x18\x03 \x01(\x0b\x32\x1b.spark.connect.DataType.I16H\x00R\x03i16\x12/\n\x03i32\x18\x05 \x01(\x0b\x32\x1b.spark.connect.DataType.I32H\x00R\x03i32\x12/\n\x03i64\x18\x07 \x01(\x0b\x32\x1b.spark.connect.DataType.I64H\x00R\x [...]
+ b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xce \n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01(\x0 [...]
)
_DATATYPE = DESCRIPTOR.message_types_by_name["DataType"]
_DATATYPE_BOOLEAN = _DATATYPE.nested_types_by_name["Boolean"]
-_DATATYPE_I8 = _DATATYPE.nested_types_by_name["I8"]
-_DATATYPE_I16 = _DATATYPE.nested_types_by_name["I16"]
-_DATATYPE_I32 = _DATATYPE.nested_types_by_name["I32"]
-_DATATYPE_I64 = _DATATYPE.nested_types_by_name["I64"]
-_DATATYPE_FP32 = _DATATYPE.nested_types_by_name["FP32"]
-_DATATYPE_FP64 = _DATATYPE.nested_types_by_name["FP64"]
+_DATATYPE_BYTE = _DATATYPE.nested_types_by_name["Byte"]
+_DATATYPE_SHORT = _DATATYPE.nested_types_by_name["Short"]
+_DATATYPE_INTEGER = _DATATYPE.nested_types_by_name["Integer"]
+_DATATYPE_LONG = _DATATYPE.nested_types_by_name["Long"]
+_DATATYPE_FLOAT = _DATATYPE.nested_types_by_name["Float"]
+_DATATYPE_DOUBLE = _DATATYPE.nested_types_by_name["Double"]
_DATATYPE_STRING = _DATATYPE.nested_types_by_name["String"]
_DATATYPE_BINARY = _DATATYPE.nested_types_by_name["Binary"]
+_DATATYPE_NULL = _DATATYPE.nested_types_by_name["NULL"]
_DATATYPE_TIMESTAMP = _DATATYPE.nested_types_by_name["Timestamp"]
_DATATYPE_DATE = _DATATYPE.nested_types_by_name["Date"]
-_DATATYPE_TIME = _DATATYPE.nested_types_by_name["Time"]
-_DATATYPE_TIMESTAMPTZ = _DATATYPE.nested_types_by_name["TimestampTZ"]
-_DATATYPE_INTERVALYEAR = _DATATYPE.nested_types_by_name["IntervalYear"]
-_DATATYPE_INTERVALDAY = _DATATYPE.nested_types_by_name["IntervalDay"]
+_DATATYPE_TIMESTAMPNTZ = _DATATYPE.nested_types_by_name["TimestampNTZ"]
+_DATATYPE_CALENDARINTERVAL = _DATATYPE.nested_types_by_name["CalendarInterval"]
+_DATATYPE_YEARMONTHINTERVAL = _DATATYPE.nested_types_by_name["YearMonthInterval"]
+_DATATYPE_DAYTIMEINTERVAL = _DATATYPE.nested_types_by_name["DayTimeInterval"]
_DATATYPE_UUID = _DATATYPE.nested_types_by_name["UUID"]
-_DATATYPE_FIXEDCHAR = _DATATYPE.nested_types_by_name["FixedChar"]
+_DATATYPE_CHAR = _DATATYPE.nested_types_by_name["Char"]
_DATATYPE_VARCHAR = _DATATYPE.nested_types_by_name["VarChar"]
_DATATYPE_FIXEDBINARY = _DATATYPE.nested_types_by_name["FixedBinary"]
_DATATYPE_DECIMAL = _DATATYPE.nested_types_by_name["Decimal"]
_DATATYPE_STRUCTFIELD = _DATATYPE.nested_types_by_name["StructField"]
_DATATYPE_STRUCTFIELD_METADATAENTRY = _DATATYPE_STRUCTFIELD.nested_types_by_name["MetadataEntry"]
_DATATYPE_STRUCT = _DATATYPE.nested_types_by_name["Struct"]
-_DATATYPE_LIST = _DATATYPE.nested_types_by_name["List"]
+_DATATYPE_ARRAY = _DATATYPE.nested_types_by_name["Array"]
_DATATYPE_MAP = _DATATYPE.nested_types_by_name["Map"]
DataType = _reflection.GeneratedProtocolMessageType(
"DataType",
@@ -73,58 +74,58 @@ DataType = _reflection.GeneratedProtocolMessageType(
# @@protoc_insertion_point(class_scope:spark.connect.DataType.Boolean)
},
),
- "I8": _reflection.GeneratedProtocolMessageType(
- "I8",
+ "Byte": _reflection.GeneratedProtocolMessageType(
+ "Byte",
(_message.Message,),
{
- "DESCRIPTOR": _DATATYPE_I8,
+ "DESCRIPTOR": _DATATYPE_BYTE,
"__module__": "spark.connect.types_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.DataType.I8)
+ # @@protoc_insertion_point(class_scope:spark.connect.DataType.Byte)
},
),
- "I16": _reflection.GeneratedProtocolMessageType(
- "I16",
+ "Short": _reflection.GeneratedProtocolMessageType(
+ "Short",
(_message.Message,),
{
- "DESCRIPTOR": _DATATYPE_I16,
+ "DESCRIPTOR": _DATATYPE_SHORT,
"__module__": "spark.connect.types_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.DataType.I16)
+ # @@protoc_insertion_point(class_scope:spark.connect.DataType.Short)
},
),
- "I32": _reflection.GeneratedProtocolMessageType(
- "I32",
+ "Integer": _reflection.GeneratedProtocolMessageType(
+ "Integer",
(_message.Message,),
{
- "DESCRIPTOR": _DATATYPE_I32,
+ "DESCRIPTOR": _DATATYPE_INTEGER,
"__module__": "spark.connect.types_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.DataType.I32)
+ # @@protoc_insertion_point(class_scope:spark.connect.DataType.Integer)
},
),
- "I64": _reflection.GeneratedProtocolMessageType(
- "I64",
+ "Long": _reflection.GeneratedProtocolMessageType(
+ "Long",
(_message.Message,),
{
- "DESCRIPTOR": _DATATYPE_I64,
+ "DESCRIPTOR": _DATATYPE_LONG,
"__module__": "spark.connect.types_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.DataType.I64)
+ # @@protoc_insertion_point(class_scope:spark.connect.DataType.Long)
},
),
- "FP32": _reflection.GeneratedProtocolMessageType(
- "FP32",
+ "Float": _reflection.GeneratedProtocolMessageType(
+ "Float",
(_message.Message,),
{
- "DESCRIPTOR": _DATATYPE_FP32,
+ "DESCRIPTOR": _DATATYPE_FLOAT,
"__module__": "spark.connect.types_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.DataType.FP32)
+ # @@protoc_insertion_point(class_scope:spark.connect.DataType.Float)
},
),
- "FP64": _reflection.GeneratedProtocolMessageType(
- "FP64",
+ "Double": _reflection.GeneratedProtocolMessageType(
+ "Double",
(_message.Message,),
{
- "DESCRIPTOR": _DATATYPE_FP64,
+ "DESCRIPTOR": _DATATYPE_DOUBLE,
"__module__": "spark.connect.types_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.DataType.FP64)
+ # @@protoc_insertion_point(class_scope:spark.connect.DataType.Double)
},
),
"String": _reflection.GeneratedProtocolMessageType(
@@ -145,6 +146,15 @@ DataType = _reflection.GeneratedProtocolMessageType(
# @@protoc_insertion_point(class_scope:spark.connect.DataType.Binary)
},
),
+ "NULL": _reflection.GeneratedProtocolMessageType(
+ "NULL",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _DATATYPE_NULL,
+ "__module__": "spark.connect.types_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.DataType.NULL)
+ },
+ ),
"Timestamp": _reflection.GeneratedProtocolMessageType(
"Timestamp",
(_message.Message,),
@@ -163,40 +173,40 @@ DataType = _reflection.GeneratedProtocolMessageType(
# @@protoc_insertion_point(class_scope:spark.connect.DataType.Date)
},
),
- "Time": _reflection.GeneratedProtocolMessageType(
- "Time",
+ "TimestampNTZ": _reflection.GeneratedProtocolMessageType(
+ "TimestampNTZ",
(_message.Message,),
{
- "DESCRIPTOR": _DATATYPE_TIME,
+ "DESCRIPTOR": _DATATYPE_TIMESTAMPNTZ,
"__module__": "spark.connect.types_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.DataType.Time)
+ # @@protoc_insertion_point(class_scope:spark.connect.DataType.TimestampNTZ)
},
),
- "TimestampTZ": _reflection.GeneratedProtocolMessageType(
- "TimestampTZ",
+ "CalendarInterval": _reflection.GeneratedProtocolMessageType(
+ "CalendarInterval",
(_message.Message,),
{
- "DESCRIPTOR": _DATATYPE_TIMESTAMPTZ,
+ "DESCRIPTOR": _DATATYPE_CALENDARINTERVAL,
"__module__": "spark.connect.types_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.DataType.TimestampTZ)
+ # @@protoc_insertion_point(class_scope:spark.connect.DataType.CalendarInterval)
},
),
- "IntervalYear": _reflection.GeneratedProtocolMessageType(
- "IntervalYear",
+ "YearMonthInterval": _reflection.GeneratedProtocolMessageType(
+ "YearMonthInterval",
(_message.Message,),
{
- "DESCRIPTOR": _DATATYPE_INTERVALYEAR,
+ "DESCRIPTOR": _DATATYPE_YEARMONTHINTERVAL,
"__module__": "spark.connect.types_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.DataType.IntervalYear)
+ # @@protoc_insertion_point(class_scope:spark.connect.DataType.YearMonthInterval)
},
),
- "IntervalDay": _reflection.GeneratedProtocolMessageType(
- "IntervalDay",
+ "DayTimeInterval": _reflection.GeneratedProtocolMessageType(
+ "DayTimeInterval",
(_message.Message,),
{
- "DESCRIPTOR": _DATATYPE_INTERVALDAY,
+ "DESCRIPTOR": _DATATYPE_DAYTIMEINTERVAL,
"__module__": "spark.connect.types_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.DataType.IntervalDay)
+ # @@protoc_insertion_point(class_scope:spark.connect.DataType.DayTimeInterval)
},
),
"UUID": _reflection.GeneratedProtocolMessageType(
@@ -208,13 +218,13 @@ DataType = _reflection.GeneratedProtocolMessageType(
# @@protoc_insertion_point(class_scope:spark.connect.DataType.UUID)
},
),
- "FixedChar": _reflection.GeneratedProtocolMessageType(
- "FixedChar",
+ "Char": _reflection.GeneratedProtocolMessageType(
+ "Char",
(_message.Message,),
{
- "DESCRIPTOR": _DATATYPE_FIXEDCHAR,
+ "DESCRIPTOR": _DATATYPE_CHAR,
"__module__": "spark.connect.types_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.DataType.FixedChar)
+ # @@protoc_insertion_point(class_scope:spark.connect.DataType.Char)
},
),
"VarChar": _reflection.GeneratedProtocolMessageType(
@@ -271,13 +281,13 @@ DataType = _reflection.GeneratedProtocolMessageType(
# @@protoc_insertion_point(class_scope:spark.connect.DataType.Struct)
},
),
- "List": _reflection.GeneratedProtocolMessageType(
- "List",
+ "Array": _reflection.GeneratedProtocolMessageType(
+ "Array",
(_message.Message,),
{
- "DESCRIPTOR": _DATATYPE_LIST,
+ "DESCRIPTOR": _DATATYPE_ARRAY,
"__module__": "spark.connect.types_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.DataType.List)
+ # @@protoc_insertion_point(class_scope:spark.connect.DataType.Array)
},
),
"Map": _reflection.GeneratedProtocolMessageType(
@@ -296,29 +306,30 @@ DataType = _reflection.GeneratedProtocolMessageType(
)
_sym_db.RegisterMessage(DataType)
_sym_db.RegisterMessage(DataType.Boolean)
-_sym_db.RegisterMessage(DataType.I8)
-_sym_db.RegisterMessage(DataType.I16)
-_sym_db.RegisterMessage(DataType.I32)
-_sym_db.RegisterMessage(DataType.I64)
-_sym_db.RegisterMessage(DataType.FP32)
-_sym_db.RegisterMessage(DataType.FP64)
+_sym_db.RegisterMessage(DataType.Byte)
+_sym_db.RegisterMessage(DataType.Short)
+_sym_db.RegisterMessage(DataType.Integer)
+_sym_db.RegisterMessage(DataType.Long)
+_sym_db.RegisterMessage(DataType.Float)
+_sym_db.RegisterMessage(DataType.Double)
_sym_db.RegisterMessage(DataType.String)
_sym_db.RegisterMessage(DataType.Binary)
+_sym_db.RegisterMessage(DataType.NULL)
_sym_db.RegisterMessage(DataType.Timestamp)
_sym_db.RegisterMessage(DataType.Date)
-_sym_db.RegisterMessage(DataType.Time)
-_sym_db.RegisterMessage(DataType.TimestampTZ)
-_sym_db.RegisterMessage(DataType.IntervalYear)
-_sym_db.RegisterMessage(DataType.IntervalDay)
+_sym_db.RegisterMessage(DataType.TimestampNTZ)
+_sym_db.RegisterMessage(DataType.CalendarInterval)
+_sym_db.RegisterMessage(DataType.YearMonthInterval)
+_sym_db.RegisterMessage(DataType.DayTimeInterval)
_sym_db.RegisterMessage(DataType.UUID)
-_sym_db.RegisterMessage(DataType.FixedChar)
+_sym_db.RegisterMessage(DataType.Char)
_sym_db.RegisterMessage(DataType.VarChar)
_sym_db.RegisterMessage(DataType.FixedBinary)
_sym_db.RegisterMessage(DataType.Decimal)
_sym_db.RegisterMessage(DataType.StructField)
_sym_db.RegisterMessage(DataType.StructField.MetadataEntry)
_sym_db.RegisterMessage(DataType.Struct)
-_sym_db.RegisterMessage(DataType.List)
+_sym_db.RegisterMessage(DataType.Array)
_sym_db.RegisterMessage(DataType.Map)
if _descriptor._USE_C_DESCRIPTORS == False:
@@ -328,55 +339,57 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_DATATYPE_STRUCTFIELD_METADATAENTRY._options = None
_DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_options = b"8\001"
_DATATYPE._serialized_start = 45
- _DATATYPE._serialized_end = 3694
- _DATATYPE_BOOLEAN._serialized_start = 1461
- _DATATYPE_BOOLEAN._serialized_end = 1528
- _DATATYPE_I8._serialized_start = 1530
- _DATATYPE_I8._serialized_end = 1592
- _DATATYPE_I16._serialized_start = 1594
- _DATATYPE_I16._serialized_end = 1657
- _DATATYPE_I32._serialized_start = 1659
- _DATATYPE_I32._serialized_end = 1722
- _DATATYPE_I64._serialized_start = 1724
- _DATATYPE_I64._serialized_end = 1787
- _DATATYPE_FP32._serialized_start = 1789
- _DATATYPE_FP32._serialized_end = 1853
- _DATATYPE_FP64._serialized_start = 1855
- _DATATYPE_FP64._serialized_end = 1919
- _DATATYPE_STRING._serialized_start = 1921
- _DATATYPE_STRING._serialized_end = 1987
- _DATATYPE_BINARY._serialized_start = 1989
- _DATATYPE_BINARY._serialized_end = 2055
- _DATATYPE_TIMESTAMP._serialized_start = 2057
- _DATATYPE_TIMESTAMP._serialized_end = 2126
- _DATATYPE_DATE._serialized_start = 2128
- _DATATYPE_DATE._serialized_end = 2192
- _DATATYPE_TIME._serialized_start = 2194
- _DATATYPE_TIME._serialized_end = 2258
- _DATATYPE_TIMESTAMPTZ._serialized_start = 2260
- _DATATYPE_TIMESTAMPTZ._serialized_end = 2331
- _DATATYPE_INTERVALYEAR._serialized_start = 2333
- _DATATYPE_INTERVALYEAR._serialized_end = 2405
- _DATATYPE_INTERVALDAY._serialized_start = 2407
- _DATATYPE_INTERVALDAY._serialized_end = 2478
- _DATATYPE_UUID._serialized_start = 2480
- _DATATYPE_UUID._serialized_end = 2544
- _DATATYPE_FIXEDCHAR._serialized_start = 2546
- _DATATYPE_FIXEDCHAR._serialized_end = 2639
- _DATATYPE_VARCHAR._serialized_start = 2641
- _DATATYPE_VARCHAR._serialized_end = 2732
- _DATATYPE_FIXEDBINARY._serialized_start = 2734
- _DATATYPE_FIXEDBINARY._serialized_end = 2829
- _DATATYPE_DECIMAL._serialized_start = 2831
- _DATATYPE_DECIMAL._serialized_end = 2950
- _DATATYPE_STRUCTFIELD._serialized_start = 2953
- _DATATYPE_STRUCTFIELD._serialized_end = 3199
- _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_start = 3140
- _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_end = 3199
- _DATATYPE_STRUCT._serialized_start = 3201
- _DATATYPE_STRUCT._serialized_end = 3328
- _DATATYPE_LIST._serialized_start = 3331
- _DATATYPE_LIST._serialized_end = 3491
- _DATATYPE_MAP._serialized_start = 3494
- _DATATYPE_MAP._serialized_end = 3686
+ _DATATYPE._serialized_end = 4219
+ _DATATYPE_BOOLEAN._serialized_start = 1612
+ _DATATYPE_BOOLEAN._serialized_end = 1679
+ _DATATYPE_BYTE._serialized_start = 1681
+ _DATATYPE_BYTE._serialized_end = 1745
+ _DATATYPE_SHORT._serialized_start = 1747
+ _DATATYPE_SHORT._serialized_end = 1812
+ _DATATYPE_INTEGER._serialized_start = 1814
+ _DATATYPE_INTEGER._serialized_end = 1881
+ _DATATYPE_LONG._serialized_start = 1883
+ _DATATYPE_LONG._serialized_end = 1947
+ _DATATYPE_FLOAT._serialized_start = 1949
+ _DATATYPE_FLOAT._serialized_end = 2014
+ _DATATYPE_DOUBLE._serialized_start = 2016
+ _DATATYPE_DOUBLE._serialized_end = 2082
+ _DATATYPE_STRING._serialized_start = 2084
+ _DATATYPE_STRING._serialized_end = 2150
+ _DATATYPE_BINARY._serialized_start = 2152
+ _DATATYPE_BINARY._serialized_end = 2218
+ _DATATYPE_NULL._serialized_start = 2220
+ _DATATYPE_NULL._serialized_end = 2284
+ _DATATYPE_TIMESTAMP._serialized_start = 2286
+ _DATATYPE_TIMESTAMP._serialized_end = 2355
+ _DATATYPE_DATE._serialized_start = 2357
+ _DATATYPE_DATE._serialized_end = 2421
+ _DATATYPE_TIMESTAMPNTZ._serialized_start = 2423
+ _DATATYPE_TIMESTAMPNTZ._serialized_end = 2495
+ _DATATYPE_CALENDARINTERVAL._serialized_start = 2497
+ _DATATYPE_CALENDARINTERVAL._serialized_end = 2573
+ _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2576
+ _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2755
+ _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2758
+ _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2935
+ _DATATYPE_UUID._serialized_start = 2937
+ _DATATYPE_UUID._serialized_end = 3001
+ _DATATYPE_CHAR._serialized_start = 3003
+ _DATATYPE_CHAR._serialized_end = 3091
+ _DATATYPE_VARCHAR._serialized_start = 3093
+ _DATATYPE_VARCHAR._serialized_end = 3184
+ _DATATYPE_FIXEDBINARY._serialized_start = 3186
+ _DATATYPE_FIXEDBINARY._serialized_end = 3281
+ _DATATYPE_DECIMAL._serialized_start = 3284
+ _DATATYPE_DECIMAL._serialized_end = 3437
+ _DATATYPE_STRUCTFIELD._serialized_start = 3440
+ _DATATYPE_STRUCTFIELD._serialized_end = 3695
+ _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_start = 3636
+ _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_end = 3695
+ _DATATYPE_STRUCT._serialized_start = 3697
+ _DATATYPE_STRUCT._serialized_end = 3824
+ _DATATYPE_ARRAY._serialized_start = 3827
+ _DATATYPE_ARRAY._serialized_end = 3989
+ _DATATYPE_MAP._serialized_start = 3992
+ _DATATYPE_MAP._serialized_end = 4211
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/types_pb2.pyi b/python/pyspark/sql/connect/proto/types_pb2.pyi
index 3bf36fc790c..647f625659b 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/types_pb2.pyi
@@ -39,6 +39,7 @@ import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
import sys
+import typing
if sys.version_info >= (3, 8):
import typing as typing_extensions
@@ -71,7 +72,7 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...
- class I8(google.protobuf.message.Message):
+ class Byte(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -88,7 +89,7 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...
- class I16(google.protobuf.message.Message):
+ class Short(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -105,7 +106,7 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...
- class I32(google.protobuf.message.Message):
+ class Integer(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -122,7 +123,7 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...
- class I64(google.protobuf.message.Message):
+ class Long(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -139,7 +140,7 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...
- class FP32(google.protobuf.message.Message):
+ class Float(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -156,7 +157,7 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...
- class FP64(google.protobuf.message.Message):
+ class Double(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -207,6 +208,23 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...
+ class NULL(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
+ type_variation_reference: builtins.int
+ def __init__(
+ self,
+ *,
+ type_variation_reference: builtins.int = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "type_variation_reference", b"type_variation_reference"
+ ],
+ ) -> None: ...
+
class Timestamp(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -241,7 +259,7 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...
- class Time(google.protobuf.message.Message):
+ class TimestampNTZ(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -258,7 +276,7 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...
- class TimestampTZ(google.protobuf.message.Message):
+ class CalendarInterval(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -275,39 +293,111 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...
- class IntervalYear(google.protobuf.message.Message):
+ class YearMonthInterval(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
+ START_FIELD_FIELD_NUMBER: builtins.int
+ END_FIELD_FIELD_NUMBER: builtins.int
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
+ start_field: builtins.int
+ end_field: builtins.int
type_variation_reference: builtins.int
def __init__(
self,
*,
+ start_field: builtins.int | None = ...,
+ end_field: builtins.int | None = ...,
type_variation_reference: builtins.int = ...,
) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_end_field",
+ b"_end_field",
+ "_start_field",
+ b"_start_field",
+ "end_field",
+ b"end_field",
+ "start_field",
+ b"start_field",
+ ],
+ ) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "type_variation_reference", b"type_variation_reference"
+ "_end_field",
+ b"_end_field",
+ "_start_field",
+ b"_start_field",
+ "end_field",
+ b"end_field",
+ "start_field",
+ b"start_field",
+ "type_variation_reference",
+ b"type_variation_reference",
],
) -> None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_end_field", b"_end_field"]
+ ) -> typing_extensions.Literal["end_field"] | None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_start_field", b"_start_field"]
+ ) -> typing_extensions.Literal["start_field"] | None: ...
- class IntervalDay(google.protobuf.message.Message):
+ class DayTimeInterval(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
+ START_FIELD_FIELD_NUMBER: builtins.int
+ END_FIELD_FIELD_NUMBER: builtins.int
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
+ start_field: builtins.int
+ end_field: builtins.int
type_variation_reference: builtins.int
def __init__(
self,
*,
+ start_field: builtins.int | None = ...,
+ end_field: builtins.int | None = ...,
type_variation_reference: builtins.int = ...,
) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_end_field",
+ b"_end_field",
+ "_start_field",
+ b"_start_field",
+ "end_field",
+ b"end_field",
+ "start_field",
+ b"start_field",
+ ],
+ ) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "type_variation_reference", b"type_variation_reference"
+ "_end_field",
+ b"_end_field",
+ "_start_field",
+ b"_start_field",
+ "end_field",
+ b"end_field",
+ "start_field",
+ b"start_field",
+ "type_variation_reference",
+ b"type_variation_reference",
],
) -> None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_end_field", b"_end_field"]
+ ) -> typing_extensions.Literal["end_field"] | None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_start_field", b"_start_field"]
+ ) -> typing_extensions.Literal["start_field"] | None: ...
class UUID(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -326,7 +416,7 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...
- class FixedChar(google.protobuf.message.Message):
+ class Char(google.protobuf.message.Message):
"""Start compound types."""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -400,13 +490,30 @@ class DataType(google.protobuf.message.Message):
def __init__(
self,
*,
- scale: builtins.int = ...,
- precision: builtins.int = ...,
+ scale: builtins.int | None = ...,
+ precision: builtins.int | None = ...,
type_variation_reference: builtins.int = ...,
) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_precision",
+ b"_precision",
+ "_scale",
+ b"_scale",
+ "precision",
+ b"precision",
+ "scale",
+ b"scale",
+ ],
+ ) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
+ "_precision",
+ b"_precision",
+ "_scale",
+ b"_scale",
"precision",
b"precision",
"scale",
@@ -415,6 +522,14 @@ class DataType(google.protobuf.message.Message):
b"type_variation_reference",
],
) -> None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_precision", b"_precision"]
+ ) -> typing_extensions.Literal["precision"] | None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_scale", b"_scale"]
+ ) -> typing_extensions.Literal["scale"] | None: ...
class StructField(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -436,13 +551,13 @@ class DataType(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
) -> None: ...
- TYPE_FIELD_NUMBER: builtins.int
NAME_FIELD_NUMBER: builtins.int
+ DATA_TYPE_FIELD_NUMBER: builtins.int
NULLABLE_FIELD_NUMBER: builtins.int
METADATA_FIELD_NUMBER: builtins.int
- @property
- def type(self) -> global___DataType: ...
name: builtins.str
+ @property
+ def data_type(self) -> global___DataType: ...
nullable: builtins.bool
@property
def metadata(
@@ -451,18 +566,25 @@ class DataType(google.protobuf.message.Message):
def __init__(
self,
*,
- type: global___DataType | None = ...,
name: builtins.str = ...,
+ data_type: global___DataType | None = ...,
nullable: builtins.bool = ...,
metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
) -> None: ...
def HasField(
- self, field_name: typing_extensions.Literal["type", b"type"]
+ self, field_name: typing_extensions.Literal["data_type", b"data_type"]
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "metadata", b"metadata", "name", b"name", "nullable", b"nullable", "type", b"type"
+ "data_type",
+ b"data_type",
+ "metadata",
+ b"metadata",
+ "name",
+ b"name",
+ "nullable",
+ b"nullable",
],
) -> None: ...
@@ -491,33 +613,33 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...
- class List(google.protobuf.message.Message):
+ class Array(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
- DATATYPE_FIELD_NUMBER: builtins.int
+ ELEMENT_TYPE_FIELD_NUMBER: builtins.int
+ CONTAINS_NULL_FIELD_NUMBER: builtins.int
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
- ELEMENT_NULLABLE_FIELD_NUMBER: builtins.int
@property
- def DataType(self) -> global___DataType: ...
+ def element_type(self) -> global___DataType: ...
+ contains_null: builtins.bool
type_variation_reference: builtins.int
- element_nullable: builtins.bool
def __init__(
self,
*,
- DataType: global___DataType | None = ...,
+ element_type: global___DataType | None = ...,
+ contains_null: builtins.bool = ...,
type_variation_reference: builtins.int = ...,
- element_nullable: builtins.bool = ...,
) -> None: ...
def HasField(
- self, field_name: typing_extensions.Literal["DataType", b"DataType"]
+ self, field_name: typing_extensions.Literal["element_type", b"element_type"]
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "DataType",
- b"DataType",
- "element_nullable",
- b"element_nullable",
+ "contains_null",
+ b"contains_null",
+ "element_type",
+ b"element_type",
"type_variation_reference",
b"type_variation_reference",
],
@@ -526,276 +648,293 @@ class DataType(google.protobuf.message.Message):
class Map(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
- KEY_FIELD_NUMBER: builtins.int
- VALUE_FIELD_NUMBER: builtins.int
+ KEY_TYPE_FIELD_NUMBER: builtins.int
+ VALUE_TYPE_FIELD_NUMBER: builtins.int
+ VALUE_CONTAINS_NULL_FIELD_NUMBER: builtins.int
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
- VALUE_NULLABLE_FIELD_NUMBER: builtins.int
@property
- def key(self) -> global___DataType: ...
+ def key_type(self) -> global___DataType: ...
@property
- def value(self) -> global___DataType: ...
+ def value_type(self) -> global___DataType: ...
+ value_contains_null: builtins.bool
type_variation_reference: builtins.int
- value_nullable: builtins.bool
def __init__(
self,
*,
- key: global___DataType | None = ...,
- value: global___DataType | None = ...,
+ key_type: global___DataType | None = ...,
+ value_type: global___DataType | None = ...,
+ value_contains_null: builtins.bool = ...,
type_variation_reference: builtins.int = ...,
- value_nullable: builtins.bool = ...,
) -> None: ...
def HasField(
- self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
+ self,
+ field_name: typing_extensions.Literal[
+ "key_type", b"key_type", "value_type", b"value_type"
+ ],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "key",
- b"key",
+ "key_type",
+ b"key_type",
"type_variation_reference",
b"type_variation_reference",
- "value",
- b"value",
- "value_nullable",
- b"value_nullable",
+ "value_contains_null",
+ b"value_contains_null",
+ "value_type",
+ b"value_type",
],
) -> None: ...
- BOOL_FIELD_NUMBER: builtins.int
- I8_FIELD_NUMBER: builtins.int
- I16_FIELD_NUMBER: builtins.int
- I32_FIELD_NUMBER: builtins.int
- I64_FIELD_NUMBER: builtins.int
- FP32_FIELD_NUMBER: builtins.int
- FP64_FIELD_NUMBER: builtins.int
- STRING_FIELD_NUMBER: builtins.int
+ NULL_FIELD_NUMBER: builtins.int
BINARY_FIELD_NUMBER: builtins.int
- TIMESTAMP_FIELD_NUMBER: builtins.int
- DATE_FIELD_NUMBER: builtins.int
- TIME_FIELD_NUMBER: builtins.int
- INTERVAL_YEAR_FIELD_NUMBER: builtins.int
- INTERVAL_DAY_FIELD_NUMBER: builtins.int
- TIMESTAMP_TZ_FIELD_NUMBER: builtins.int
- UUID_FIELD_NUMBER: builtins.int
- FIXED_CHAR_FIELD_NUMBER: builtins.int
- VARCHAR_FIELD_NUMBER: builtins.int
- FIXED_BINARY_FIELD_NUMBER: builtins.int
+ BOOLEAN_FIELD_NUMBER: builtins.int
+ BYTE_FIELD_NUMBER: builtins.int
+ SHORT_FIELD_NUMBER: builtins.int
+ INTEGER_FIELD_NUMBER: builtins.int
+ LONG_FIELD_NUMBER: builtins.int
+ FLOAT_FIELD_NUMBER: builtins.int
+ DOUBLE_FIELD_NUMBER: builtins.int
DECIMAL_FIELD_NUMBER: builtins.int
+ STRING_FIELD_NUMBER: builtins.int
+ CHAR_FIELD_NUMBER: builtins.int
+ VAR_CHAR_FIELD_NUMBER: builtins.int
+ DATE_FIELD_NUMBER: builtins.int
+ TIMESTAMP_FIELD_NUMBER: builtins.int
+ TIMESTAMP_NTZ_FIELD_NUMBER: builtins.int
+ CALENDAR_INTERVAL_FIELD_NUMBER: builtins.int
+ YEAR_MONTH_INTERVAL_FIELD_NUMBER: builtins.int
+ DAY_TIME_INTERVAL_FIELD_NUMBER: builtins.int
+ ARRAY_FIELD_NUMBER: builtins.int
STRUCT_FIELD_NUMBER: builtins.int
- LIST_FIELD_NUMBER: builtins.int
MAP_FIELD_NUMBER: builtins.int
+ UUID_FIELD_NUMBER: builtins.int
+ FIXED_BINARY_FIELD_NUMBER: builtins.int
USER_DEFINED_TYPE_REFERENCE_FIELD_NUMBER: builtins.int
@property
- def bool(self) -> global___DataType.Boolean: ...
+ def null(self) -> global___DataType.NULL: ...
@property
- def i8(self) -> global___DataType.I8: ...
+ def binary(self) -> global___DataType.Binary: ...
@property
- def i16(self) -> global___DataType.I16: ...
+ def boolean(self) -> global___DataType.Boolean: ...
@property
- def i32(self) -> global___DataType.I32: ...
+ def byte(self) -> global___DataType.Byte:
+ """Numeric types"""
@property
- def i64(self) -> global___DataType.I64: ...
+ def short(self) -> global___DataType.Short: ...
@property
- def fp32(self) -> global___DataType.FP32: ...
+ def integer(self) -> global___DataType.Integer: ...
@property
- def fp64(self) -> global___DataType.FP64: ...
+ def long(self) -> global___DataType.Long: ...
@property
- def string(self) -> global___DataType.String: ...
+ def float(self) -> global___DataType.Float: ...
@property
- def binary(self) -> global___DataType.Binary: ...
+ def double(self) -> global___DataType.Double: ...
@property
- def timestamp(self) -> global___DataType.Timestamp: ...
+ def decimal(self) -> global___DataType.Decimal: ...
@property
- def date(self) -> global___DataType.Date: ...
+ def string(self) -> global___DataType.String:
+ """String types"""
@property
- def time(self) -> global___DataType.Time: ...
+ def char(self) -> global___DataType.Char: ...
@property
- def interval_year(self) -> global___DataType.IntervalYear: ...
+ def var_char(self) -> global___DataType.VarChar: ...
@property
- def interval_day(self) -> global___DataType.IntervalDay: ...
+ def date(self) -> global___DataType.Date:
+ """Datatime types"""
@property
- def timestamp_tz(self) -> global___DataType.TimestampTZ: ...
+ def timestamp(self) -> global___DataType.Timestamp: ...
@property
- def uuid(self) -> global___DataType.UUID: ...
+ def timestamp_ntz(self) -> global___DataType.TimestampNTZ: ...
@property
- def fixed_char(self) -> global___DataType.FixedChar: ...
+ def calendar_interval(self) -> global___DataType.CalendarInterval:
+ """Interval types"""
@property
- def varchar(self) -> global___DataType.VarChar: ...
+ def year_month_interval(self) -> global___DataType.YearMonthInterval: ...
@property
- def fixed_binary(self) -> global___DataType.FixedBinary: ...
+ def day_time_interval(self) -> global___DataType.DayTimeInterval: ...
@property
- def decimal(self) -> global___DataType.Decimal: ...
+ def array(self) -> global___DataType.Array:
+ """Complex types"""
@property
def struct(self) -> global___DataType.Struct: ...
@property
- def list(self) -> global___DataType.List: ...
- @property
def map(self) -> global___DataType.Map: ...
+ @property
+ def uuid(self) -> global___DataType.UUID: ...
+ @property
+ def fixed_binary(self) -> global___DataType.FixedBinary: ...
user_defined_type_reference: builtins.int
def __init__(
self,
*,
- bool: global___DataType.Boolean | None = ...,
- i8: global___DataType.I8 | None = ...,
- i16: global___DataType.I16 | None = ...,
- i32: global___DataType.I32 | None = ...,
- i64: global___DataType.I64 | None = ...,
- fp32: global___DataType.FP32 | None = ...,
- fp64: global___DataType.FP64 | None = ...,
- string: global___DataType.String | None = ...,
+ null: global___DataType.NULL | None = ...,
binary: global___DataType.Binary | None = ...,
- timestamp: global___DataType.Timestamp | None = ...,
- date: global___DataType.Date | None = ...,
- time: global___DataType.Time | None = ...,
- interval_year: global___DataType.IntervalYear | None = ...,
- interval_day: global___DataType.IntervalDay | None = ...,
- timestamp_tz: global___DataType.TimestampTZ | None = ...,
- uuid: global___DataType.UUID | None = ...,
- fixed_char: global___DataType.FixedChar | None = ...,
- varchar: global___DataType.VarChar | None = ...,
- fixed_binary: global___DataType.FixedBinary | None = ...,
+ boolean: global___DataType.Boolean | None = ...,
+ byte: global___DataType.Byte | None = ...,
+ short: global___DataType.Short | None = ...,
+ integer: global___DataType.Integer | None = ...,
+ long: global___DataType.Long | None = ...,
+ float: global___DataType.Float | None = ...,
+ double: global___DataType.Double | None = ...,
decimal: global___DataType.Decimal | None = ...,
+ string: global___DataType.String | None = ...,
+ char: global___DataType.Char | None = ...,
+ var_char: global___DataType.VarChar | None = ...,
+ date: global___DataType.Date | None = ...,
+ timestamp: global___DataType.Timestamp | None = ...,
+ timestamp_ntz: global___DataType.TimestampNTZ | None = ...,
+ calendar_interval: global___DataType.CalendarInterval | None = ...,
+ year_month_interval: global___DataType.YearMonthInterval | None = ...,
+ day_time_interval: global___DataType.DayTimeInterval | None = ...,
+ array: global___DataType.Array | None = ...,
struct: global___DataType.Struct | None = ...,
- list: global___DataType.List | None = ...,
map: global___DataType.Map | None = ...,
+ uuid: global___DataType.UUID | None = ...,
+ fixed_binary: global___DataType.FixedBinary | None = ...,
user_defined_type_reference: builtins.int = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
+ "array",
+ b"array",
"binary",
b"binary",
- "bool",
- b"bool",
+ "boolean",
+ b"boolean",
+ "byte",
+ b"byte",
+ "calendar_interval",
+ b"calendar_interval",
+ "char",
+ b"char",
"date",
b"date",
+ "day_time_interval",
+ b"day_time_interval",
"decimal",
b"decimal",
+ "double",
+ b"double",
"fixed_binary",
b"fixed_binary",
- "fixed_char",
- b"fixed_char",
- "fp32",
- b"fp32",
- "fp64",
- b"fp64",
- "i16",
- b"i16",
- "i32",
- b"i32",
- "i64",
- b"i64",
- "i8",
- b"i8",
- "interval_day",
- b"interval_day",
- "interval_year",
- b"interval_year",
+ "float",
+ b"float",
+ "integer",
+ b"integer",
"kind",
b"kind",
- "list",
- b"list",
+ "long",
+ b"long",
"map",
b"map",
+ "null",
+ b"null",
+ "short",
+ b"short",
"string",
b"string",
"struct",
b"struct",
- "time",
- b"time",
"timestamp",
b"timestamp",
- "timestamp_tz",
- b"timestamp_tz",
+ "timestamp_ntz",
+ b"timestamp_ntz",
"user_defined_type_reference",
b"user_defined_type_reference",
"uuid",
b"uuid",
- "varchar",
- b"varchar",
+ "var_char",
+ b"var_char",
+ "year_month_interval",
+ b"year_month_interval",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
+ "array",
+ b"array",
"binary",
b"binary",
- "bool",
- b"bool",
+ "boolean",
+ b"boolean",
+ "byte",
+ b"byte",
+ "calendar_interval",
+ b"calendar_interval",
+ "char",
+ b"char",
"date",
b"date",
+ "day_time_interval",
+ b"day_time_interval",
"decimal",
b"decimal",
+ "double",
+ b"double",
"fixed_binary",
b"fixed_binary",
- "fixed_char",
- b"fixed_char",
- "fp32",
- b"fp32",
- "fp64",
- b"fp64",
- "i16",
- b"i16",
- "i32",
- b"i32",
- "i64",
- b"i64",
- "i8",
- b"i8",
- "interval_day",
- b"interval_day",
- "interval_year",
- b"interval_year",
+ "float",
+ b"float",
+ "integer",
+ b"integer",
"kind",
b"kind",
- "list",
- b"list",
+ "long",
+ b"long",
"map",
b"map",
+ "null",
+ b"null",
+ "short",
+ b"short",
"string",
b"string",
"struct",
b"struct",
- "time",
- b"time",
"timestamp",
b"timestamp",
- "timestamp_tz",
- b"timestamp_tz",
+ "timestamp_ntz",
+ b"timestamp_ntz",
"user_defined_type_reference",
b"user_defined_type_reference",
"uuid",
b"uuid",
- "varchar",
- b"varchar",
+ "var_char",
+ b"var_char",
+ "year_month_interval",
+ b"year_month_interval",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["kind", b"kind"]
) -> typing_extensions.Literal[
- "bool",
- "i8",
- "i16",
- "i32",
- "i64",
- "fp32",
- "fp64",
- "string",
+ "null",
"binary",
- "timestamp",
- "date",
- "time",
- "interval_year",
- "interval_day",
- "timestamp_tz",
- "uuid",
- "fixed_char",
- "varchar",
- "fixed_binary",
+ "boolean",
+ "byte",
+ "short",
+ "integer",
+ "long",
+ "float",
+ "double",
"decimal",
+ "string",
+ "char",
+ "var_char",
+ "date",
+ "timestamp",
+ "timestamp_ntz",
+ "calendar_interval",
+ "year_month_interval",
+ "day_time_interval",
+ "array",
"struct",
- "list",
"map",
+ "uuid",
+ "fixed_binary",
"user_defined_type_reference",
] | None: ...
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index fa27b609941..ab454c53491 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -144,6 +144,73 @@ class SparkConnectTests(SparkConnectSQLTestCase):
schema,
)
+ # test FloatType, DoubleType, DecimalType, StringType, BooleanType, NullType
+ query = """
+ SELECT * FROM VALUES
+ (float(1.0), double(1.0), 1.0, "1", true, NULL),
+ (float(2.0), double(2.0), 2.0, "2", false, NULL),
+ (float(3.0), double(3.0), NULL, "3", false, NULL)
+ AS tab(a, b, c, d, e, f)
+ """
+ self.assertEqual(
+ self.spark.sql(query).schema,
+ self.connect.sql(query).schema,
+ )
+
+ # test TimestampType, DateType
+ query = """
+ SELECT * FROM VALUES
+ (TIMESTAMP('2019-04-12 15:50:00'), DATE('2022-02-22')),
+ (TIMESTAMP('2019-04-12 15:50:00'), NULL),
+ (NULL, DATE('2022-02-22'))
+ AS tab(a, b)
+ """
+ self.assertEqual(
+ self.spark.sql(query).schema,
+ self.connect.sql(query).schema,
+ )
+
+ # test MapType
+ query = """
+ SELECT * FROM VALUES
+ (MAP('a', 'ab'), MAP('a', 'ab'), MAP(1, 2, 3, 4)),
+ (MAP('x', 'yz'), MAP('x', NULL), NULL),
+ (MAP('c', 'de'), NULL, MAP(-1, NULL, -3, -4))
+ AS tab(a, b, c)
+ """
+ self.assertEqual(
+ self.spark.sql(query).schema,
+ self.connect.sql(query).schema,
+ )
+
+ # test ArrayType
+ query = """
+ SELECT * FROM VALUES
+ (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3)),
+ (ARRAY('x', NULL), NULL, ARRAY(1, 3)),
+ (NULL, ARRAY(-1, -2, -3), Array())
+ AS tab(a, b, c)
+ """
+ self.assertEqual(
+ self.spark.sql(query).schema,
+ self.connect.sql(query).schema,
+ )
+
+ # test StructType
+ query = """
+ SELECT STRUCT(a, b, c, d), STRUCT(e, f, g), STRUCT(STRUCT(a, b), STRUCT(h)) FROM VALUES
+ (float(1.0), double(1.0), 1.0, "1", true, NULL, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)),
+ (float(2.0), double(2.0), 2.0, "2", false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)),
+ (float(3.0), double(3.0), NULL, "3", false, NULL, ARRAY(NULL), NULL)
+ AS tab(a, b, c, d, e, f, g, h)
+ """
+ # compare the __repr__() to ignore the metadata for now
+ # the metadata is not supported in Connect for now
+ self.assertEqual(
+ self.spark.sql(query).schema.__repr__(),
+ self.connect.sql(query).schema.__repr__(),
+ )
+
def test_simple_binary_expressions(self):
"""Test complex expression"""
df = self.connect.read.table(self.tbl_name)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org