You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/25 01:26:49 UTC

[spark] branch master updated: [SPARK-41238][CONNECT][PYTHON] Support more built-in datatypes

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4006d195111 [SPARK-41238][CONNECT][PYTHON] Support more built-in datatypes
4006d195111 is described below

commit 4006d195111334b4b795680e547dea9dd0acda22
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Fri Nov 25 09:26:25 2022 +0800

    [SPARK-41238][CONNECT][PYTHON] Support more built-in datatypes
    
    ### What changes were proposed in this pull request?
    1, in the sever side, make `proto_datatype` <-> `catalyst_datatype` conversion support all the built-in sql datatypes;
    
    2, in the client side, make `proto_datatype` <-> `pyspark_catalyst_datatype` conversion support [all the datatypes that are supported in pyspark now.](https://github.com/apache/spark/blob/master/python/pyspark/sql/types.py#L60-L83)
    
    ### Why are the changes needed?
    right now, only `long`, `string`, `struct` are supported
    
    ```
    grpc._channel._InactiveRpcError: <_InactiveRpcError of RPC that terminated with:
            status = StatusCode.UNKNOWN
            details = "Does not support convert float to connect proto types."
            debug_error_string = "{"created":"1669206685.760099000","description":"Error received from peer ipv6:[::1]:15002","file":"src/core/lib/surface/call.cc","file_line":1064,"grpc_message":"Does not support convert float to connect proto types.","grpc_status":2}"
    ```
    
    this PR make the schema and literal expr support more datatypes.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    added UT
    
    Closes #38770 from zhengruifeng/connect_support_more_datatypes.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../main/protobuf/spark/connect/expressions.proto  |   2 +-
 .../src/main/protobuf/spark/connect/types.proto    | 123 +++--
 .../org/apache/spark/sql/connect/dsl/package.scala |   5 +-
 .../connect/planner/DataTypeProtoConverter.scala   | 281 ++++++++++--
 .../connect/planner/SparkConnectServiceSuite.scala |   4 +-
 python/pyspark/sql/connect/client.py               | 108 ++++-
 .../pyspark/sql/connect/proto/expressions_pb2.py   |  66 +--
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  |  16 +-
 python/pyspark/sql/connect/proto/types_pb2.py      | 261 ++++++-----
 python/pyspark/sql/connect/proto/types_pb2.pyi     | 507 +++++++++++++--------
 .../sql/tests/connect/test_connect_basic.py        |  67 +++
 11 files changed, 972 insertions(+), 468 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
index ac5fe24d349..7ff06aeb196 100644
--- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
@@ -68,7 +68,7 @@ message Expression {
       bytes uuid = 28;
       DataType null = 29; // a typed null literal
       List list = 30;
-      DataType.List empty_list = 31;
+      DataType.Array empty_array = 31;
       DataType.Map empty_map = 32;
       UserDefined user_defined = 33;
     }
diff --git a/connector/connect/src/main/protobuf/spark/connect/types.proto b/connector/connect/src/main/protobuf/spark/connect/types.proto
index ad043d85947..56dbf28665e 100644
--- a/connector/connect/src/main/protobuf/spark/connect/types.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/types.proto
@@ -26,31 +26,46 @@ option java_package = "org.apache.spark.connect.proto";
 // itself but only describes it.
 message DataType {
   oneof kind {
-    Boolean bool = 1;
-    I8 i8 = 2;
-    I16 i16 = 3;
-    I32 i32 = 5;
-    I64 i64 = 7;
-    FP32 fp32 = 10;
-    FP64 fp64 = 11;
-    String string = 12;
-    Binary binary = 13;
-    Timestamp timestamp = 14;
-    Date date = 16;
-    Time time = 17;
-    IntervalYear interval_year = 19;
-    IntervalDay interval_day = 20;
-    TimestampTZ timestamp_tz = 29;
-    UUID uuid = 32;
-
-    FixedChar fixed_char = 21;
-    VarChar varchar = 22;
-    FixedBinary fixed_binary = 23;
-    Decimal decimal = 24;
-
-    Struct struct = 25;
-    List list = 27;
-    Map map = 28;
+    NULL null = 1;
+
+    Binary binary = 2;
+
+    Boolean boolean = 3;
+
+    // Numeric types
+    Byte byte = 4;
+    Short short = 5;
+    Integer integer = 6;
+    Long long = 7;
+
+    Float float = 8;
+    Double double = 9;
+    Decimal decimal = 10;
+
+    // String types
+    String string = 11;
+    Char char = 12;
+    VarChar var_char = 13;
+
+    // Datatime types
+    Date date = 14;
+    Timestamp timestamp = 15;
+    TimestampNTZ timestamp_ntz = 16;
+
+    // Interval types
+    CalendarInterval calendar_interval = 17;
+    YearMonthInterval year_month_interval = 18;
+    DayTimeInterval day_time_interval = 19;
+
+    // Complex types
+    Array array = 20;
+    Struct struct = 21;
+    Map map = 22;
+
+
+    UUID uuid = 25;
+
+    FixedBinary fixed_binary = 26;
 
     uint32 user_defined_type_reference = 31;
   }
@@ -59,27 +74,27 @@ message DataType {
     uint32 type_variation_reference = 1;
   }
 
-  message I8 {
+  message Byte {
     uint32 type_variation_reference = 1;
   }
 
-  message I16 {
+  message Short {
     uint32 type_variation_reference = 1;
   }
 
-  message I32 {
+  message Integer {
     uint32 type_variation_reference = 1;
   }
 
-  message I64 {
+  message Long {
     uint32 type_variation_reference = 1;
   }
 
-  message FP32 {
+  message Float {
     uint32 type_variation_reference = 1;
   }
 
-  message FP64 {
+  message Double {
     uint32 type_variation_reference = 1;
   }
 
@@ -91,6 +106,10 @@ message DataType {
     uint32 type_variation_reference = 1;
   }
 
+  message NULL {
+    uint32 type_variation_reference = 1;
+  }
+
   message Timestamp {
     uint32 type_variation_reference = 1;
   }
@@ -99,20 +118,24 @@ message DataType {
     uint32 type_variation_reference = 1;
   }
 
-  message Time {
+  message TimestampNTZ {
     uint32 type_variation_reference = 1;
   }
 
-  message TimestampTZ {
+  message CalendarInterval {
     uint32 type_variation_reference = 1;
   }
 
-  message IntervalYear {
-    uint32 type_variation_reference = 1;
+  message YearMonthInterval {
+    optional int32 start_field = 1;
+    optional int32 end_field = 2;
+    uint32 type_variation_reference = 3;
   }
 
-  message IntervalDay {
-    uint32 type_variation_reference = 1;
+  message DayTimeInterval {
+    optional int32 start_field = 1;
+    optional int32 end_field = 2;
+    uint32 type_variation_reference = 3;
   }
 
   message UUID {
@@ -120,7 +143,7 @@ message DataType {
   }
 
   // Start compound types.
-  message FixedChar {
+  message Char {
     int32 length = 1;
     uint32 type_variation_reference = 2;
   }
@@ -136,14 +159,14 @@ message DataType {
   }
 
   message Decimal {
-    int32 scale = 1;
-    int32 precision = 2;
+    optional int32 scale = 1;
+    optional int32 precision = 2;
     uint32 type_variation_reference = 3;
   }
 
   message StructField {
-    DataType type = 1;
-    string name = 2;
+    string name = 1;
+    DataType data_type = 2;
     bool nullable = 3;
     map<string, string> metadata = 4;
   }
@@ -153,16 +176,16 @@ message DataType {
     uint32 type_variation_reference = 2;
   }
 
-  message List {
-    DataType DataType = 1;
-    uint32 type_variation_reference = 2;
-    bool element_nullable = 3;
+  message Array {
+    DataType element_type = 1;
+    bool contains_null = 2;
+    uint32 type_variation_reference = 3;
   }
 
   message Map {
-    DataType key = 1;
-    DataType value = 2;
-    uint32 type_variation_reference = 3;
-    bool value_nullable = 4;
+    DataType key_type = 1;
+    DataType value_type = 2;
+    bool value_contains_null = 3;
+    uint32 type_variation_reference = 4;
   }
 }
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 1827aa4e3c0..efebb67aeda 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -54,7 +54,8 @@ package object dsl {
         for (attr <- attrs) {
           val structField = DataType.StructField.newBuilder()
           structField.setName(attr.getName)
-          structField.setType(attr.getType)
+          structField.setDataType(attr.getType)
+          structField.setNullable(true)
           structExpr.addFields(structField)
         }
         Expression.QualifiedAttribute
@@ -66,7 +67,7 @@ package object dsl {
 
       /** Creates a new AttributeReference of type int */
       def int: Expression.QualifiedAttribute = protoQualifiedAttrWithType(
-        DataType.newBuilder().setI32(DataType.I32.newBuilder()).build())
+        DataType.newBuilder().setInteger(DataType.Integer.newBuilder()).build())
 
       private def protoQualifiedAttrWithType(dataType: DataType): Expression.QualifiedAttribute =
         Expression.QualifiedAttribute
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
index 088030b2dbc..0b8d79596c3 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
@@ -21,7 +21,7 @@ import scala.collection.convert.ImplicitConversions._
 
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.SaveMode
-import org.apache.spark.sql.types.{DataType, IntegerType, LongType, MapType, StringType, StructField, StructType}
+import org.apache.spark.sql.types._
 
 /**
  * This object offers methods to convert to/from connect proto to catalyst types.
@@ -29,65 +29,264 @@ import org.apache.spark.sql.types.{DataType, IntegerType, LongType, MapType, Str
 object DataTypeProtoConverter {
   def toCatalystType(t: proto.DataType): DataType = {
     t.getKindCase match {
-      case proto.DataType.KindCase.I32 => IntegerType
+      case proto.DataType.KindCase.NULL => NullType
+
+      case proto.DataType.KindCase.BINARY => BinaryType
+
+      case proto.DataType.KindCase.BOOLEAN => BooleanType
+
+      case proto.DataType.KindCase.BYTE => ByteType
+      case proto.DataType.KindCase.SHORT => ShortType
+      case proto.DataType.KindCase.INTEGER => IntegerType
+      case proto.DataType.KindCase.LONG => LongType
+
+      case proto.DataType.KindCase.FLOAT => FloatType
+      case proto.DataType.KindCase.DOUBLE => DoubleType
+      case proto.DataType.KindCase.DECIMAL => toCatalystDecimalType(t.getDecimal)
+
       case proto.DataType.KindCase.STRING => StringType
-      case proto.DataType.KindCase.STRUCT => convertProtoDataTypeToCatalyst(t.getStruct)
-      case proto.DataType.KindCase.MAP => convertProtoDataTypeToCatalyst(t.getMap)
+      case proto.DataType.KindCase.CHAR => CharType(t.getChar.getLength)
+      case proto.DataType.KindCase.VAR_CHAR => VarcharType(t.getVarChar.getLength)
+
+      case proto.DataType.KindCase.DATE => DateType
+      case proto.DataType.KindCase.TIMESTAMP => TimestampType
+      case proto.DataType.KindCase.TIMESTAMP_NTZ => TimestampNTZType
+
+      case proto.DataType.KindCase.CALENDAR_INTERVAL => CalendarIntervalType
+      case proto.DataType.KindCase.YEAR_MONTH_INTERVAL =>
+        toCatalystYearMonthIntervalType(t.getYearMonthInterval)
+      case proto.DataType.KindCase.DAY_TIME_INTERVAL =>
+        toCatalystDayTimeIntervalType(t.getDayTimeInterval)
+
+      case proto.DataType.KindCase.ARRAY => toCatalystArrayType(t.getArray)
+      case proto.DataType.KindCase.STRUCT => toCatalystStructType(t.getStruct)
+      case proto.DataType.KindCase.MAP => toCatalystMapType(t.getMap)
       case _ =>
         throw InvalidPlanInput(s"Does not support convert ${t.getKindCase} to catalyst types.")
     }
   }
 
-  private def convertProtoDataTypeToCatalyst(t: proto.DataType.Struct): StructType = {
-    // TODO: handle nullability
-    val structFields =
-      t.getFieldsList.map(f => StructField(f.getName, toCatalystType(f.getType))).toList
-    StructType.apply(structFields)
+  private def toCatalystDecimalType(t: proto.DataType.Decimal): DecimalType = {
+    (t.hasPrecision, t.hasScale) match {
+      case (true, true) => DecimalType(t.getPrecision, t.getScale)
+      case (true, false) => new DecimalType(t.getPrecision)
+      case _ => new DecimalType()
+    }
+  }
+
+  private def toCatalystYearMonthIntervalType(t: proto.DataType.YearMonthInterval) = {
+    (t.hasStartField, t.hasEndField) match {
+      case (true, true) => YearMonthIntervalType(t.getStartField.toByte, t.getEndField.toByte)
+      case (true, false) => YearMonthIntervalType(t.getStartField.toByte)
+      case _ => YearMonthIntervalType()
+    }
   }
 
-  private def convertProtoDataTypeToCatalyst(t: proto.DataType.Map): MapType = {
-    MapType(toCatalystType(t.getKey), toCatalystType(t.getValue))
+  private def toCatalystDayTimeIntervalType(t: proto.DataType.DayTimeInterval) = {
+    (t.hasStartField, t.hasEndField) match {
+      case (true, true) => DayTimeIntervalType(t.getStartField.toByte, t.getEndField.toByte)
+      case (true, false) => DayTimeIntervalType(t.getStartField.toByte)
+      case _ => DayTimeIntervalType()
+    }
+  }
+
+  private def toCatalystArrayType(t: proto.DataType.Array): ArrayType = {
+    ArrayType(toCatalystType(t.getElementType), t.getContainsNull)
+  }
+
+  private def toCatalystStructType(t: proto.DataType.Struct): StructType = {
+    // TODO: support metadata
+    val fields = t.getFieldsList.toSeq.map { protoField =>
+      StructField(
+        name = protoField.getName,
+        dataType = toCatalystType(protoField.getDataType),
+        nullable = protoField.getNullable,
+        metadata = Metadata.empty)
+    }
+    StructType.apply(fields)
+  }
+
+  private def toCatalystMapType(t: proto.DataType.Map): MapType = {
+    MapType(toCatalystType(t.getKeyType), toCatalystType(t.getValueType), t.getValueContainsNull)
   }
 
   def toConnectProtoType(t: DataType): proto.DataType = {
     t match {
+      case NullType =>
+        proto.DataType
+          .newBuilder()
+          .setNull(proto.DataType.NULL.getDefaultInstance)
+          .build()
+
+      case BooleanType =>
+        proto.DataType
+          .newBuilder()
+          .setBoolean(proto.DataType.Boolean.getDefaultInstance)
+          .build()
+
+      case BinaryType =>
+        proto.DataType
+          .newBuilder()
+          .setBinary(proto.DataType.Binary.getDefaultInstance)
+          .build()
+
+      case ByteType =>
+        proto.DataType
+          .newBuilder()
+          .setByte(proto.DataType.Byte.getDefaultInstance)
+          .build()
+
+      case ShortType =>
+        proto.DataType
+          .newBuilder()
+          .setShort(proto.DataType.Short.getDefaultInstance)
+          .build()
+
       case IntegerType =>
-        proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build()
-      case StringType =>
-        proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build()
+        proto.DataType
+          .newBuilder()
+          .setInteger(proto.DataType.Integer.getDefaultInstance)
+          .build()
+
       case LongType =>
-        proto.DataType.newBuilder().setI64(proto.DataType.I64.getDefaultInstance).build()
-      case struct: StructType =>
-        toConnectProtoStructType(struct)
-      case map: MapType => toConnectProtoMapType(map)
-      case _ =>
-        throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.")
-    }
-  }
+        proto.DataType
+          .newBuilder()
+          .setLong(proto.DataType.Long.getDefaultInstance)
+          .build()
 
-  def toConnectProtoMapType(schema: MapType): proto.DataType = {
-    proto.DataType
-      .newBuilder()
-      .setMap(
-        proto.DataType.Map
+      case FloatType =>
+        proto.DataType
           .newBuilder()
-          .setKey(toConnectProtoType(schema.keyType))
-          .setValue(toConnectProtoType(schema.valueType))
-          .build())
-      .build()
-  }
+          .setFloat(proto.DataType.Float.getDefaultInstance)
+          .build()
+
+      case DoubleType =>
+        proto.DataType
+          .newBuilder()
+          .setDouble(proto.DataType.Double.getDefaultInstance)
+          .build()
+
+      case DecimalType.Fixed(precision, scale) =>
+        proto.DataType
+          .newBuilder()
+          .setDecimal(
+            proto.DataType.Decimal.newBuilder().setPrecision(precision).setScale(scale).build())
+          .build()
+
+      case StringType =>
+        proto.DataType
+          .newBuilder()
+          .setString(proto.DataType.String.getDefaultInstance)
+          .build()
+
+      case CharType(length) =>
+        proto.DataType
+          .newBuilder()
+          .setChar(proto.DataType.Char.newBuilder().setLength(length).build())
+          .build()
+
+      case VarcharType(length) =>
+        proto.DataType
+          .newBuilder()
+          .setVarChar(proto.DataType.VarChar.newBuilder().setLength(length).build())
+          .build()
+
+      case DateType =>
+        proto.DataType
+          .newBuilder()
+          .setDate(proto.DataType.Date.getDefaultInstance)
+          .build()
 
-  def toConnectProtoStructType(schema: StructType): proto.DataType = {
-    val struct = proto.DataType.Struct.newBuilder()
-    for (structField <- schema.fields) {
-      struct.addFields(
-        proto.DataType.StructField
+      case TimestampType =>
+        proto.DataType
           .newBuilder()
-          .setName(structField.name)
-          .setType(toConnectProtoType(structField.dataType))
-          .setNullable(structField.nullable))
+          .setTimestamp(proto.DataType.Timestamp.getDefaultInstance)
+          .build()
+
+      case TimestampNTZType =>
+        proto.DataType
+          .newBuilder()
+          .setTimestampNtz(proto.DataType.TimestampNTZ.getDefaultInstance)
+          .build()
+
+      case CalendarIntervalType =>
+        proto.DataType
+          .newBuilder()
+          .setCalendarInterval(proto.DataType.CalendarInterval.getDefaultInstance)
+          .build()
+
+      case YearMonthIntervalType(startField, endField) =>
+        proto.DataType
+          .newBuilder()
+          .setYearMonthInterval(
+            proto.DataType.YearMonthInterval
+              .newBuilder()
+              .setStartField(startField)
+              .setEndField(endField)
+              .build())
+          .build()
+
+      case DayTimeIntervalType(startField, endField) =>
+        proto.DataType
+          .newBuilder()
+          .setDayTimeInterval(
+            proto.DataType.DayTimeInterval
+              .newBuilder()
+              .setStartField(startField)
+              .setEndField(endField)
+              .build())
+          .build()
+
+      case ArrayType(elementType: DataType, containsNull: Boolean) =>
+        proto.DataType
+          .newBuilder()
+          .setArray(
+            proto.DataType.Array
+              .newBuilder()
+              .setElementType(toConnectProtoType(elementType))
+              .setContainsNull(containsNull)
+              .build())
+          .build()
+
+      case StructType(fields: Array[StructField]) =>
+        // TODO: support metadata
+        val protoFields = fields.toSeq.map {
+          case StructField(
+                name: String,
+                dataType: DataType,
+                nullable: Boolean,
+                metadata: Metadata) =>
+            proto.DataType.StructField
+              .newBuilder()
+              .setName(name)
+              .setDataType(toConnectProtoType(dataType))
+              .setNullable(nullable)
+              .build()
+        }
+        proto.DataType
+          .newBuilder()
+          .setStruct(
+            proto.DataType.Struct
+              .newBuilder()
+              .addAllFields(protoFields)
+              .build())
+          .build()
+
+      case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) =>
+        proto.DataType
+          .newBuilder()
+          .setMap(
+            proto.DataType.Map
+              .newBuilder()
+              .setKeyType(toConnectProtoType(keyType))
+              .setValueType(toConnectProtoType(valueType))
+              .setValueContainsNull(valueContainsNull)
+              .build())
+          .build()
+
+      case _ =>
+        throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.")
     }
-    proto.DataType.newBuilder().setStruct(struct).build()
   }
 
   def toSaveMode(mode: proto.WriteOperation.SaveMode): SaveMode = {
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 5f18b0d45c5..6ca3c2430c4 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -61,10 +61,10 @@ class SparkConnectServiceSuite extends SharedSparkSession {
       assert(schema.getFieldsCount == 2)
       assert(
         schema.getFields(0).getName == "col1"
-          && schema.getFields(0).getType.getKindCase == proto.DataType.KindCase.I32)
+          && schema.getFields(0).getDataType.getKindCase == proto.DataType.KindCase.INTEGER)
       assert(
         schema.getFields(1).getName == "col2"
-          && schema.getFields(1).getType.getKindCase == proto.DataType.KindCase.STRING)
+          && schema.getFields(1).getDataType.getKindCase == proto.DataType.KindCase.STRING)
     }
   }
 
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 24d104a0418..eb2e2227fb9 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -32,7 +32,29 @@ from pyspark import cloudpickle
 from pyspark.sql.connect.dataframe import DataFrame
 from pyspark.sql.connect.readwriter import DataFrameReader
 from pyspark.sql.connect.plan import SQL, Range
-from pyspark.sql.types import DataType, StructType, StructField, LongType, StringType
+from pyspark.sql.types import (
+    DataType,
+    ByteType,
+    ShortType,
+    IntegerType,
+    FloatType,
+    DateType,
+    TimestampType,
+    DayTimeIntervalType,
+    MapType,
+    StringType,
+    CharType,
+    VarcharType,
+    StructType,
+    StructField,
+    ArrayType,
+    DoubleType,
+    LongType,
+    DecimalType,
+    BinaryType,
+    BooleanType,
+    NullType,
+)
 
 from typing import Iterable, Optional, Any, Union, List, Tuple, Dict
 
@@ -356,38 +378,78 @@ class RemoteSparkSession(object):
         return self._execute_and_fetch(req)
 
     def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType:
-        if schema.HasField("struct"):
-            structFields = []
-            for proto_field in schema.struct.fields:
-                structFields.append(
-                    StructField(
-                        proto_field.name,
-                        self._proto_schema_to_pyspark_schema(proto_field.type),
-                        proto_field.nullable,
-                    )
-                )
-            return StructType(structFields)
-        elif schema.HasField("i64"):
+        if schema.HasField("null"):
+            return NullType()
+        elif schema.HasField("boolean"):
+            return BooleanType()
+        elif schema.HasField("binary"):
+            return BinaryType()
+        elif schema.HasField("byte"):
+            return ByteType()
+        elif schema.HasField("short"):
+            return ShortType()
+        elif schema.HasField("integer"):
+            return IntegerType()
+        elif schema.HasField("long"):
             return LongType()
+        elif schema.HasField("float"):
+            return FloatType()
+        elif schema.HasField("double"):
+            return DoubleType()
+        elif schema.HasField("decimal"):
+            p = schema.decimal.precision if schema.decimal.HasField("precision") else 10
+            s = schema.decimal.scale if schema.decimal.HasField("scale") else 0
+            return DecimalType(precision=p, scale=s)
         elif schema.HasField("string"):
             return StringType()
+        elif schema.HasField("char"):
+            return CharType(schema.char.length)
+        elif schema.HasField("var_char"):
+            return VarcharType(schema.var_char.length)
+        elif schema.HasField("date"):
+            return DateType()
+        elif schema.HasField("timestamp"):
+            return TimestampType()
+        elif schema.HasField("day_time_interval"):
+            return DayTimeIntervalType()
+        elif schema.HasField("array"):
+            return ArrayType(
+                self._proto_schema_to_pyspark_schema(schema.array.element_type),
+                schema.array.contains_null,
+            )
+        elif schema.HasField("struct"):
+            fields = [
+                StructField(
+                    f.name,
+                    self._proto_schema_to_pyspark_schema(f.data_type),
+                    f.nullable,
+                )
+                for f in schema.struct.fields
+            ]
+            return StructType(fields)
+        elif schema.HasField("map"):
+            return MapType(
+                self._proto_schema_to_pyspark_schema(schema.map.key_type),
+                self._proto_schema_to_pyspark_schema(schema.map.value_type),
+                schema.map.value_contains_null,
+            )
         else:
-            raise Exception("Only support long, string, struct conversion")
+            raise Exception(f"Unsupported data type {schema}")
 
     def schema(self, plan: pb2.Plan) -> StructType:
         proto_schema = self._analyze(plan).schema
         # Server side should populate the struct field which is the schema.
         assert proto_schema.HasField("struct")
-        structFields = []
-        for proto_field in proto_schema.struct.fields:
-            structFields.append(
-                StructField(
-                    proto_field.name,
-                    self._proto_schema_to_pyspark_schema(proto_field.type),
-                    proto_field.nullable,
-                )
+
+        fields = [
+            StructField(
+                f.name,
+                self._proto_schema_to_pyspark_schema(f.data_type),
+                f.nullable,
             )
-        return StructType(structFields)
+            for f in proto_schema.struct.fields
+        ]
+        return StructType(fields)
 
     def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str:
         result = self._analyze(plan, explain_mode)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index c372df7d324..7435a54d7ec 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xf0\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFu [...]
+    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xf3\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFu [...]
 )
 
 
@@ -235,37 +235,37 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
     _EXPRESSION._serialized_start = 105
-    _EXPRESSION._serialized_end = 3161
+    _EXPRESSION._serialized_end = 3164
     _EXPRESSION_LITERAL._serialized_start = 613
-    _EXPRESSION_LITERAL._serialized_end = 2696
-    _EXPRESSION_LITERAL_VARCHAR._serialized_start = 1923
-    _EXPRESSION_LITERAL_VARCHAR._serialized_end = 1978
-    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1980
-    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 2063
-    _EXPRESSION_LITERAL_MAP._serialized_start = 2066
-    _EXPRESSION_LITERAL_MAP._serialized_end = 2272
-    _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_start = 2152
-    _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_end = 2272
-    _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_start = 2274
-    _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_end = 2341
-    _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_start = 2343
-    _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_end = 2446
-    _EXPRESSION_LITERAL_STRUCT._serialized_start = 2448
-    _EXPRESSION_LITERAL_STRUCT._serialized_end = 2515
-    _EXPRESSION_LITERAL_LIST._serialized_start = 2517
-    _EXPRESSION_LITERAL_LIST._serialized_end = 2582
-    _EXPRESSION_LITERAL_USERDEFINED._serialized_start = 2584
-    _EXPRESSION_LITERAL_USERDEFINED._serialized_end = 2680
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2698
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2768
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2770
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2869
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2871
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2921
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2923
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2939
-    _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_start = 2941
-    _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_end = 3026
-    _EXPRESSION_ALIAS._serialized_start = 3028
-    _EXPRESSION_ALIAS._serialized_end = 3148
+    _EXPRESSION_LITERAL._serialized_end = 2699
+    _EXPRESSION_LITERAL_VARCHAR._serialized_start = 1926
+    _EXPRESSION_LITERAL_VARCHAR._serialized_end = 1981
+    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1983
+    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 2066
+    _EXPRESSION_LITERAL_MAP._serialized_start = 2069
+    _EXPRESSION_LITERAL_MAP._serialized_end = 2275
+    _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_start = 2155
+    _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_end = 2275
+    _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_start = 2277
+    _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_end = 2344
+    _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_start = 2346
+    _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_end = 2449
+    _EXPRESSION_LITERAL_STRUCT._serialized_start = 2451
+    _EXPRESSION_LITERAL_STRUCT._serialized_end = 2518
+    _EXPRESSION_LITERAL_LIST._serialized_start = 2520
+    _EXPRESSION_LITERAL_LIST._serialized_end = 2585
+    _EXPRESSION_LITERAL_USERDEFINED._serialized_start = 2587
+    _EXPRESSION_LITERAL_USERDEFINED._serialized_end = 2683
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2701
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2771
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2773
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2872
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2874
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2924
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2926
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2942
+    _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_start = 2944
+    _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_end = 3029
+    _EXPRESSION_ALIAS._serialized_start = 3031
+    _EXPRESSION_ALIAS._serialized_end = 3151
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index ea538b2ebec..05c8cbe6385 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -280,7 +280,7 @@ class Expression(google.protobuf.message.Message):
         UUID_FIELD_NUMBER: builtins.int
         NULL_FIELD_NUMBER: builtins.int
         LIST_FIELD_NUMBER: builtins.int
-        EMPTY_LIST_FIELD_NUMBER: builtins.int
+        EMPTY_ARRAY_FIELD_NUMBER: builtins.int
         EMPTY_MAP_FIELD_NUMBER: builtins.int
         USER_DEFINED_FIELD_NUMBER: builtins.int
         NULLABLE_FIELD_NUMBER: builtins.int
@@ -323,7 +323,7 @@ class Expression(google.protobuf.message.Message):
         @property
         def list(self) -> global___Expression.Literal.List: ...
         @property
-        def empty_list(self) -> pyspark.sql.connect.proto.types_pb2.DataType.List: ...
+        def empty_array(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Array: ...
         @property
         def empty_map(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Map: ...
         @property
@@ -365,7 +365,7 @@ class Expression(google.protobuf.message.Message):
             uuid: builtins.bytes = ...,
             null: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
             list: global___Expression.Literal.List | None = ...,
-            empty_list: pyspark.sql.connect.proto.types_pb2.DataType.List | None = ...,
+            empty_array: pyspark.sql.connect.proto.types_pb2.DataType.Array | None = ...,
             empty_map: pyspark.sql.connect.proto.types_pb2.DataType.Map | None = ...,
             user_defined: global___Expression.Literal.UserDefined | None = ...,
             nullable: builtins.bool = ...,
@@ -382,8 +382,8 @@ class Expression(google.protobuf.message.Message):
                 b"date",
                 "decimal",
                 b"decimal",
-                "empty_list",
-                b"empty_list",
+                "empty_array",
+                b"empty_array",
                 "empty_map",
                 b"empty_map",
                 "fixed_binary",
@@ -443,8 +443,8 @@ class Expression(google.protobuf.message.Message):
                 b"date",
                 "decimal",
                 b"decimal",
-                "empty_list",
-                b"empty_list",
+                "empty_array",
+                b"empty_array",
                 "empty_map",
                 b"empty_map",
                 "fixed_binary",
@@ -524,7 +524,7 @@ class Expression(google.protobuf.message.Message):
             "uuid",
             "null",
             "list",
-            "empty_list",
+            "empty_array",
             "empty_map",
             "user_defined",
         ] | None: ...
diff --git a/python/pyspark/sql/connect/proto/types_pb2.py b/python/pyspark/sql/connect/proto/types_pb2.py
index 3507b03602c..dd6567d96a2 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.py
+++ b/python/pyspark/sql/connect/proto/types_pb2.py
@@ -30,35 +30,36 @@ _sym_db = _symbol_database.Default()
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xc1\x1c\n\x08\x44\x61taType\x12\x35\n\x04\x62ool\x18\x01 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x04\x62ool\x12,\n\x02i8\x18\x02 \x01(\x0b\x32\x1a.spark.connect.DataType.I8H\x00R\x02i8\x12/\n\x03i16\x18\x03 \x01(\x0b\x32\x1b.spark.connect.DataType.I16H\x00R\x03i16\x12/\n\x03i32\x18\x05 \x01(\x0b\x32\x1b.spark.connect.DataType.I32H\x00R\x03i32\x12/\n\x03i64\x18\x07 \x01(\x0b\x32\x1b.spark.connect.DataType.I64H\x00R\x [...]
+    b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xce \n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01(\x0 [...]
 )
 
 
 _DATATYPE = DESCRIPTOR.message_types_by_name["DataType"]
 _DATATYPE_BOOLEAN = _DATATYPE.nested_types_by_name["Boolean"]
-_DATATYPE_I8 = _DATATYPE.nested_types_by_name["I8"]
-_DATATYPE_I16 = _DATATYPE.nested_types_by_name["I16"]
-_DATATYPE_I32 = _DATATYPE.nested_types_by_name["I32"]
-_DATATYPE_I64 = _DATATYPE.nested_types_by_name["I64"]
-_DATATYPE_FP32 = _DATATYPE.nested_types_by_name["FP32"]
-_DATATYPE_FP64 = _DATATYPE.nested_types_by_name["FP64"]
+_DATATYPE_BYTE = _DATATYPE.nested_types_by_name["Byte"]
+_DATATYPE_SHORT = _DATATYPE.nested_types_by_name["Short"]
+_DATATYPE_INTEGER = _DATATYPE.nested_types_by_name["Integer"]
+_DATATYPE_LONG = _DATATYPE.nested_types_by_name["Long"]
+_DATATYPE_FLOAT = _DATATYPE.nested_types_by_name["Float"]
+_DATATYPE_DOUBLE = _DATATYPE.nested_types_by_name["Double"]
 _DATATYPE_STRING = _DATATYPE.nested_types_by_name["String"]
 _DATATYPE_BINARY = _DATATYPE.nested_types_by_name["Binary"]
+_DATATYPE_NULL = _DATATYPE.nested_types_by_name["NULL"]
 _DATATYPE_TIMESTAMP = _DATATYPE.nested_types_by_name["Timestamp"]
 _DATATYPE_DATE = _DATATYPE.nested_types_by_name["Date"]
-_DATATYPE_TIME = _DATATYPE.nested_types_by_name["Time"]
-_DATATYPE_TIMESTAMPTZ = _DATATYPE.nested_types_by_name["TimestampTZ"]
-_DATATYPE_INTERVALYEAR = _DATATYPE.nested_types_by_name["IntervalYear"]
-_DATATYPE_INTERVALDAY = _DATATYPE.nested_types_by_name["IntervalDay"]
+_DATATYPE_TIMESTAMPNTZ = _DATATYPE.nested_types_by_name["TimestampNTZ"]
+_DATATYPE_CALENDARINTERVAL = _DATATYPE.nested_types_by_name["CalendarInterval"]
+_DATATYPE_YEARMONTHINTERVAL = _DATATYPE.nested_types_by_name["YearMonthInterval"]
+_DATATYPE_DAYTIMEINTERVAL = _DATATYPE.nested_types_by_name["DayTimeInterval"]
 _DATATYPE_UUID = _DATATYPE.nested_types_by_name["UUID"]
-_DATATYPE_FIXEDCHAR = _DATATYPE.nested_types_by_name["FixedChar"]
+_DATATYPE_CHAR = _DATATYPE.nested_types_by_name["Char"]
 _DATATYPE_VARCHAR = _DATATYPE.nested_types_by_name["VarChar"]
 _DATATYPE_FIXEDBINARY = _DATATYPE.nested_types_by_name["FixedBinary"]
 _DATATYPE_DECIMAL = _DATATYPE.nested_types_by_name["Decimal"]
 _DATATYPE_STRUCTFIELD = _DATATYPE.nested_types_by_name["StructField"]
 _DATATYPE_STRUCTFIELD_METADATAENTRY = _DATATYPE_STRUCTFIELD.nested_types_by_name["MetadataEntry"]
 _DATATYPE_STRUCT = _DATATYPE.nested_types_by_name["Struct"]
-_DATATYPE_LIST = _DATATYPE.nested_types_by_name["List"]
+_DATATYPE_ARRAY = _DATATYPE.nested_types_by_name["Array"]
 _DATATYPE_MAP = _DATATYPE.nested_types_by_name["Map"]
 DataType = _reflection.GeneratedProtocolMessageType(
     "DataType",
@@ -73,58 +74,58 @@ DataType = _reflection.GeneratedProtocolMessageType(
                 # @@protoc_insertion_point(class_scope:spark.connect.DataType.Boolean)
             },
         ),
-        "I8": _reflection.GeneratedProtocolMessageType(
-            "I8",
+        "Byte": _reflection.GeneratedProtocolMessageType(
+            "Byte",
             (_message.Message,),
             {
-                "DESCRIPTOR": _DATATYPE_I8,
+                "DESCRIPTOR": _DATATYPE_BYTE,
                 "__module__": "spark.connect.types_pb2"
-                # @@protoc_insertion_point(class_scope:spark.connect.DataType.I8)
+                # @@protoc_insertion_point(class_scope:spark.connect.DataType.Byte)
             },
         ),
-        "I16": _reflection.GeneratedProtocolMessageType(
-            "I16",
+        "Short": _reflection.GeneratedProtocolMessageType(
+            "Short",
             (_message.Message,),
             {
-                "DESCRIPTOR": _DATATYPE_I16,
+                "DESCRIPTOR": _DATATYPE_SHORT,
                 "__module__": "spark.connect.types_pb2"
-                # @@protoc_insertion_point(class_scope:spark.connect.DataType.I16)
+                # @@protoc_insertion_point(class_scope:spark.connect.DataType.Short)
             },
         ),
-        "I32": _reflection.GeneratedProtocolMessageType(
-            "I32",
+        "Integer": _reflection.GeneratedProtocolMessageType(
+            "Integer",
             (_message.Message,),
             {
-                "DESCRIPTOR": _DATATYPE_I32,
+                "DESCRIPTOR": _DATATYPE_INTEGER,
                 "__module__": "spark.connect.types_pb2"
-                # @@protoc_insertion_point(class_scope:spark.connect.DataType.I32)
+                # @@protoc_insertion_point(class_scope:spark.connect.DataType.Integer)
             },
         ),
-        "I64": _reflection.GeneratedProtocolMessageType(
-            "I64",
+        "Long": _reflection.GeneratedProtocolMessageType(
+            "Long",
             (_message.Message,),
             {
-                "DESCRIPTOR": _DATATYPE_I64,
+                "DESCRIPTOR": _DATATYPE_LONG,
                 "__module__": "spark.connect.types_pb2"
-                # @@protoc_insertion_point(class_scope:spark.connect.DataType.I64)
+                # @@protoc_insertion_point(class_scope:spark.connect.DataType.Long)
             },
         ),
-        "FP32": _reflection.GeneratedProtocolMessageType(
-            "FP32",
+        "Float": _reflection.GeneratedProtocolMessageType(
+            "Float",
             (_message.Message,),
             {
-                "DESCRIPTOR": _DATATYPE_FP32,
+                "DESCRIPTOR": _DATATYPE_FLOAT,
                 "__module__": "spark.connect.types_pb2"
-                # @@protoc_insertion_point(class_scope:spark.connect.DataType.FP32)
+                # @@protoc_insertion_point(class_scope:spark.connect.DataType.Float)
             },
         ),
-        "FP64": _reflection.GeneratedProtocolMessageType(
-            "FP64",
+        "Double": _reflection.GeneratedProtocolMessageType(
+            "Double",
             (_message.Message,),
             {
-                "DESCRIPTOR": _DATATYPE_FP64,
+                "DESCRIPTOR": _DATATYPE_DOUBLE,
                 "__module__": "spark.connect.types_pb2"
-                # @@protoc_insertion_point(class_scope:spark.connect.DataType.FP64)
+                # @@protoc_insertion_point(class_scope:spark.connect.DataType.Double)
             },
         ),
         "String": _reflection.GeneratedProtocolMessageType(
@@ -145,6 +146,15 @@ DataType = _reflection.GeneratedProtocolMessageType(
                 # @@protoc_insertion_point(class_scope:spark.connect.DataType.Binary)
             },
         ),
+        "NULL": _reflection.GeneratedProtocolMessageType(
+            "NULL",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _DATATYPE_NULL,
+                "__module__": "spark.connect.types_pb2"
+                # @@protoc_insertion_point(class_scope:spark.connect.DataType.NULL)
+            },
+        ),
         "Timestamp": _reflection.GeneratedProtocolMessageType(
             "Timestamp",
             (_message.Message,),
@@ -163,40 +173,40 @@ DataType = _reflection.GeneratedProtocolMessageType(
                 # @@protoc_insertion_point(class_scope:spark.connect.DataType.Date)
             },
         ),
-        "Time": _reflection.GeneratedProtocolMessageType(
-            "Time",
+        "TimestampNTZ": _reflection.GeneratedProtocolMessageType(
+            "TimestampNTZ",
             (_message.Message,),
             {
-                "DESCRIPTOR": _DATATYPE_TIME,
+                "DESCRIPTOR": _DATATYPE_TIMESTAMPNTZ,
                 "__module__": "spark.connect.types_pb2"
-                # @@protoc_insertion_point(class_scope:spark.connect.DataType.Time)
+                # @@protoc_insertion_point(class_scope:spark.connect.DataType.TimestampNTZ)
             },
         ),
-        "TimestampTZ": _reflection.GeneratedProtocolMessageType(
-            "TimestampTZ",
+        "CalendarInterval": _reflection.GeneratedProtocolMessageType(
+            "CalendarInterval",
             (_message.Message,),
             {
-                "DESCRIPTOR": _DATATYPE_TIMESTAMPTZ,
+                "DESCRIPTOR": _DATATYPE_CALENDARINTERVAL,
                 "__module__": "spark.connect.types_pb2"
-                # @@protoc_insertion_point(class_scope:spark.connect.DataType.TimestampTZ)
+                # @@protoc_insertion_point(class_scope:spark.connect.DataType.CalendarInterval)
             },
         ),
-        "IntervalYear": _reflection.GeneratedProtocolMessageType(
-            "IntervalYear",
+        "YearMonthInterval": _reflection.GeneratedProtocolMessageType(
+            "YearMonthInterval",
             (_message.Message,),
             {
-                "DESCRIPTOR": _DATATYPE_INTERVALYEAR,
+                "DESCRIPTOR": _DATATYPE_YEARMONTHINTERVAL,
                 "__module__": "spark.connect.types_pb2"
-                # @@protoc_insertion_point(class_scope:spark.connect.DataType.IntervalYear)
+                # @@protoc_insertion_point(class_scope:spark.connect.DataType.YearMonthInterval)
             },
         ),
-        "IntervalDay": _reflection.GeneratedProtocolMessageType(
-            "IntervalDay",
+        "DayTimeInterval": _reflection.GeneratedProtocolMessageType(
+            "DayTimeInterval",
             (_message.Message,),
             {
-                "DESCRIPTOR": _DATATYPE_INTERVALDAY,
+                "DESCRIPTOR": _DATATYPE_DAYTIMEINTERVAL,
                 "__module__": "spark.connect.types_pb2"
-                # @@protoc_insertion_point(class_scope:spark.connect.DataType.IntervalDay)
+                # @@protoc_insertion_point(class_scope:spark.connect.DataType.DayTimeInterval)
             },
         ),
         "UUID": _reflection.GeneratedProtocolMessageType(
@@ -208,13 +218,13 @@ DataType = _reflection.GeneratedProtocolMessageType(
                 # @@protoc_insertion_point(class_scope:spark.connect.DataType.UUID)
             },
         ),
-        "FixedChar": _reflection.GeneratedProtocolMessageType(
-            "FixedChar",
+        "Char": _reflection.GeneratedProtocolMessageType(
+            "Char",
             (_message.Message,),
             {
-                "DESCRIPTOR": _DATATYPE_FIXEDCHAR,
+                "DESCRIPTOR": _DATATYPE_CHAR,
                 "__module__": "spark.connect.types_pb2"
-                # @@protoc_insertion_point(class_scope:spark.connect.DataType.FixedChar)
+                # @@protoc_insertion_point(class_scope:spark.connect.DataType.Char)
             },
         ),
         "VarChar": _reflection.GeneratedProtocolMessageType(
@@ -271,13 +281,13 @@ DataType = _reflection.GeneratedProtocolMessageType(
                 # @@protoc_insertion_point(class_scope:spark.connect.DataType.Struct)
             },
         ),
-        "List": _reflection.GeneratedProtocolMessageType(
-            "List",
+        "Array": _reflection.GeneratedProtocolMessageType(
+            "Array",
             (_message.Message,),
             {
-                "DESCRIPTOR": _DATATYPE_LIST,
+                "DESCRIPTOR": _DATATYPE_ARRAY,
                 "__module__": "spark.connect.types_pb2"
-                # @@protoc_insertion_point(class_scope:spark.connect.DataType.List)
+                # @@protoc_insertion_point(class_scope:spark.connect.DataType.Array)
             },
         ),
         "Map": _reflection.GeneratedProtocolMessageType(
@@ -296,29 +306,30 @@ DataType = _reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(DataType)
 _sym_db.RegisterMessage(DataType.Boolean)
-_sym_db.RegisterMessage(DataType.I8)
-_sym_db.RegisterMessage(DataType.I16)
-_sym_db.RegisterMessage(DataType.I32)
-_sym_db.RegisterMessage(DataType.I64)
-_sym_db.RegisterMessage(DataType.FP32)
-_sym_db.RegisterMessage(DataType.FP64)
+_sym_db.RegisterMessage(DataType.Byte)
+_sym_db.RegisterMessage(DataType.Short)
+_sym_db.RegisterMessage(DataType.Integer)
+_sym_db.RegisterMessage(DataType.Long)
+_sym_db.RegisterMessage(DataType.Float)
+_sym_db.RegisterMessage(DataType.Double)
 _sym_db.RegisterMessage(DataType.String)
 _sym_db.RegisterMessage(DataType.Binary)
+_sym_db.RegisterMessage(DataType.NULL)
 _sym_db.RegisterMessage(DataType.Timestamp)
 _sym_db.RegisterMessage(DataType.Date)
-_sym_db.RegisterMessage(DataType.Time)
-_sym_db.RegisterMessage(DataType.TimestampTZ)
-_sym_db.RegisterMessage(DataType.IntervalYear)
-_sym_db.RegisterMessage(DataType.IntervalDay)
+_sym_db.RegisterMessage(DataType.TimestampNTZ)
+_sym_db.RegisterMessage(DataType.CalendarInterval)
+_sym_db.RegisterMessage(DataType.YearMonthInterval)
+_sym_db.RegisterMessage(DataType.DayTimeInterval)
 _sym_db.RegisterMessage(DataType.UUID)
-_sym_db.RegisterMessage(DataType.FixedChar)
+_sym_db.RegisterMessage(DataType.Char)
 _sym_db.RegisterMessage(DataType.VarChar)
 _sym_db.RegisterMessage(DataType.FixedBinary)
 _sym_db.RegisterMessage(DataType.Decimal)
 _sym_db.RegisterMessage(DataType.StructField)
 _sym_db.RegisterMessage(DataType.StructField.MetadataEntry)
 _sym_db.RegisterMessage(DataType.Struct)
-_sym_db.RegisterMessage(DataType.List)
+_sym_db.RegisterMessage(DataType.Array)
 _sym_db.RegisterMessage(DataType.Map)
 
 if _descriptor._USE_C_DESCRIPTORS == False:
@@ -328,55 +339,57 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _DATATYPE_STRUCTFIELD_METADATAENTRY._options = None
     _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_options = b"8\001"
     _DATATYPE._serialized_start = 45
-    _DATATYPE._serialized_end = 3694
-    _DATATYPE_BOOLEAN._serialized_start = 1461
-    _DATATYPE_BOOLEAN._serialized_end = 1528
-    _DATATYPE_I8._serialized_start = 1530
-    _DATATYPE_I8._serialized_end = 1592
-    _DATATYPE_I16._serialized_start = 1594
-    _DATATYPE_I16._serialized_end = 1657
-    _DATATYPE_I32._serialized_start = 1659
-    _DATATYPE_I32._serialized_end = 1722
-    _DATATYPE_I64._serialized_start = 1724
-    _DATATYPE_I64._serialized_end = 1787
-    _DATATYPE_FP32._serialized_start = 1789
-    _DATATYPE_FP32._serialized_end = 1853
-    _DATATYPE_FP64._serialized_start = 1855
-    _DATATYPE_FP64._serialized_end = 1919
-    _DATATYPE_STRING._serialized_start = 1921
-    _DATATYPE_STRING._serialized_end = 1987
-    _DATATYPE_BINARY._serialized_start = 1989
-    _DATATYPE_BINARY._serialized_end = 2055
-    _DATATYPE_TIMESTAMP._serialized_start = 2057
-    _DATATYPE_TIMESTAMP._serialized_end = 2126
-    _DATATYPE_DATE._serialized_start = 2128
-    _DATATYPE_DATE._serialized_end = 2192
-    _DATATYPE_TIME._serialized_start = 2194
-    _DATATYPE_TIME._serialized_end = 2258
-    _DATATYPE_TIMESTAMPTZ._serialized_start = 2260
-    _DATATYPE_TIMESTAMPTZ._serialized_end = 2331
-    _DATATYPE_INTERVALYEAR._serialized_start = 2333
-    _DATATYPE_INTERVALYEAR._serialized_end = 2405
-    _DATATYPE_INTERVALDAY._serialized_start = 2407
-    _DATATYPE_INTERVALDAY._serialized_end = 2478
-    _DATATYPE_UUID._serialized_start = 2480
-    _DATATYPE_UUID._serialized_end = 2544
-    _DATATYPE_FIXEDCHAR._serialized_start = 2546
-    _DATATYPE_FIXEDCHAR._serialized_end = 2639
-    _DATATYPE_VARCHAR._serialized_start = 2641
-    _DATATYPE_VARCHAR._serialized_end = 2732
-    _DATATYPE_FIXEDBINARY._serialized_start = 2734
-    _DATATYPE_FIXEDBINARY._serialized_end = 2829
-    _DATATYPE_DECIMAL._serialized_start = 2831
-    _DATATYPE_DECIMAL._serialized_end = 2950
-    _DATATYPE_STRUCTFIELD._serialized_start = 2953
-    _DATATYPE_STRUCTFIELD._serialized_end = 3199
-    _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_start = 3140
-    _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_end = 3199
-    _DATATYPE_STRUCT._serialized_start = 3201
-    _DATATYPE_STRUCT._serialized_end = 3328
-    _DATATYPE_LIST._serialized_start = 3331
-    _DATATYPE_LIST._serialized_end = 3491
-    _DATATYPE_MAP._serialized_start = 3494
-    _DATATYPE_MAP._serialized_end = 3686
+    _DATATYPE._serialized_end = 4219
+    _DATATYPE_BOOLEAN._serialized_start = 1612
+    _DATATYPE_BOOLEAN._serialized_end = 1679
+    _DATATYPE_BYTE._serialized_start = 1681
+    _DATATYPE_BYTE._serialized_end = 1745
+    _DATATYPE_SHORT._serialized_start = 1747
+    _DATATYPE_SHORT._serialized_end = 1812
+    _DATATYPE_INTEGER._serialized_start = 1814
+    _DATATYPE_INTEGER._serialized_end = 1881
+    _DATATYPE_LONG._serialized_start = 1883
+    _DATATYPE_LONG._serialized_end = 1947
+    _DATATYPE_FLOAT._serialized_start = 1949
+    _DATATYPE_FLOAT._serialized_end = 2014
+    _DATATYPE_DOUBLE._serialized_start = 2016
+    _DATATYPE_DOUBLE._serialized_end = 2082
+    _DATATYPE_STRING._serialized_start = 2084
+    _DATATYPE_STRING._serialized_end = 2150
+    _DATATYPE_BINARY._serialized_start = 2152
+    _DATATYPE_BINARY._serialized_end = 2218
+    _DATATYPE_NULL._serialized_start = 2220
+    _DATATYPE_NULL._serialized_end = 2284
+    _DATATYPE_TIMESTAMP._serialized_start = 2286
+    _DATATYPE_TIMESTAMP._serialized_end = 2355
+    _DATATYPE_DATE._serialized_start = 2357
+    _DATATYPE_DATE._serialized_end = 2421
+    _DATATYPE_TIMESTAMPNTZ._serialized_start = 2423
+    _DATATYPE_TIMESTAMPNTZ._serialized_end = 2495
+    _DATATYPE_CALENDARINTERVAL._serialized_start = 2497
+    _DATATYPE_CALENDARINTERVAL._serialized_end = 2573
+    _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2576
+    _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2755
+    _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2758
+    _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2935
+    _DATATYPE_UUID._serialized_start = 2937
+    _DATATYPE_UUID._serialized_end = 3001
+    _DATATYPE_CHAR._serialized_start = 3003
+    _DATATYPE_CHAR._serialized_end = 3091
+    _DATATYPE_VARCHAR._serialized_start = 3093
+    _DATATYPE_VARCHAR._serialized_end = 3184
+    _DATATYPE_FIXEDBINARY._serialized_start = 3186
+    _DATATYPE_FIXEDBINARY._serialized_end = 3281
+    _DATATYPE_DECIMAL._serialized_start = 3284
+    _DATATYPE_DECIMAL._serialized_end = 3437
+    _DATATYPE_STRUCTFIELD._serialized_start = 3440
+    _DATATYPE_STRUCTFIELD._serialized_end = 3695
+    _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_start = 3636
+    _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_end = 3695
+    _DATATYPE_STRUCT._serialized_start = 3697
+    _DATATYPE_STRUCT._serialized_end = 3824
+    _DATATYPE_ARRAY._serialized_start = 3827
+    _DATATYPE_ARRAY._serialized_end = 3989
+    _DATATYPE_MAP._serialized_start = 3992
+    _DATATYPE_MAP._serialized_end = 4211
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/types_pb2.pyi b/python/pyspark/sql/connect/proto/types_pb2.pyi
index 3bf36fc790c..647f625659b 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/types_pb2.pyi
@@ -39,6 +39,7 @@ import google.protobuf.descriptor
 import google.protobuf.internal.containers
 import google.protobuf.message
 import sys
+import typing
 
 if sys.version_info >= (3, 8):
     import typing as typing_extensions
@@ -71,7 +72,7 @@ class DataType(google.protobuf.message.Message):
             ],
         ) -> None: ...
 
-    class I8(google.protobuf.message.Message):
+    class Byte(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
         TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -88,7 +89,7 @@ class DataType(google.protobuf.message.Message):
             ],
         ) -> None: ...
 
-    class I16(google.protobuf.message.Message):
+    class Short(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
         TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -105,7 +106,7 @@ class DataType(google.protobuf.message.Message):
             ],
         ) -> None: ...
 
-    class I32(google.protobuf.message.Message):
+    class Integer(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
         TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -122,7 +123,7 @@ class DataType(google.protobuf.message.Message):
             ],
         ) -> None: ...
 
-    class I64(google.protobuf.message.Message):
+    class Long(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
         TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -139,7 +140,7 @@ class DataType(google.protobuf.message.Message):
             ],
         ) -> None: ...
 
-    class FP32(google.protobuf.message.Message):
+    class Float(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
         TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -156,7 +157,7 @@ class DataType(google.protobuf.message.Message):
             ],
         ) -> None: ...
 
-    class FP64(google.protobuf.message.Message):
+    class Double(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
         TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -207,6 +208,23 @@ class DataType(google.protobuf.message.Message):
             ],
         ) -> None: ...
 
+    class NULL(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
+        type_variation_reference: builtins.int
+        def __init__(
+            self,
+            *,
+            type_variation_reference: builtins.int = ...,
+        ) -> None: ...
+        def ClearField(
+            self,
+            field_name: typing_extensions.Literal[
+                "type_variation_reference", b"type_variation_reference"
+            ],
+        ) -> None: ...
+
     class Timestamp(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
@@ -241,7 +259,7 @@ class DataType(google.protobuf.message.Message):
             ],
         ) -> None: ...
 
-    class Time(google.protobuf.message.Message):
+    class TimestampNTZ(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
         TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -258,7 +276,7 @@ class DataType(google.protobuf.message.Message):
             ],
         ) -> None: ...
 
-    class TimestampTZ(google.protobuf.message.Message):
+    class CalendarInterval(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
         TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
@@ -275,39 +293,111 @@ class DataType(google.protobuf.message.Message):
             ],
         ) -> None: ...
 
-    class IntervalYear(google.protobuf.message.Message):
+    class YearMonthInterval(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
+        START_FIELD_FIELD_NUMBER: builtins.int
+        END_FIELD_FIELD_NUMBER: builtins.int
         TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
+        start_field: builtins.int
+        end_field: builtins.int
         type_variation_reference: builtins.int
         def __init__(
             self,
             *,
+            start_field: builtins.int | None = ...,
+            end_field: builtins.int | None = ...,
             type_variation_reference: builtins.int = ...,
         ) -> None: ...
+        def HasField(
+            self,
+            field_name: typing_extensions.Literal[
+                "_end_field",
+                b"_end_field",
+                "_start_field",
+                b"_start_field",
+                "end_field",
+                b"end_field",
+                "start_field",
+                b"start_field",
+            ],
+        ) -> builtins.bool: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
-                "type_variation_reference", b"type_variation_reference"
+                "_end_field",
+                b"_end_field",
+                "_start_field",
+                b"_start_field",
+                "end_field",
+                b"end_field",
+                "start_field",
+                b"start_field",
+                "type_variation_reference",
+                b"type_variation_reference",
             ],
         ) -> None: ...
+        @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_end_field", b"_end_field"]
+        ) -> typing_extensions.Literal["end_field"] | None: ...
+        @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_start_field", b"_start_field"]
+        ) -> typing_extensions.Literal["start_field"] | None: ...
 
-    class IntervalDay(google.protobuf.message.Message):
+    class DayTimeInterval(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
+        START_FIELD_FIELD_NUMBER: builtins.int
+        END_FIELD_FIELD_NUMBER: builtins.int
         TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
+        start_field: builtins.int
+        end_field: builtins.int
         type_variation_reference: builtins.int
         def __init__(
             self,
             *,
+            start_field: builtins.int | None = ...,
+            end_field: builtins.int | None = ...,
             type_variation_reference: builtins.int = ...,
         ) -> None: ...
+        def HasField(
+            self,
+            field_name: typing_extensions.Literal[
+                "_end_field",
+                b"_end_field",
+                "_start_field",
+                b"_start_field",
+                "end_field",
+                b"end_field",
+                "start_field",
+                b"start_field",
+            ],
+        ) -> builtins.bool: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
-                "type_variation_reference", b"type_variation_reference"
+                "_end_field",
+                b"_end_field",
+                "_start_field",
+                b"_start_field",
+                "end_field",
+                b"end_field",
+                "start_field",
+                b"start_field",
+                "type_variation_reference",
+                b"type_variation_reference",
             ],
         ) -> None: ...
+        @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_end_field", b"_end_field"]
+        ) -> typing_extensions.Literal["end_field"] | None: ...
+        @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_start_field", b"_start_field"]
+        ) -> typing_extensions.Literal["start_field"] | None: ...
 
     class UUID(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -326,7 +416,7 @@ class DataType(google.protobuf.message.Message):
             ],
         ) -> None: ...
 
-    class FixedChar(google.protobuf.message.Message):
+    class Char(google.protobuf.message.Message):
         """Start compound types."""
 
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -400,13 +490,30 @@ class DataType(google.protobuf.message.Message):
         def __init__(
             self,
             *,
-            scale: builtins.int = ...,
-            precision: builtins.int = ...,
+            scale: builtins.int | None = ...,
+            precision: builtins.int | None = ...,
             type_variation_reference: builtins.int = ...,
         ) -> None: ...
+        def HasField(
+            self,
+            field_name: typing_extensions.Literal[
+                "_precision",
+                b"_precision",
+                "_scale",
+                b"_scale",
+                "precision",
+                b"precision",
+                "scale",
+                b"scale",
+            ],
+        ) -> builtins.bool: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
+                "_precision",
+                b"_precision",
+                "_scale",
+                b"_scale",
                 "precision",
                 b"precision",
                 "scale",
@@ -415,6 +522,14 @@ class DataType(google.protobuf.message.Message):
                 b"type_variation_reference",
             ],
         ) -> None: ...
+        @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_precision", b"_precision"]
+        ) -> typing_extensions.Literal["precision"] | None: ...
+        @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_scale", b"_scale"]
+        ) -> typing_extensions.Literal["scale"] | None: ...
 
     class StructField(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -436,13 +551,13 @@ class DataType(google.protobuf.message.Message):
                 self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
             ) -> None: ...
 
-        TYPE_FIELD_NUMBER: builtins.int
         NAME_FIELD_NUMBER: builtins.int
+        DATA_TYPE_FIELD_NUMBER: builtins.int
         NULLABLE_FIELD_NUMBER: builtins.int
         METADATA_FIELD_NUMBER: builtins.int
-        @property
-        def type(self) -> global___DataType: ...
         name: builtins.str
+        @property
+        def data_type(self) -> global___DataType: ...
         nullable: builtins.bool
         @property
         def metadata(
@@ -451,18 +566,25 @@ class DataType(google.protobuf.message.Message):
         def __init__(
             self,
             *,
-            type: global___DataType | None = ...,
             name: builtins.str = ...,
+            data_type: global___DataType | None = ...,
             nullable: builtins.bool = ...,
             metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
         ) -> None: ...
         def HasField(
-            self, field_name: typing_extensions.Literal["type", b"type"]
+            self, field_name: typing_extensions.Literal["data_type", b"data_type"]
         ) -> builtins.bool: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
-                "metadata", b"metadata", "name", b"name", "nullable", b"nullable", "type", b"type"
+                "data_type",
+                b"data_type",
+                "metadata",
+                b"metadata",
+                "name",
+                b"name",
+                "nullable",
+                b"nullable",
             ],
         ) -> None: ...
 
@@ -491,33 +613,33 @@ class DataType(google.protobuf.message.Message):
             ],
         ) -> None: ...
 
-    class List(google.protobuf.message.Message):
+    class Array(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
-        DATATYPE_FIELD_NUMBER: builtins.int
+        ELEMENT_TYPE_FIELD_NUMBER: builtins.int
+        CONTAINS_NULL_FIELD_NUMBER: builtins.int
         TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
-        ELEMENT_NULLABLE_FIELD_NUMBER: builtins.int
         @property
-        def DataType(self) -> global___DataType: ...
+        def element_type(self) -> global___DataType: ...
+        contains_null: builtins.bool
         type_variation_reference: builtins.int
-        element_nullable: builtins.bool
         def __init__(
             self,
             *,
-            DataType: global___DataType | None = ...,
+            element_type: global___DataType | None = ...,
+            contains_null: builtins.bool = ...,
             type_variation_reference: builtins.int = ...,
-            element_nullable: builtins.bool = ...,
         ) -> None: ...
         def HasField(
-            self, field_name: typing_extensions.Literal["DataType", b"DataType"]
+            self, field_name: typing_extensions.Literal["element_type", b"element_type"]
         ) -> builtins.bool: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
-                "DataType",
-                b"DataType",
-                "element_nullable",
-                b"element_nullable",
+                "contains_null",
+                b"contains_null",
+                "element_type",
+                b"element_type",
                 "type_variation_reference",
                 b"type_variation_reference",
             ],
@@ -526,276 +648,293 @@ class DataType(google.protobuf.message.Message):
     class Map(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
-        KEY_FIELD_NUMBER: builtins.int
-        VALUE_FIELD_NUMBER: builtins.int
+        KEY_TYPE_FIELD_NUMBER: builtins.int
+        VALUE_TYPE_FIELD_NUMBER: builtins.int
+        VALUE_CONTAINS_NULL_FIELD_NUMBER: builtins.int
         TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
-        VALUE_NULLABLE_FIELD_NUMBER: builtins.int
         @property
-        def key(self) -> global___DataType: ...
+        def key_type(self) -> global___DataType: ...
         @property
-        def value(self) -> global___DataType: ...
+        def value_type(self) -> global___DataType: ...
+        value_contains_null: builtins.bool
         type_variation_reference: builtins.int
-        value_nullable: builtins.bool
         def __init__(
             self,
             *,
-            key: global___DataType | None = ...,
-            value: global___DataType | None = ...,
+            key_type: global___DataType | None = ...,
+            value_type: global___DataType | None = ...,
+            value_contains_null: builtins.bool = ...,
             type_variation_reference: builtins.int = ...,
-            value_nullable: builtins.bool = ...,
         ) -> None: ...
         def HasField(
-            self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
+            self,
+            field_name: typing_extensions.Literal[
+                "key_type", b"key_type", "value_type", b"value_type"
+            ],
         ) -> builtins.bool: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
-                "key",
-                b"key",
+                "key_type",
+                b"key_type",
                 "type_variation_reference",
                 b"type_variation_reference",
-                "value",
-                b"value",
-                "value_nullable",
-                b"value_nullable",
+                "value_contains_null",
+                b"value_contains_null",
+                "value_type",
+                b"value_type",
             ],
         ) -> None: ...
 
-    BOOL_FIELD_NUMBER: builtins.int
-    I8_FIELD_NUMBER: builtins.int
-    I16_FIELD_NUMBER: builtins.int
-    I32_FIELD_NUMBER: builtins.int
-    I64_FIELD_NUMBER: builtins.int
-    FP32_FIELD_NUMBER: builtins.int
-    FP64_FIELD_NUMBER: builtins.int
-    STRING_FIELD_NUMBER: builtins.int
+    NULL_FIELD_NUMBER: builtins.int
     BINARY_FIELD_NUMBER: builtins.int
-    TIMESTAMP_FIELD_NUMBER: builtins.int
-    DATE_FIELD_NUMBER: builtins.int
-    TIME_FIELD_NUMBER: builtins.int
-    INTERVAL_YEAR_FIELD_NUMBER: builtins.int
-    INTERVAL_DAY_FIELD_NUMBER: builtins.int
-    TIMESTAMP_TZ_FIELD_NUMBER: builtins.int
-    UUID_FIELD_NUMBER: builtins.int
-    FIXED_CHAR_FIELD_NUMBER: builtins.int
-    VARCHAR_FIELD_NUMBER: builtins.int
-    FIXED_BINARY_FIELD_NUMBER: builtins.int
+    BOOLEAN_FIELD_NUMBER: builtins.int
+    BYTE_FIELD_NUMBER: builtins.int
+    SHORT_FIELD_NUMBER: builtins.int
+    INTEGER_FIELD_NUMBER: builtins.int
+    LONG_FIELD_NUMBER: builtins.int
+    FLOAT_FIELD_NUMBER: builtins.int
+    DOUBLE_FIELD_NUMBER: builtins.int
     DECIMAL_FIELD_NUMBER: builtins.int
+    STRING_FIELD_NUMBER: builtins.int
+    CHAR_FIELD_NUMBER: builtins.int
+    VAR_CHAR_FIELD_NUMBER: builtins.int
+    DATE_FIELD_NUMBER: builtins.int
+    TIMESTAMP_FIELD_NUMBER: builtins.int
+    TIMESTAMP_NTZ_FIELD_NUMBER: builtins.int
+    CALENDAR_INTERVAL_FIELD_NUMBER: builtins.int
+    YEAR_MONTH_INTERVAL_FIELD_NUMBER: builtins.int
+    DAY_TIME_INTERVAL_FIELD_NUMBER: builtins.int
+    ARRAY_FIELD_NUMBER: builtins.int
     STRUCT_FIELD_NUMBER: builtins.int
-    LIST_FIELD_NUMBER: builtins.int
     MAP_FIELD_NUMBER: builtins.int
+    UUID_FIELD_NUMBER: builtins.int
+    FIXED_BINARY_FIELD_NUMBER: builtins.int
     USER_DEFINED_TYPE_REFERENCE_FIELD_NUMBER: builtins.int
     @property
-    def bool(self) -> global___DataType.Boolean: ...
+    def null(self) -> global___DataType.NULL: ...
     @property
-    def i8(self) -> global___DataType.I8: ...
+    def binary(self) -> global___DataType.Binary: ...
     @property
-    def i16(self) -> global___DataType.I16: ...
+    def boolean(self) -> global___DataType.Boolean: ...
     @property
-    def i32(self) -> global___DataType.I32: ...
+    def byte(self) -> global___DataType.Byte:
+        """Numeric types"""
     @property
-    def i64(self) -> global___DataType.I64: ...
+    def short(self) -> global___DataType.Short: ...
     @property
-    def fp32(self) -> global___DataType.FP32: ...
+    def integer(self) -> global___DataType.Integer: ...
     @property
-    def fp64(self) -> global___DataType.FP64: ...
+    def long(self) -> global___DataType.Long: ...
     @property
-    def string(self) -> global___DataType.String: ...
+    def float(self) -> global___DataType.Float: ...
     @property
-    def binary(self) -> global___DataType.Binary: ...
+    def double(self) -> global___DataType.Double: ...
     @property
-    def timestamp(self) -> global___DataType.Timestamp: ...
+    def decimal(self) -> global___DataType.Decimal: ...
     @property
-    def date(self) -> global___DataType.Date: ...
+    def string(self) -> global___DataType.String:
+        """String types"""
     @property
-    def time(self) -> global___DataType.Time: ...
+    def char(self) -> global___DataType.Char: ...
     @property
-    def interval_year(self) -> global___DataType.IntervalYear: ...
+    def var_char(self) -> global___DataType.VarChar: ...
     @property
-    def interval_day(self) -> global___DataType.IntervalDay: ...
+    def date(self) -> global___DataType.Date:
+        """Datatime types"""
     @property
-    def timestamp_tz(self) -> global___DataType.TimestampTZ: ...
+    def timestamp(self) -> global___DataType.Timestamp: ...
     @property
-    def uuid(self) -> global___DataType.UUID: ...
+    def timestamp_ntz(self) -> global___DataType.TimestampNTZ: ...
     @property
-    def fixed_char(self) -> global___DataType.FixedChar: ...
+    def calendar_interval(self) -> global___DataType.CalendarInterval:
+        """Interval types"""
     @property
-    def varchar(self) -> global___DataType.VarChar: ...
+    def year_month_interval(self) -> global___DataType.YearMonthInterval: ...
     @property
-    def fixed_binary(self) -> global___DataType.FixedBinary: ...
+    def day_time_interval(self) -> global___DataType.DayTimeInterval: ...
     @property
-    def decimal(self) -> global___DataType.Decimal: ...
+    def array(self) -> global___DataType.Array:
+        """Complex types"""
     @property
     def struct(self) -> global___DataType.Struct: ...
     @property
-    def list(self) -> global___DataType.List: ...
-    @property
     def map(self) -> global___DataType.Map: ...
+    @property
+    def uuid(self) -> global___DataType.UUID: ...
+    @property
+    def fixed_binary(self) -> global___DataType.FixedBinary: ...
     user_defined_type_reference: builtins.int
     def __init__(
         self,
         *,
-        bool: global___DataType.Boolean | None = ...,
-        i8: global___DataType.I8 | None = ...,
-        i16: global___DataType.I16 | None = ...,
-        i32: global___DataType.I32 | None = ...,
-        i64: global___DataType.I64 | None = ...,
-        fp32: global___DataType.FP32 | None = ...,
-        fp64: global___DataType.FP64 | None = ...,
-        string: global___DataType.String | None = ...,
+        null: global___DataType.NULL | None = ...,
         binary: global___DataType.Binary | None = ...,
-        timestamp: global___DataType.Timestamp | None = ...,
-        date: global___DataType.Date | None = ...,
-        time: global___DataType.Time | None = ...,
-        interval_year: global___DataType.IntervalYear | None = ...,
-        interval_day: global___DataType.IntervalDay | None = ...,
-        timestamp_tz: global___DataType.TimestampTZ | None = ...,
-        uuid: global___DataType.UUID | None = ...,
-        fixed_char: global___DataType.FixedChar | None = ...,
-        varchar: global___DataType.VarChar | None = ...,
-        fixed_binary: global___DataType.FixedBinary | None = ...,
+        boolean: global___DataType.Boolean | None = ...,
+        byte: global___DataType.Byte | None = ...,
+        short: global___DataType.Short | None = ...,
+        integer: global___DataType.Integer | None = ...,
+        long: global___DataType.Long | None = ...,
+        float: global___DataType.Float | None = ...,
+        double: global___DataType.Double | None = ...,
         decimal: global___DataType.Decimal | None = ...,
+        string: global___DataType.String | None = ...,
+        char: global___DataType.Char | None = ...,
+        var_char: global___DataType.VarChar | None = ...,
+        date: global___DataType.Date | None = ...,
+        timestamp: global___DataType.Timestamp | None = ...,
+        timestamp_ntz: global___DataType.TimestampNTZ | None = ...,
+        calendar_interval: global___DataType.CalendarInterval | None = ...,
+        year_month_interval: global___DataType.YearMonthInterval | None = ...,
+        day_time_interval: global___DataType.DayTimeInterval | None = ...,
+        array: global___DataType.Array | None = ...,
         struct: global___DataType.Struct | None = ...,
-        list: global___DataType.List | None = ...,
         map: global___DataType.Map | None = ...,
+        uuid: global___DataType.UUID | None = ...,
+        fixed_binary: global___DataType.FixedBinary | None = ...,
         user_defined_type_reference: builtins.int = ...,
     ) -> None: ...
     def HasField(
         self,
         field_name: typing_extensions.Literal[
+            "array",
+            b"array",
             "binary",
             b"binary",
-            "bool",
-            b"bool",
+            "boolean",
+            b"boolean",
+            "byte",
+            b"byte",
+            "calendar_interval",
+            b"calendar_interval",
+            "char",
+            b"char",
             "date",
             b"date",
+            "day_time_interval",
+            b"day_time_interval",
             "decimal",
             b"decimal",
+            "double",
+            b"double",
             "fixed_binary",
             b"fixed_binary",
-            "fixed_char",
-            b"fixed_char",
-            "fp32",
-            b"fp32",
-            "fp64",
-            b"fp64",
-            "i16",
-            b"i16",
-            "i32",
-            b"i32",
-            "i64",
-            b"i64",
-            "i8",
-            b"i8",
-            "interval_day",
-            b"interval_day",
-            "interval_year",
-            b"interval_year",
+            "float",
+            b"float",
+            "integer",
+            b"integer",
             "kind",
             b"kind",
-            "list",
-            b"list",
+            "long",
+            b"long",
             "map",
             b"map",
+            "null",
+            b"null",
+            "short",
+            b"short",
             "string",
             b"string",
             "struct",
             b"struct",
-            "time",
-            b"time",
             "timestamp",
             b"timestamp",
-            "timestamp_tz",
-            b"timestamp_tz",
+            "timestamp_ntz",
+            b"timestamp_ntz",
             "user_defined_type_reference",
             b"user_defined_type_reference",
             "uuid",
             b"uuid",
-            "varchar",
-            b"varchar",
+            "var_char",
+            b"var_char",
+            "year_month_interval",
+            b"year_month_interval",
         ],
     ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
+            "array",
+            b"array",
             "binary",
             b"binary",
-            "bool",
-            b"bool",
+            "boolean",
+            b"boolean",
+            "byte",
+            b"byte",
+            "calendar_interval",
+            b"calendar_interval",
+            "char",
+            b"char",
             "date",
             b"date",
+            "day_time_interval",
+            b"day_time_interval",
             "decimal",
             b"decimal",
+            "double",
+            b"double",
             "fixed_binary",
             b"fixed_binary",
-            "fixed_char",
-            b"fixed_char",
-            "fp32",
-            b"fp32",
-            "fp64",
-            b"fp64",
-            "i16",
-            b"i16",
-            "i32",
-            b"i32",
-            "i64",
-            b"i64",
-            "i8",
-            b"i8",
-            "interval_day",
-            b"interval_day",
-            "interval_year",
-            b"interval_year",
+            "float",
+            b"float",
+            "integer",
+            b"integer",
             "kind",
             b"kind",
-            "list",
-            b"list",
+            "long",
+            b"long",
             "map",
             b"map",
+            "null",
+            b"null",
+            "short",
+            b"short",
             "string",
             b"string",
             "struct",
             b"struct",
-            "time",
-            b"time",
             "timestamp",
             b"timestamp",
-            "timestamp_tz",
-            b"timestamp_tz",
+            "timestamp_ntz",
+            b"timestamp_ntz",
             "user_defined_type_reference",
             b"user_defined_type_reference",
             "uuid",
             b"uuid",
-            "varchar",
-            b"varchar",
+            "var_char",
+            b"var_char",
+            "year_month_interval",
+            b"year_month_interval",
         ],
     ) -> None: ...
     def WhichOneof(
         self, oneof_group: typing_extensions.Literal["kind", b"kind"]
     ) -> typing_extensions.Literal[
-        "bool",
-        "i8",
-        "i16",
-        "i32",
-        "i64",
-        "fp32",
-        "fp64",
-        "string",
+        "null",
         "binary",
-        "timestamp",
-        "date",
-        "time",
-        "interval_year",
-        "interval_day",
-        "timestamp_tz",
-        "uuid",
-        "fixed_char",
-        "varchar",
-        "fixed_binary",
+        "boolean",
+        "byte",
+        "short",
+        "integer",
+        "long",
+        "float",
+        "double",
         "decimal",
+        "string",
+        "char",
+        "var_char",
+        "date",
+        "timestamp",
+        "timestamp_ntz",
+        "calendar_interval",
+        "year_month_interval",
+        "day_time_interval",
+        "array",
         "struct",
-        "list",
         "map",
+        "uuid",
+        "fixed_binary",
         "user_defined_type_reference",
     ] | None: ...
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index fa27b609941..ab454c53491 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -144,6 +144,73 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             schema,
         )
 
+        # test FloatType, DoubleType, DecimalType, StringType, BooleanType, NullType
+        query = """
+            SELECT * FROM VALUES
+            (float(1.0), double(1.0), 1.0, "1", true, NULL),
+            (float(2.0), double(2.0), 2.0, "2", false, NULL),
+            (float(3.0), double(3.0), NULL, "3", false, NULL)
+            AS tab(a, b, c, d, e, f)
+            """
+        self.assertEqual(
+            self.spark.sql(query).schema,
+            self.connect.sql(query).schema,
+        )
+
+        # test TimestampType, DateType
+        query = """
+            SELECT * FROM VALUES
+            (TIMESTAMP('2019-04-12 15:50:00'), DATE('2022-02-22')),
+            (TIMESTAMP('2019-04-12 15:50:00'), NULL),
+            (NULL, DATE('2022-02-22'))
+            AS tab(a, b)
+            """
+        self.assertEqual(
+            self.spark.sql(query).schema,
+            self.connect.sql(query).schema,
+        )
+
+        # test MapType
+        query = """
+            SELECT * FROM VALUES
+            (MAP('a', 'ab'), MAP('a', 'ab'), MAP(1, 2, 3, 4)),
+            (MAP('x', 'yz'), MAP('x', NULL), NULL),
+            (MAP('c', 'de'), NULL, MAP(-1, NULL, -3, -4))
+            AS tab(a, b, c)
+            """
+        self.assertEqual(
+            self.spark.sql(query).schema,
+            self.connect.sql(query).schema,
+        )
+
+        # test ArrayType
+        query = """
+            SELECT * FROM VALUES
+            (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3)),
+            (ARRAY('x', NULL), NULL, ARRAY(1, 3)),
+            (NULL, ARRAY(-1, -2, -3), Array())
+            AS tab(a, b, c)
+            """
+        self.assertEqual(
+            self.spark.sql(query).schema,
+            self.connect.sql(query).schema,
+        )
+
+        # test StructType
+        query = """
+            SELECT STRUCT(a, b, c, d), STRUCT(e, f, g), STRUCT(STRUCT(a, b), STRUCT(h)) FROM VALUES
+            (float(1.0), double(1.0), 1.0, "1", true, NULL, ARRAY(1, NULL, 3), MAP(1, 2, 3, 4)),
+            (float(2.0), double(2.0), 2.0, "2", false, NULL, ARRAY(1, 3), MAP(1, NULL, 3, 4)),
+            (float(3.0), double(3.0), NULL, "3", false, NULL, ARRAY(NULL), NULL)
+            AS tab(a, b, c, d, e, f, g, h)
+            """
+        # compare the __repr__() to ignore the metadata for now
+        # the metadata is not supported in Connect for now
+        self.assertEqual(
+            self.spark.sql(query).schema.__repr__(),
+            self.connect.sql(query).schema.__repr__(),
+        )
+
     def test_simple_binary_expressions(self):
         """Test complex expression"""
         df = self.connect.read.table(self.tbl_name)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org