You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/29 06:36:45 UTC
[spark] branch master updated: [SPARK-41264][CONNECT][PYTHON] Make Literal support more datatypes
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4c35c5bb5e5 [SPARK-41264][CONNECT][PYTHON] Make Literal support more datatypes
4c35c5bb5e5 is described below
commit 4c35c5bb5e545acb2f46a80218f68e69c868b388
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Tue Nov 29 14:36:10 2022 +0800
[SPARK-41264][CONNECT][PYTHON] Make Literal support more datatypes
### What changes were proposed in this pull request?
1, in the sever side, try to match https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala#L63-L101, and use `CreateArray`, `CreateStruct`, `CreateMap` for complex inputs;
2, in the client side, try to match https://github.com/apache/spark/blob/master/python/pyspark/sql/types.py#L1335-L1349 ,
but do not support `datetime.time` since I don't find a corrsponding sql type for it.
### Why are the changes needed?
try to support all datatype
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
updated tests
Closes #38800 from zhengruifeng/connect_update_literal.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../main/protobuf/spark/connect/expressions.proto | 103 ++---
.../src/main/protobuf/spark/connect/types.proto | 16 -
.../org/apache/spark/sql/connect/dsl/package.scala | 6 +-
.../sql/connect/planner/SparkConnectPlanner.scala | 118 ++++--
.../messages/ConnectProtoMessagesSuite.scala | 6 +-
.../connect/planner/SparkConnectPlannerSuite.scala | 2 +-
.../connect/planner/SparkConnectServiceSuite.scala | 2 +-
python/pyspark/sql/connect/column.py | 112 ++---
python/pyspark/sql/connect/plan.py | 4 +-
.../pyspark/sql/connect/proto/expressions_pb2.py | 151 +++----
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 462 ++++++++-------------
python/pyspark/sql/connect/proto/types_pb2.py | 126 +++---
python/pyspark/sql/connect/proto/types_pb2.pyi | 63 ---
.../connect/test_connect_column_expressions.py | 73 +++-
.../sql/tests/connect/test_connect_plan_only.py | 4 +-
15 files changed, 539 insertions(+), 709 deletions(-)
diff --git a/connector/connect/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
index 595758f0443..2a1159c1d04 100644
--- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
@@ -40,37 +40,34 @@ message Expression {
message Literal {
oneof literal_type {
- bool boolean = 1;
- int32 i8 = 2;
- int32 i16 = 3;
- int32 i32 = 5;
- int64 i64 = 7;
- float fp32 = 10;
- double fp64 = 11;
- string string = 12;
- bytes binary = 13;
- // Timestamp in units of microseconds since the UNIX epoch.
- int64 timestamp = 14;
+ bool null = 1;
+ bytes binary = 2;
+ bool boolean = 3;
+
+ int32 byte = 4;
+ int32 short = 5;
+ int32 integer = 6;
+ int64 long = 7;
+ float float = 10;
+ double double = 11;
+ Decimal decimal = 12;
+
+ string string = 13;
+
// Date in units of days since the UNIX epoch.
int32 date = 16;
- // Time in units of microseconds past midnight
- int64 time = 17;
- IntervalYearToMonth interval_year_to_month = 19;
- IntervalDayToSecond interval_day_to_second = 20;
- string fixed_char = 21;
- VarChar var_char = 22;
- bytes fixed_binary = 23;
- Decimal decimal = 24;
- Struct struct = 25;
- Map map = 26;
// Timestamp in units of microseconds since the UNIX epoch.
- int64 timestamp_tz = 27;
- bytes uuid = 28;
- DataType null = 29; // a typed null literal
- List list = 30;
- DataType.Array empty_array = 31;
- DataType.Map empty_map = 32;
- UserDefined user_defined = 33;
+ int64 timestamp = 17;
+ // Timestamp in units of microseconds since the UNIX epoch (without timezone information).
+ int64 timestamp_ntz = 18;
+
+ CalendarInterval calendar_interval = 19;
+ int32 year_month_interval = 20;
+ int64 day_time_interval = 21;
+
+ Array array = 22;
+ Struct struct = 23;
+ Map map = 24;
}
// whether the literal type should be treated as a nullable type. Applies to
@@ -83,40 +80,20 @@ message Expression {
// directly declare the type variation).
uint32 type_variation_reference = 51;
- message VarChar {
- string value = 1;
- uint32 length = 2;
- }
-
message Decimal {
- // little-endian twos-complement integer representation of complete value
- // (ignoring precision) Always 16 bytes in length
- bytes value = 1;
+ // the string representation.
+ string value = 1;
// The maximum number of digits allowed in the value.
// the maximum precision is 38.
- int32 precision = 2;
+ optional int32 precision = 2;
// declared scale of decimal literal
- int32 scale = 3;
+ optional int32 scale = 3;
}
- message Map {
- message KeyValue {
- Literal key = 1;
- Literal value = 2;
- }
-
- repeated KeyValue key_values = 1;
- }
-
- message IntervalYearToMonth {
- int32 years = 1;
- int32 months = 2;
- }
-
- message IntervalDayToSecond {
- int32 days = 1;
- int32 seconds = 2;
- int32 microseconds = 3;
+ message CalendarInterval {
+ int32 months = 1;
+ int32 days = 2;
+ int64 microseconds = 3;
}
message Struct {
@@ -124,18 +101,18 @@ message Expression {
repeated Literal fields = 1;
}
- message List {
+ message Array {
// A homogeneously typed list of literals
repeated Literal values = 1;
}
- message UserDefined {
- // points to a type_anchor defined in this plan
- uint32 type_reference = 1;
+ message Map {
+ repeated Pair pairs = 1;
- // the value of the literal, serialized using some type-specific
- // protobuf message
- google.protobuf.Any value = 2;
+ message Pair {
+ Literal key = 1;
+ Literal value = 2;
+ }
}
}
diff --git a/connector/connect/src/main/protobuf/spark/connect/types.proto b/connector/connect/src/main/protobuf/spark/connect/types.proto
index 56dbf28665e..785a191955f 100644
--- a/connector/connect/src/main/protobuf/spark/connect/types.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/types.proto
@@ -61,13 +61,6 @@ message DataType {
Array array = 20;
Struct struct = 21;
Map map = 22;
-
-
- UUID uuid = 25;
-
- FixedBinary fixed_binary = 26;
-
- uint32 user_defined_type_reference = 31;
}
message Boolean {
@@ -138,10 +131,6 @@ message DataType {
uint32 type_variation_reference = 3;
}
- message UUID {
- uint32 type_variation_reference = 1;
- }
-
// Start compound types.
message Char {
int32 length = 1;
@@ -153,11 +142,6 @@ message DataType {
uint32 type_variation_reference = 2;
}
- message FixedBinary {
- int32 length = 1;
- uint32 type_variation_reference = 2;
- }
-
message Decimal {
optional int32 scale = 1;
optional int32 precision = 2;
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index dd1c7f0574b..16bf4c52627 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -144,7 +144,7 @@ package object dsl {
implicit def intToLiteral(i: Int): Expression =
Expression
.newBuilder()
- .setLiteral(Expression.Literal.newBuilder().setI32(i))
+ .setLiteral(Expression.Literal.newBuilder().setInteger(i))
.build()
}
@@ -233,8 +233,8 @@ package object dsl {
private def convertValue(value: Any) = {
value match {
case b: Boolean => Expression.Literal.newBuilder().setBoolean(b).build()
- case l: Long => Expression.Literal.newBuilder().setI64(l).build()
- case d: Double => Expression.Literal.newBuilder().setFp64(d).build()
+ case l: Long => Expression.Literal.newBuilder().setLong(l).build()
+ case d: Double => Expression.Literal.newBuilder().setDouble(d).build()
case s: String => Expression.Literal.newBuilder().setString(s).build()
case o => throw new Exception(s"Unsupported value type: $o")
}
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index b2b6e6ffc54..c9a3a313ed3 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.{Column, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, NamedExpression, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
@@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.command.CreateViewCommand
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.util.Utils
final case class InvalidPlanInput(
@@ -174,17 +175,17 @@ class SparkConnectPlanner(session: SparkSession) {
} else {
dataset.na.fill(value = value.getBoolean).logicalPlan
}
- case proto.Expression.Literal.LiteralTypeCase.I64 =>
+ case proto.Expression.Literal.LiteralTypeCase.LONG =>
if (cols.nonEmpty) {
- dataset.na.fill(value = value.getI64, cols = cols).logicalPlan
+ dataset.na.fill(value = value.getLong, cols = cols).logicalPlan
} else {
- dataset.na.fill(value = value.getI64).logicalPlan
+ dataset.na.fill(value = value.getLong).logicalPlan
}
- case proto.Expression.Literal.LiteralTypeCase.FP64 =>
+ case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
if (cols.nonEmpty) {
- dataset.na.fill(value = value.getFp64, cols = cols).logicalPlan
+ dataset.na.fill(value = value.getDouble, cols = cols).logicalPlan
} else {
- dataset.na.fill(value = value.getFp64).logicalPlan
+ dataset.na.fill(value = value.getDouble).logicalPlan
}
case proto.Expression.Literal.LiteralTypeCase.STRING =>
if (cols.nonEmpty) {
@@ -200,10 +201,10 @@ class SparkConnectPlanner(session: SparkSession) {
value.getLiteralTypeCase match {
case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
valueMap.update(col, value.getBoolean)
- case proto.Expression.Literal.LiteralTypeCase.I64 =>
- valueMap.update(col, value.getI64)
- case proto.Expression.Literal.LiteralTypeCase.FP64 =>
- valueMap.update(col, value.getFp64)
+ case proto.Expression.Literal.LiteralTypeCase.LONG =>
+ valueMap.update(col, value.getLong)
+ case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
+ valueMap.update(col, value.getDouble)
case proto.Expression.Literal.LiteralTypeCase.STRING =>
valueMap.update(col, value.getString)
case other => throw InvalidPlanInput(s"Unsupported value type: $other")
@@ -371,34 +372,89 @@ class SparkConnectPlanner(session: SparkSession) {
/**
* Transforms the protocol buffers literals into the appropriate Catalyst literal expression.
- *
- * TODO(SPARK-40533): Missing support for Instant, BigDecimal, LocalDate, LocalTimestamp,
- * Duration, Period.
- * @param lit
* @return
* Expression
*/
private def transformLiteral(lit: proto.Expression.Literal): Expression = {
lit.getLiteralTypeCase match {
- case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => expressions.Literal(lit.getBoolean)
- case proto.Expression.Literal.LiteralTypeCase.I8 => expressions.Literal(lit.getI8, ByteType)
- case proto.Expression.Literal.LiteralTypeCase.I16 =>
- expressions.Literal(lit.getI16, ShortType)
- case proto.Expression.Literal.LiteralTypeCase.I32 => expressions.Literal(lit.getI32)
- case proto.Expression.Literal.LiteralTypeCase.I64 => expressions.Literal(lit.getI64)
- case proto.Expression.Literal.LiteralTypeCase.FP32 =>
- expressions.Literal(lit.getFp32, FloatType)
- case proto.Expression.Literal.LiteralTypeCase.FP64 =>
- expressions.Literal(lit.getFp64, DoubleType)
- case proto.Expression.Literal.LiteralTypeCase.STRING => expressions.Literal(lit.getString)
+ case proto.Expression.Literal.LiteralTypeCase.NULL =>
+ expressions.Literal(null, NullType)
+
case proto.Expression.Literal.LiteralTypeCase.BINARY =>
- expressions.Literal(lit.getBinary, BinaryType)
- // Microseconds since unix epoch.
- case proto.Expression.Literal.LiteralTypeCase.TIME =>
- expressions.Literal(lit.getTime, TimestampType)
- // Days since UNIX epoch.
+ expressions.Literal(lit.getBinary.toByteArray, BinaryType)
+
+ case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
+ expressions.Literal(lit.getBoolean, BooleanType)
+
+ case proto.Expression.Literal.LiteralTypeCase.BYTE =>
+ expressions.Literal(lit.getByte, ByteType)
+
+ case proto.Expression.Literal.LiteralTypeCase.SHORT =>
+ expressions.Literal(lit.getShort, ShortType)
+
+ case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
+ expressions.Literal(lit.getInteger, IntegerType)
+
+ case proto.Expression.Literal.LiteralTypeCase.LONG =>
+ expressions.Literal(lit.getLong, LongType)
+
+ case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
+ expressions.Literal(lit.getFloat, FloatType)
+
+ case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
+ expressions.Literal(lit.getDouble, DoubleType)
+
+ case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
+ val decimal = Decimal.apply(lit.getDecimal.getValue)
+ var precision = decimal.precision
+ if (lit.getDecimal.hasPrecision) {
+ precision = math.max(precision, lit.getDecimal.getPrecision)
+ }
+ var scale = decimal.scale
+ if (lit.getDecimal.hasScale) {
+ scale = math.max(scale, lit.getDecimal.getScale)
+ }
+ expressions.Literal(decimal, DecimalType(math.max(precision, scale), scale))
+
+ case proto.Expression.Literal.LiteralTypeCase.STRING =>
+ expressions.Literal(UTF8String.fromString(lit.getString), StringType)
+
case proto.Expression.Literal.LiteralTypeCase.DATE =>
expressions.Literal(lit.getDate, DateType)
+
+ case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
+ expressions.Literal(lit.getTimestamp, TimestampType)
+
+ case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
+ expressions.Literal(lit.getTimestampNtz, TimestampNTZType)
+
+ case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
+ val interval = new CalendarInterval(
+ lit.getCalendarInterval.getMonths,
+ lit.getCalendarInterval.getDays,
+ lit.getCalendarInterval.getMicroseconds)
+ expressions.Literal(interval, CalendarIntervalType)
+
+ case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
+ expressions.Literal(lit.getYearMonthInterval, YearMonthIntervalType())
+
+ case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
+ expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType())
+
+ case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
+ val literals = lit.getArray.getValuesList.asScala.toArray.map(transformLiteral)
+ CreateArray(literals)
+
+ case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
+ val literals = lit.getStruct.getFieldsList.asScala.toArray.map(transformLiteral)
+ CreateStruct(literals)
+
+ case proto.Expression.Literal.LiteralTypeCase.MAP =>
+ val literals = lit.getMap.getPairsList.asScala.toArray.flatMap { pair =>
+ transformLiteral(pair.getKey) :: transformLiteral(pair.getValue) :: Nil
+ }
+ CreateMap(literals)
+
case _ =>
throw InvalidPlanInput(
s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
index 31b572afa21..08f12aa6d08 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
@@ -27,7 +27,7 @@ class ConnectProtoMessagesSuite extends SparkFunSuite {
// Create the extension value.
val lit = proto.Expression
.newBuilder()
- .setLiteral(proto.Expression.Literal.newBuilder().setI32(32).build())
+ .setLiteral(proto.Expression.Literal.newBuilder().setInteger(32).build())
// Pack the extension into Any.
val aval = com.google.protobuf.Any.pack(lit.build())
// Add Any to the repeated field list.
@@ -45,7 +45,7 @@ class ConnectProtoMessagesSuite extends SparkFunSuite {
assert(ext.is(classOf[proto.Expression]))
val extLit = ext.unpack(classOf[proto.Expression])
assert(extLit.hasLiteral)
- assert(extLit.getLiteral.hasI32)
- assert(extLit.getLiteral.getI32 == 32)
+ assert(extLit.getLiteral.hasInteger)
+ assert(extLit.getLiteral.getInteger == 32)
}
}
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 072af389513..bd5fdc29416 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -275,7 +275,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
.setInput(readRel)
.addExpressions(
proto.Expression.newBuilder
- .setLiteral(proto.Expression.Literal.newBuilder.setI32(32))
+ .setLiteral(proto.Expression.Literal.newBuilder.setInteger(32))
.build())
.build()
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index e5cd84fb504..9724c876b1b 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -214,7 +214,7 @@ class SparkConnectServiceSuite extends SharedSparkSession {
.addParts("abs")
.addArguments(proto.Expression
.newBuilder()
- .setLiteral(proto.Expression.Literal.newBuilder().setI32(-1)))))
+ .setLiteral(proto.Expression.Literal.newBuilder().setInteger(-1)))))
.setInput(
proto.Relation
.newBuilder()
diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py
index 69f9fa72db6..663af074925 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -14,13 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import uuid
-from typing import cast, get_args, TYPE_CHECKING, Callable, Any
+from typing import get_args, TYPE_CHECKING, Callable, Any
import json
import decimal
import datetime
+from pyspark.sql.types import TimestampType, DayTimeIntervalType, DateType
+
import pyspark.sql.connect.proto as proto
if TYPE_CHECKING:
@@ -167,66 +168,65 @@ class LiteralExpression(Expression):
TODO(SPARK-40533) This method always assumes the largest type and can thus
create weird interpretations of the literal."""
- value_type = type(self._value)
- exp = proto.Expression()
- if value_type is int:
- exp.literal.i64 = cast(int, self._value)
- elif value_type is bool:
- exp.literal.boolean = cast(bool, self._value)
- elif value_type is str:
- exp.literal.string = cast(str, self._value)
- elif value_type is float:
- exp.literal.fp64 = cast(float, self._value)
- elif value_type is decimal.Decimal:
- d_v = cast(decimal.Decimal, self._value)
- v_tuple = d_v.as_tuple()
- exp.literal.decimal.scale = abs(v_tuple.exponent)
- exp.literal.decimal.precision = len(v_tuple.digits) - abs(v_tuple.exponent)
- # Two complement yeah...
- raise ValueError("Python Decimal not supported.")
- elif value_type is bytes:
- exp.literal.binary = self._value
- elif value_type is datetime.datetime:
- # Microseconds since epoch.
- dt = cast(datetime.datetime, self._value)
- v = dt - datetime.datetime(1970, 1, 1, 0, 0, 0, 0)
- exp.literal.timestamp = int(v / datetime.timedelta(microseconds=1))
- elif value_type is datetime.time:
- # Nanoseconds of the day.
- tv = cast(datetime.time, self._value)
- offset = (tv.second + tv.minute * 60 + tv.hour * 3600) * 1000 + tv.microsecond
- exp.literal.time = int(offset * 1000)
- elif value_type is datetime.date:
- # Days since epoch.
- days_since_epoch = (cast(datetime.date, self._value) - datetime.date(1970, 1, 1)).days
- exp.literal.date = days_since_epoch
- elif value_type is uuid.UUID:
- raise ValueError("Python UUID type not supported.")
- elif value_type is list:
- lv = cast(list, self._value)
- for k in lv:
- if type(k) is LiteralExpression:
- exp.literal.list.values.append(k.to_plan(session).literal)
+ expr = proto.Expression()
+ if self._value is None:
+ expr.literal.null = True
+ elif isinstance(self._value, (bytes, bytearray)):
+ expr.literal.binary = bytes(self._value)
+ elif isinstance(self._value, bool):
+ expr.literal.boolean = bool(self._value)
+ elif isinstance(self._value, int):
+ expr.literal.long = int(self._value)
+ elif isinstance(self._value, float):
+ expr.literal.double = float(self._value)
+ elif isinstance(self._value, str):
+ expr.literal.string = str(self._value)
+ elif isinstance(self._value, decimal.Decimal):
+ expr.literal.decimal.value = str(self._value)
+ expr.literal.decimal.precision = int(decimal.getcontext().prec)
+ elif isinstance(self._value, datetime.datetime):
+ expr.literal.timestamp = TimestampType().toInternal(self._value)
+ elif isinstance(self._value, datetime.date):
+ expr.literal.date = DateType().toInternal(self._value)
+ elif isinstance(self._value, datetime.timedelta):
+ interval = DayTimeIntervalType().toInternal(self._value)
+ assert interval is not None
+ expr.literal.day_time_interval = int(interval)
+ elif isinstance(self._value, list):
+ expr.literal.array.SetInParent()
+ for item in list(self._value):
+ if isinstance(item, LiteralExpression):
+ expr.literal.array.values.append(item.to_plan(session).literal)
else:
- exp.literal.list.values.append(LiteralExpression(k).to_plan(session).literal)
- elif value_type is dict:
- mv = cast(dict, self._value)
- for k in mv:
- kv = proto.Expression.Literal.Map.KeyValue()
- if type(k) is LiteralExpression:
- kv.key.CopyFrom(k.to_plan(session).literal)
+ expr.literal.array.values.append(
+ LiteralExpression(item).to_plan(session).literal
+ )
+ elif isinstance(self._value, tuple):
+ expr.literal.struct.SetInParent()
+ for item in list(self._value):
+ if isinstance(item, LiteralExpression):
+ expr.literal.struct.fields.append(item.to_plan(session).literal)
else:
- kv.key.CopyFrom(LiteralExpression(k).to_plan(session).literal)
-
- if type(mv[k]) is LiteralExpression:
- kv.value.CopyFrom(mv[k].to_plan(session).literal)
+ expr.literal.struct.fields.append(
+ LiteralExpression(item).to_plan(session).literal
+ )
+ elif isinstance(self._value, dict):
+ expr.literal.map.SetInParent()
+ for key, value in dict(self._value).items():
+ pair = proto.Expression.Literal.Map.Pair()
+ if isinstance(key, LiteralExpression):
+ pair.key.CopyFrom(key.to_plan(session).literal)
else:
- kv.value.CopyFrom(LiteralExpression(mv[k]).to_plan(session).literal)
- exp.literal.map.key_values.append(kv)
+ pair.key.CopyFrom(LiteralExpression(key).to_plan(session).literal)
+ if isinstance(value, LiteralExpression):
+ pair.value.CopyFrom(value.to_plan(session).literal)
+ else:
+ pair.value.CopyFrom(LiteralExpression(value).to_plan(session).literal)
+ expr.literal.map.pairs.append(pair)
else:
raise ValueError(f"Could not convert literal for type {type(self._value)}")
- return exp
+ return expr
def __str__(self) -> str:
return f"Literal({self._value})"
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 0f611654ee5..19889cb9eb8 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -963,9 +963,9 @@ class NAFill(LogicalPlan):
if isinstance(v, bool):
value.boolean = v
elif isinstance(v, int):
- value.i64 = v
+ value.long = v
elif isinstance(v, float):
- value.fp64 = v
+ value.double = v
else:
value.string = v
return value
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 5d8441ff65d..afa783742d2 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,25 +34,18 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\x9c\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFu [...]
+ b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xa8\x12\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFu [...]
)
_EXPRESSION = DESCRIPTOR.message_types_by_name["Expression"]
_EXPRESSION_LITERAL = _EXPRESSION.nested_types_by_name["Literal"]
-_EXPRESSION_LITERAL_VARCHAR = _EXPRESSION_LITERAL.nested_types_by_name["VarChar"]
_EXPRESSION_LITERAL_DECIMAL = _EXPRESSION_LITERAL.nested_types_by_name["Decimal"]
-_EXPRESSION_LITERAL_MAP = _EXPRESSION_LITERAL.nested_types_by_name["Map"]
-_EXPRESSION_LITERAL_MAP_KEYVALUE = _EXPRESSION_LITERAL_MAP.nested_types_by_name["KeyValue"]
-_EXPRESSION_LITERAL_INTERVALYEARTOMONTH = _EXPRESSION_LITERAL.nested_types_by_name[
- "IntervalYearToMonth"
-]
-_EXPRESSION_LITERAL_INTERVALDAYTOSECOND = _EXPRESSION_LITERAL.nested_types_by_name[
- "IntervalDayToSecond"
-]
+_EXPRESSION_LITERAL_CALENDARINTERVAL = _EXPRESSION_LITERAL.nested_types_by_name["CalendarInterval"]
_EXPRESSION_LITERAL_STRUCT = _EXPRESSION_LITERAL.nested_types_by_name["Struct"]
-_EXPRESSION_LITERAL_LIST = _EXPRESSION_LITERAL.nested_types_by_name["List"]
-_EXPRESSION_LITERAL_USERDEFINED = _EXPRESSION_LITERAL.nested_types_by_name["UserDefined"]
+_EXPRESSION_LITERAL_ARRAY = _EXPRESSION_LITERAL.nested_types_by_name["Array"]
+_EXPRESSION_LITERAL_MAP = _EXPRESSION_LITERAL.nested_types_by_name["Map"]
+_EXPRESSION_LITERAL_MAP_PAIR = _EXPRESSION_LITERAL_MAP.nested_types_by_name["Pair"]
_EXPRESSION_UNRESOLVEDATTRIBUTE = _EXPRESSION.nested_types_by_name["UnresolvedAttribute"]
_EXPRESSION_UNRESOLVEDFUNCTION = _EXPRESSION.nested_types_by_name["UnresolvedFunction"]
_EXPRESSION_EXPRESSIONSTRING = _EXPRESSION.nested_types_by_name["ExpressionString"]
@@ -66,15 +59,6 @@ Expression = _reflection.GeneratedProtocolMessageType(
"Literal",
(_message.Message,),
{
- "VarChar": _reflection.GeneratedProtocolMessageType(
- "VarChar",
- (_message.Message,),
- {
- "DESCRIPTOR": _EXPRESSION_LITERAL_VARCHAR,
- "__module__": "spark.connect.expressions_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.VarChar)
- },
- ),
"Decimal": _reflection.GeneratedProtocolMessageType(
"Decimal",
(_message.Message,),
@@ -84,40 +68,13 @@ Expression = _reflection.GeneratedProtocolMessageType(
# @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Decimal)
},
),
- "Map": _reflection.GeneratedProtocolMessageType(
- "Map",
+ "CalendarInterval": _reflection.GeneratedProtocolMessageType(
+ "CalendarInterval",
(_message.Message,),
{
- "KeyValue": _reflection.GeneratedProtocolMessageType(
- "KeyValue",
- (_message.Message,),
- {
- "DESCRIPTOR": _EXPRESSION_LITERAL_MAP_KEYVALUE,
- "__module__": "spark.connect.expressions_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Map.KeyValue)
- },
- ),
- "DESCRIPTOR": _EXPRESSION_LITERAL_MAP,
+ "DESCRIPTOR": _EXPRESSION_LITERAL_CALENDARINTERVAL,
"__module__": "spark.connect.expressions_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Map)
- },
- ),
- "IntervalYearToMonth": _reflection.GeneratedProtocolMessageType(
- "IntervalYearToMonth",
- (_message.Message,),
- {
- "DESCRIPTOR": _EXPRESSION_LITERAL_INTERVALYEARTOMONTH,
- "__module__": "spark.connect.expressions_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.IntervalYearToMonth)
- },
- ),
- "IntervalDayToSecond": _reflection.GeneratedProtocolMessageType(
- "IntervalDayToSecond",
- (_message.Message,),
- {
- "DESCRIPTOR": _EXPRESSION_LITERAL_INTERVALDAYTOSECOND,
- "__module__": "spark.connect.expressions_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.IntervalDayToSecond)
+ # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.CalendarInterval)
},
),
"Struct": _reflection.GeneratedProtocolMessageType(
@@ -129,22 +86,31 @@ Expression = _reflection.GeneratedProtocolMessageType(
# @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Struct)
},
),
- "List": _reflection.GeneratedProtocolMessageType(
- "List",
+ "Array": _reflection.GeneratedProtocolMessageType(
+ "Array",
(_message.Message,),
{
- "DESCRIPTOR": _EXPRESSION_LITERAL_LIST,
+ "DESCRIPTOR": _EXPRESSION_LITERAL_ARRAY,
"__module__": "spark.connect.expressions_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.List)
+ # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Array)
},
),
- "UserDefined": _reflection.GeneratedProtocolMessageType(
- "UserDefined",
+ "Map": _reflection.GeneratedProtocolMessageType(
+ "Map",
(_message.Message,),
{
- "DESCRIPTOR": _EXPRESSION_LITERAL_USERDEFINED,
+ "Pair": _reflection.GeneratedProtocolMessageType(
+ "Pair",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _EXPRESSION_LITERAL_MAP_PAIR,
+ "__module__": "spark.connect.expressions_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Map.Pair)
+ },
+ ),
+ "DESCRIPTOR": _EXPRESSION_LITERAL_MAP,
"__module__": "spark.connect.expressions_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.UserDefined)
+ # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Map)
},
),
"DESCRIPTOR": _EXPRESSION_LITERAL,
@@ -204,15 +170,12 @@ Expression = _reflection.GeneratedProtocolMessageType(
)
_sym_db.RegisterMessage(Expression)
_sym_db.RegisterMessage(Expression.Literal)
-_sym_db.RegisterMessage(Expression.Literal.VarChar)
_sym_db.RegisterMessage(Expression.Literal.Decimal)
-_sym_db.RegisterMessage(Expression.Literal.Map)
-_sym_db.RegisterMessage(Expression.Literal.Map.KeyValue)
-_sym_db.RegisterMessage(Expression.Literal.IntervalYearToMonth)
-_sym_db.RegisterMessage(Expression.Literal.IntervalDayToSecond)
+_sym_db.RegisterMessage(Expression.Literal.CalendarInterval)
_sym_db.RegisterMessage(Expression.Literal.Struct)
-_sym_db.RegisterMessage(Expression.Literal.List)
-_sym_db.RegisterMessage(Expression.Literal.UserDefined)
+_sym_db.RegisterMessage(Expression.Literal.Array)
+_sym_db.RegisterMessage(Expression.Literal.Map)
+_sym_db.RegisterMessage(Expression.Literal.Map.Pair)
_sym_db.RegisterMessage(Expression.UnresolvedAttribute)
_sym_db.RegisterMessage(Expression.UnresolvedFunction)
_sym_db.RegisterMessage(Expression.ExpressionString)
@@ -224,35 +187,29 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
_EXPRESSION._serialized_start = 105
- _EXPRESSION._serialized_end = 3077
+ _EXPRESSION._serialized_end = 2449
_EXPRESSION_LITERAL._serialized_start = 613
- _EXPRESSION_LITERAL._serialized_end = 2699
- _EXPRESSION_LITERAL_VARCHAR._serialized_start = 1926
- _EXPRESSION_LITERAL_VARCHAR._serialized_end = 1981
- _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1983
- _EXPRESSION_LITERAL_DECIMAL._serialized_end = 2066
- _EXPRESSION_LITERAL_MAP._serialized_start = 2069
- _EXPRESSION_LITERAL_MAP._serialized_end = 2275
- _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_start = 2155
- _EXPRESSION_LITERAL_MAP_KEYVALUE._serialized_end = 2275
- _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_start = 2277
- _EXPRESSION_LITERAL_INTERVALYEARTOMONTH._serialized_end = 2344
- _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_start = 2346
- _EXPRESSION_LITERAL_INTERVALDAYTOSECOND._serialized_end = 2449
- _EXPRESSION_LITERAL_STRUCT._serialized_start = 2451
- _EXPRESSION_LITERAL_STRUCT._serialized_end = 2518
- _EXPRESSION_LITERAL_LIST._serialized_start = 2520
- _EXPRESSION_LITERAL_LIST._serialized_end = 2585
- _EXPRESSION_LITERAL_USERDEFINED._serialized_start = 2587
- _EXPRESSION_LITERAL_USERDEFINED._serialized_end = 2683
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2701
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2771
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2773
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2872
- _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2874
- _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2924
- _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2926
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2942
- _EXPRESSION_ALIAS._serialized_start = 2944
- _EXPRESSION_ALIAS._serialized_end = 3064
+ _EXPRESSION_LITERAL._serialized_end = 2071
+ _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1509
+ _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1626
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1628
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1726
+ _EXPRESSION_LITERAL_STRUCT._serialized_start = 1728
+ _EXPRESSION_LITERAL_STRUCT._serialized_end = 1795
+ _EXPRESSION_LITERAL_ARRAY._serialized_start = 1797
+ _EXPRESSION_LITERAL_ARRAY._serialized_end = 1863
+ _EXPRESSION_LITERAL_MAP._serialized_start = 1866
+ _EXPRESSION_LITERAL_MAP._serialized_end = 2055
+ _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 1939
+ _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2055
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2073
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2143
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2145
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2244
+ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2246
+ _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2296
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2298
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2314
+ _EXPRESSION_ALIAS._serialized_start = 2316
+ _EXPRESSION_ALIAS._serialized_end = 2436
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 23ef99c6c48..f1c599964bf 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -35,12 +35,11 @@ limitations under the License.
"""
import builtins
import collections.abc
-import google.protobuf.any_pb2
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
-import pyspark.sql.connect.proto.types_pb2
import sys
+import typing
if sys.version_info >= (3, 8):
import typing as typing_extensions
@@ -59,33 +58,14 @@ class Expression(google.protobuf.message.Message):
class Literal(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
- class VarChar(google.protobuf.message.Message):
- DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
- VALUE_FIELD_NUMBER: builtins.int
- LENGTH_FIELD_NUMBER: builtins.int
- value: builtins.str
- length: builtins.int
- def __init__(
- self,
- *,
- value: builtins.str = ...,
- length: builtins.int = ...,
- ) -> None: ...
- def ClearField(
- self, field_name: typing_extensions.Literal["length", b"length", "value", b"value"]
- ) -> None: ...
-
class Decimal(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
VALUE_FIELD_NUMBER: builtins.int
PRECISION_FIELD_NUMBER: builtins.int
SCALE_FIELD_NUMBER: builtins.int
- value: builtins.bytes
- """little-endian twos-complement integer representation of complete value
- (ignoring precision) Always 16 bytes in length
- """
+ value: builtins.str
+ """the string representation."""
precision: builtins.int
"""The maximum number of digits allowed in the value.
the maximum precision is 38.
@@ -95,96 +75,67 @@ class Expression(google.protobuf.message.Message):
def __init__(
self,
*,
- value: builtins.bytes = ...,
- precision: builtins.int = ...,
- scale: builtins.int = ...,
+ value: builtins.str = ...,
+ precision: builtins.int | None = ...,
+ scale: builtins.int | None = ...,
) -> None: ...
- def ClearField(
+ def HasField(
self,
field_name: typing_extensions.Literal[
- "precision", b"precision", "scale", b"scale", "value", b"value"
+ "_precision",
+ b"_precision",
+ "_scale",
+ b"_scale",
+ "precision",
+ b"precision",
+ "scale",
+ b"scale",
],
- ) -> None: ...
-
- class Map(google.protobuf.message.Message):
- DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
- class KeyValue(google.protobuf.message.Message):
- DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
- KEY_FIELD_NUMBER: builtins.int
- VALUE_FIELD_NUMBER: builtins.int
- @property
- def key(self) -> global___Expression.Literal: ...
- @property
- def value(self) -> global___Expression.Literal: ...
- def __init__(
- self,
- *,
- key: global___Expression.Literal | None = ...,
- value: global___Expression.Literal | None = ...,
- ) -> None: ...
- def HasField(
- self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
- ) -> builtins.bool: ...
- def ClearField(
- self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
- ) -> None: ...
-
- KEY_VALUES_FIELD_NUMBER: builtins.int
- @property
- def key_values(
- self,
- ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
- global___Expression.Literal.Map.KeyValue
- ]: ...
- def __init__(
- self,
- *,
- key_values: collections.abc.Iterable[global___Expression.Literal.Map.KeyValue]
- | None = ...,
- ) -> None: ...
+ ) -> builtins.bool: ...
def ClearField(
- self, field_name: typing_extensions.Literal["key_values", b"key_values"]
- ) -> None: ...
-
- class IntervalYearToMonth(google.protobuf.message.Message):
- DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
- YEARS_FIELD_NUMBER: builtins.int
- MONTHS_FIELD_NUMBER: builtins.int
- years: builtins.int
- months: builtins.int
- def __init__(
self,
- *,
- years: builtins.int = ...,
- months: builtins.int = ...,
- ) -> None: ...
- def ClearField(
- self, field_name: typing_extensions.Literal["months", b"months", "years", b"years"]
+ field_name: typing_extensions.Literal[
+ "_precision",
+ b"_precision",
+ "_scale",
+ b"_scale",
+ "precision",
+ b"precision",
+ "scale",
+ b"scale",
+ "value",
+ b"value",
+ ],
) -> None: ...
-
- class IntervalDayToSecond(google.protobuf.message.Message):
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_precision", b"_precision"]
+ ) -> typing_extensions.Literal["precision"] | None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_scale", b"_scale"]
+ ) -> typing_extensions.Literal["scale"] | None: ...
+
+ class CalendarInterval(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
+ MONTHS_FIELD_NUMBER: builtins.int
DAYS_FIELD_NUMBER: builtins.int
- SECONDS_FIELD_NUMBER: builtins.int
MICROSECONDS_FIELD_NUMBER: builtins.int
+ months: builtins.int
days: builtins.int
- seconds: builtins.int
microseconds: builtins.int
def __init__(
self,
*,
+ months: builtins.int = ...,
days: builtins.int = ...,
- seconds: builtins.int = ...,
microseconds: builtins.int = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "days", b"days", "microseconds", b"microseconds", "seconds", b"seconds"
+ "days", b"days", "microseconds", b"microseconds", "months", b"months"
],
) -> None: ...
@@ -208,7 +159,7 @@ class Expression(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["fields", b"fields"]
) -> None: ...
- class List(google.protobuf.message.Message):
+ class Array(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
VALUES_FIELD_NUMBER: builtins.int
@@ -228,106 +179,97 @@ class Expression(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["values", b"values"]
) -> None: ...
- class UserDefined(google.protobuf.message.Message):
+ class Map(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
- TYPE_REFERENCE_FIELD_NUMBER: builtins.int
- VALUE_FIELD_NUMBER: builtins.int
- type_reference: builtins.int
- """points to a type_anchor defined in this plan"""
+ class Pair(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ KEY_FIELD_NUMBER: builtins.int
+ VALUE_FIELD_NUMBER: builtins.int
+ @property
+ def key(self) -> global___Expression.Literal: ...
+ @property
+ def value(self) -> global___Expression.Literal: ...
+ def __init__(
+ self,
+ *,
+ key: global___Expression.Literal | None = ...,
+ value: global___Expression.Literal | None = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
+ ) -> None: ...
+
+ PAIRS_FIELD_NUMBER: builtins.int
@property
- def value(self) -> google.protobuf.any_pb2.Any:
- """the value of the literal, serialized using some type-specific
- protobuf message
- """
+ def pairs(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ global___Expression.Literal.Map.Pair
+ ]: ...
def __init__(
self,
*,
- type_reference: builtins.int = ...,
- value: google.protobuf.any_pb2.Any | None = ...,
+ pairs: collections.abc.Iterable[global___Expression.Literal.Map.Pair] | None = ...,
) -> None: ...
- def HasField(
- self, field_name: typing_extensions.Literal["value", b"value"]
- ) -> builtins.bool: ...
def ClearField(
- self,
- field_name: typing_extensions.Literal[
- "type_reference", b"type_reference", "value", b"value"
- ],
+ self, field_name: typing_extensions.Literal["pairs", b"pairs"]
) -> None: ...
+ NULL_FIELD_NUMBER: builtins.int
+ BINARY_FIELD_NUMBER: builtins.int
BOOLEAN_FIELD_NUMBER: builtins.int
- I8_FIELD_NUMBER: builtins.int
- I16_FIELD_NUMBER: builtins.int
- I32_FIELD_NUMBER: builtins.int
- I64_FIELD_NUMBER: builtins.int
- FP32_FIELD_NUMBER: builtins.int
- FP64_FIELD_NUMBER: builtins.int
+ BYTE_FIELD_NUMBER: builtins.int
+ SHORT_FIELD_NUMBER: builtins.int
+ INTEGER_FIELD_NUMBER: builtins.int
+ LONG_FIELD_NUMBER: builtins.int
+ FLOAT_FIELD_NUMBER: builtins.int
+ DOUBLE_FIELD_NUMBER: builtins.int
+ DECIMAL_FIELD_NUMBER: builtins.int
STRING_FIELD_NUMBER: builtins.int
- BINARY_FIELD_NUMBER: builtins.int
- TIMESTAMP_FIELD_NUMBER: builtins.int
DATE_FIELD_NUMBER: builtins.int
- TIME_FIELD_NUMBER: builtins.int
- INTERVAL_YEAR_TO_MONTH_FIELD_NUMBER: builtins.int
- INTERVAL_DAY_TO_SECOND_FIELD_NUMBER: builtins.int
- FIXED_CHAR_FIELD_NUMBER: builtins.int
- VAR_CHAR_FIELD_NUMBER: builtins.int
- FIXED_BINARY_FIELD_NUMBER: builtins.int
- DECIMAL_FIELD_NUMBER: builtins.int
+ TIMESTAMP_FIELD_NUMBER: builtins.int
+ TIMESTAMP_NTZ_FIELD_NUMBER: builtins.int
+ CALENDAR_INTERVAL_FIELD_NUMBER: builtins.int
+ YEAR_MONTH_INTERVAL_FIELD_NUMBER: builtins.int
+ DAY_TIME_INTERVAL_FIELD_NUMBER: builtins.int
+ ARRAY_FIELD_NUMBER: builtins.int
STRUCT_FIELD_NUMBER: builtins.int
MAP_FIELD_NUMBER: builtins.int
- TIMESTAMP_TZ_FIELD_NUMBER: builtins.int
- UUID_FIELD_NUMBER: builtins.int
- NULL_FIELD_NUMBER: builtins.int
- LIST_FIELD_NUMBER: builtins.int
- EMPTY_ARRAY_FIELD_NUMBER: builtins.int
- EMPTY_MAP_FIELD_NUMBER: builtins.int
- USER_DEFINED_FIELD_NUMBER: builtins.int
NULLABLE_FIELD_NUMBER: builtins.int
TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
+ null: builtins.bool
+ binary: builtins.bytes
boolean: builtins.bool
- i8: builtins.int
- i16: builtins.int
- i32: builtins.int
- i64: builtins.int
- fp32: builtins.float
- fp64: builtins.float
+ byte: builtins.int
+ short: builtins.int
+ integer: builtins.int
+ long: builtins.int
+ float: builtins.float
+ double: builtins.float
+ @property
+ def decimal(self) -> global___Expression.Literal.Decimal: ...
string: builtins.str
- binary: builtins.bytes
- timestamp: builtins.int
- """Timestamp in units of microseconds since the UNIX epoch."""
date: builtins.int
"""Date in units of days since the UNIX epoch."""
- time: builtins.int
- """Time in units of microseconds past midnight"""
- @property
- def interval_year_to_month(self) -> global___Expression.Literal.IntervalYearToMonth: ...
- @property
- def interval_day_to_second(self) -> global___Expression.Literal.IntervalDayToSecond: ...
- fixed_char: builtins.str
+ timestamp: builtins.int
+ """Timestamp in units of microseconds since the UNIX epoch."""
+ timestamp_ntz: builtins.int
+ """Timestamp in units of microseconds since the UNIX epoch (without timezone information)."""
@property
- def var_char(self) -> global___Expression.Literal.VarChar: ...
- fixed_binary: builtins.bytes
+ def calendar_interval(self) -> global___Expression.Literal.CalendarInterval: ...
+ year_month_interval: builtins.int
+ day_time_interval: builtins.int
@property
- def decimal(self) -> global___Expression.Literal.Decimal: ...
+ def array(self) -> global___Expression.Literal.Array: ...
@property
def struct(self) -> global___Expression.Literal.Struct: ...
@property
def map(self) -> global___Expression.Literal.Map: ...
- timestamp_tz: builtins.int
- """Timestamp in units of microseconds since the UNIX epoch."""
- uuid: builtins.bytes
- @property
- def null(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
- """a typed null literal"""
- @property
- def list(self) -> global___Expression.Literal.List: ...
- @property
- def empty_array(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Array: ...
- @property
- def empty_map(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Map: ...
- @property
- def user_defined(self) -> global___Expression.Literal.UserDefined: ...
nullable: builtins.bool
"""whether the literal type should be treated as a nullable type. Applies to
all members of union other than the Typed null (which should directly
@@ -341,192 +283,150 @@ class Expression(google.protobuf.message.Message):
def __init__(
self,
*,
+ null: builtins.bool = ...,
+ binary: builtins.bytes = ...,
boolean: builtins.bool = ...,
- i8: builtins.int = ...,
- i16: builtins.int = ...,
- i32: builtins.int = ...,
- i64: builtins.int = ...,
- fp32: builtins.float = ...,
- fp64: builtins.float = ...,
+ byte: builtins.int = ...,
+ short: builtins.int = ...,
+ integer: builtins.int = ...,
+ long: builtins.int = ...,
+ float: builtins.float = ...,
+ double: builtins.float = ...,
+ decimal: global___Expression.Literal.Decimal | None = ...,
string: builtins.str = ...,
- binary: builtins.bytes = ...,
- timestamp: builtins.int = ...,
date: builtins.int = ...,
- time: builtins.int = ...,
- interval_year_to_month: global___Expression.Literal.IntervalYearToMonth | None = ...,
- interval_day_to_second: global___Expression.Literal.IntervalDayToSecond | None = ...,
- fixed_char: builtins.str = ...,
- var_char: global___Expression.Literal.VarChar | None = ...,
- fixed_binary: builtins.bytes = ...,
- decimal: global___Expression.Literal.Decimal | None = ...,
+ timestamp: builtins.int = ...,
+ timestamp_ntz: builtins.int = ...,
+ calendar_interval: global___Expression.Literal.CalendarInterval | None = ...,
+ year_month_interval: builtins.int = ...,
+ day_time_interval: builtins.int = ...,
+ array: global___Expression.Literal.Array | None = ...,
struct: global___Expression.Literal.Struct | None = ...,
map: global___Expression.Literal.Map | None = ...,
- timestamp_tz: builtins.int = ...,
- uuid: builtins.bytes = ...,
- null: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
- list: global___Expression.Literal.List | None = ...,
- empty_array: pyspark.sql.connect.proto.types_pb2.DataType.Array | None = ...,
- empty_map: pyspark.sql.connect.proto.types_pb2.DataType.Map | None = ...,
- user_defined: global___Expression.Literal.UserDefined | None = ...,
nullable: builtins.bool = ...,
type_variation_reference: builtins.int = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
+ "array",
+ b"array",
"binary",
b"binary",
"boolean",
b"boolean",
+ "byte",
+ b"byte",
+ "calendar_interval",
+ b"calendar_interval",
"date",
b"date",
+ "day_time_interval",
+ b"day_time_interval",
"decimal",
b"decimal",
- "empty_array",
- b"empty_array",
- "empty_map",
- b"empty_map",
- "fixed_binary",
- b"fixed_binary",
- "fixed_char",
- b"fixed_char",
- "fp32",
- b"fp32",
- "fp64",
- b"fp64",
- "i16",
- b"i16",
- "i32",
- b"i32",
- "i64",
- b"i64",
- "i8",
- b"i8",
- "interval_day_to_second",
- b"interval_day_to_second",
- "interval_year_to_month",
- b"interval_year_to_month",
- "list",
- b"list",
+ "double",
+ b"double",
+ "float",
+ b"float",
+ "integer",
+ b"integer",
"literal_type",
b"literal_type",
+ "long",
+ b"long",
"map",
b"map",
"null",
b"null",
+ "short",
+ b"short",
"string",
b"string",
"struct",
b"struct",
- "time",
- b"time",
"timestamp",
b"timestamp",
- "timestamp_tz",
- b"timestamp_tz",
- "user_defined",
- b"user_defined",
- "uuid",
- b"uuid",
- "var_char",
- b"var_char",
+ "timestamp_ntz",
+ b"timestamp_ntz",
+ "year_month_interval",
+ b"year_month_interval",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
+ "array",
+ b"array",
"binary",
b"binary",
"boolean",
b"boolean",
+ "byte",
+ b"byte",
+ "calendar_interval",
+ b"calendar_interval",
"date",
b"date",
+ "day_time_interval",
+ b"day_time_interval",
"decimal",
b"decimal",
- "empty_array",
- b"empty_array",
- "empty_map",
- b"empty_map",
- "fixed_binary",
- b"fixed_binary",
- "fixed_char",
- b"fixed_char",
- "fp32",
- b"fp32",
- "fp64",
- b"fp64",
- "i16",
- b"i16",
- "i32",
- b"i32",
- "i64",
- b"i64",
- "i8",
- b"i8",
- "interval_day_to_second",
- b"interval_day_to_second",
- "interval_year_to_month",
- b"interval_year_to_month",
- "list",
- b"list",
+ "double",
+ b"double",
+ "float",
+ b"float",
+ "integer",
+ b"integer",
"literal_type",
b"literal_type",
+ "long",
+ b"long",
"map",
b"map",
"null",
b"null",
"nullable",
b"nullable",
+ "short",
+ b"short",
"string",
b"string",
"struct",
b"struct",
- "time",
- b"time",
"timestamp",
b"timestamp",
- "timestamp_tz",
- b"timestamp_tz",
+ "timestamp_ntz",
+ b"timestamp_ntz",
"type_variation_reference",
b"type_variation_reference",
- "user_defined",
- b"user_defined",
- "uuid",
- b"uuid",
- "var_char",
- b"var_char",
+ "year_month_interval",
+ b"year_month_interval",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["literal_type", b"literal_type"]
) -> typing_extensions.Literal[
+ "null",
+ "binary",
"boolean",
- "i8",
- "i16",
- "i32",
- "i64",
- "fp32",
- "fp64",
+ "byte",
+ "short",
+ "integer",
+ "long",
+ "float",
+ "double",
+ "decimal",
"string",
- "binary",
- "timestamp",
"date",
- "time",
- "interval_year_to_month",
- "interval_day_to_second",
- "fixed_char",
- "var_char",
- "fixed_binary",
- "decimal",
+ "timestamp",
+ "timestamp_ntz",
+ "calendar_interval",
+ "year_month_interval",
+ "day_time_interval",
+ "array",
"struct",
"map",
- "timestamp_tz",
- "uuid",
- "null",
- "list",
- "empty_array",
- "empty_map",
- "user_defined",
] | None: ...
class UnresolvedAttribute(google.protobuf.message.Message):
diff --git a/python/pyspark/sql/connect/proto/types_pb2.py b/python/pyspark/sql/connect/proto/types_pb2.py
index dd6567d96a2..2fcb56acb4d 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.py
+++ b/python/pyspark/sql/connect/proto/types_pb2.py
@@ -30,7 +30,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xce \n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01(\x0 [...]
+ b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xec\x1d\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01( [...]
)
@@ -51,10 +51,8 @@ _DATATYPE_TIMESTAMPNTZ = _DATATYPE.nested_types_by_name["TimestampNTZ"]
_DATATYPE_CALENDARINTERVAL = _DATATYPE.nested_types_by_name["CalendarInterval"]
_DATATYPE_YEARMONTHINTERVAL = _DATATYPE.nested_types_by_name["YearMonthInterval"]
_DATATYPE_DAYTIMEINTERVAL = _DATATYPE.nested_types_by_name["DayTimeInterval"]
-_DATATYPE_UUID = _DATATYPE.nested_types_by_name["UUID"]
_DATATYPE_CHAR = _DATATYPE.nested_types_by_name["Char"]
_DATATYPE_VARCHAR = _DATATYPE.nested_types_by_name["VarChar"]
-_DATATYPE_FIXEDBINARY = _DATATYPE.nested_types_by_name["FixedBinary"]
_DATATYPE_DECIMAL = _DATATYPE.nested_types_by_name["Decimal"]
_DATATYPE_STRUCTFIELD = _DATATYPE.nested_types_by_name["StructField"]
_DATATYPE_STRUCTFIELD_METADATAENTRY = _DATATYPE_STRUCTFIELD.nested_types_by_name["MetadataEntry"]
@@ -209,15 +207,6 @@ DataType = _reflection.GeneratedProtocolMessageType(
# @@protoc_insertion_point(class_scope:spark.connect.DataType.DayTimeInterval)
},
),
- "UUID": _reflection.GeneratedProtocolMessageType(
- "UUID",
- (_message.Message,),
- {
- "DESCRIPTOR": _DATATYPE_UUID,
- "__module__": "spark.connect.types_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.DataType.UUID)
- },
- ),
"Char": _reflection.GeneratedProtocolMessageType(
"Char",
(_message.Message,),
@@ -236,15 +225,6 @@ DataType = _reflection.GeneratedProtocolMessageType(
# @@protoc_insertion_point(class_scope:spark.connect.DataType.VarChar)
},
),
- "FixedBinary": _reflection.GeneratedProtocolMessageType(
- "FixedBinary",
- (_message.Message,),
- {
- "DESCRIPTOR": _DATATYPE_FIXEDBINARY,
- "__module__": "spark.connect.types_pb2"
- # @@protoc_insertion_point(class_scope:spark.connect.DataType.FixedBinary)
- },
- ),
"Decimal": _reflection.GeneratedProtocolMessageType(
"Decimal",
(_message.Message,),
@@ -321,10 +301,8 @@ _sym_db.RegisterMessage(DataType.TimestampNTZ)
_sym_db.RegisterMessage(DataType.CalendarInterval)
_sym_db.RegisterMessage(DataType.YearMonthInterval)
_sym_db.RegisterMessage(DataType.DayTimeInterval)
-_sym_db.RegisterMessage(DataType.UUID)
_sym_db.RegisterMessage(DataType.Char)
_sym_db.RegisterMessage(DataType.VarChar)
-_sym_db.RegisterMessage(DataType.FixedBinary)
_sym_db.RegisterMessage(DataType.Decimal)
_sym_db.RegisterMessage(DataType.StructField)
_sym_db.RegisterMessage(DataType.StructField.MetadataEntry)
@@ -339,57 +317,53 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_DATATYPE_STRUCTFIELD_METADATAENTRY._options = None
_DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_options = b"8\001"
_DATATYPE._serialized_start = 45
- _DATATYPE._serialized_end = 4219
- _DATATYPE_BOOLEAN._serialized_start = 1612
- _DATATYPE_BOOLEAN._serialized_end = 1679
- _DATATYPE_BYTE._serialized_start = 1681
- _DATATYPE_BYTE._serialized_end = 1745
- _DATATYPE_SHORT._serialized_start = 1747
- _DATATYPE_SHORT._serialized_end = 1812
- _DATATYPE_INTEGER._serialized_start = 1814
- _DATATYPE_INTEGER._serialized_end = 1881
- _DATATYPE_LONG._serialized_start = 1883
- _DATATYPE_LONG._serialized_end = 1947
- _DATATYPE_FLOAT._serialized_start = 1949
- _DATATYPE_FLOAT._serialized_end = 2014
- _DATATYPE_DOUBLE._serialized_start = 2016
- _DATATYPE_DOUBLE._serialized_end = 2082
- _DATATYPE_STRING._serialized_start = 2084
- _DATATYPE_STRING._serialized_end = 2150
- _DATATYPE_BINARY._serialized_start = 2152
- _DATATYPE_BINARY._serialized_end = 2218
- _DATATYPE_NULL._serialized_start = 2220
- _DATATYPE_NULL._serialized_end = 2284
- _DATATYPE_TIMESTAMP._serialized_start = 2286
- _DATATYPE_TIMESTAMP._serialized_end = 2355
- _DATATYPE_DATE._serialized_start = 2357
- _DATATYPE_DATE._serialized_end = 2421
- _DATATYPE_TIMESTAMPNTZ._serialized_start = 2423
- _DATATYPE_TIMESTAMPNTZ._serialized_end = 2495
- _DATATYPE_CALENDARINTERVAL._serialized_start = 2497
- _DATATYPE_CALENDARINTERVAL._serialized_end = 2573
- _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2576
- _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2755
- _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2758
- _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2935
- _DATATYPE_UUID._serialized_start = 2937
- _DATATYPE_UUID._serialized_end = 3001
- _DATATYPE_CHAR._serialized_start = 3003
- _DATATYPE_CHAR._serialized_end = 3091
- _DATATYPE_VARCHAR._serialized_start = 3093
- _DATATYPE_VARCHAR._serialized_end = 3184
- _DATATYPE_FIXEDBINARY._serialized_start = 3186
- _DATATYPE_FIXEDBINARY._serialized_end = 3281
- _DATATYPE_DECIMAL._serialized_start = 3284
- _DATATYPE_DECIMAL._serialized_end = 3437
- _DATATYPE_STRUCTFIELD._serialized_start = 3440
- _DATATYPE_STRUCTFIELD._serialized_end = 3695
- _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_start = 3636
- _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_end = 3695
- _DATATYPE_STRUCT._serialized_start = 3697
- _DATATYPE_STRUCT._serialized_end = 3824
- _DATATYPE_ARRAY._serialized_start = 3827
- _DATATYPE_ARRAY._serialized_end = 3989
- _DATATYPE_MAP._serialized_start = 3992
- _DATATYPE_MAP._serialized_end = 4211
+ _DATATYPE._serialized_end = 3865
+ _DATATYPE_BOOLEAN._serialized_start = 1421
+ _DATATYPE_BOOLEAN._serialized_end = 1488
+ _DATATYPE_BYTE._serialized_start = 1490
+ _DATATYPE_BYTE._serialized_end = 1554
+ _DATATYPE_SHORT._serialized_start = 1556
+ _DATATYPE_SHORT._serialized_end = 1621
+ _DATATYPE_INTEGER._serialized_start = 1623
+ _DATATYPE_INTEGER._serialized_end = 1690
+ _DATATYPE_LONG._serialized_start = 1692
+ _DATATYPE_LONG._serialized_end = 1756
+ _DATATYPE_FLOAT._serialized_start = 1758
+ _DATATYPE_FLOAT._serialized_end = 1823
+ _DATATYPE_DOUBLE._serialized_start = 1825
+ _DATATYPE_DOUBLE._serialized_end = 1891
+ _DATATYPE_STRING._serialized_start = 1893
+ _DATATYPE_STRING._serialized_end = 1959
+ _DATATYPE_BINARY._serialized_start = 1961
+ _DATATYPE_BINARY._serialized_end = 2027
+ _DATATYPE_NULL._serialized_start = 2029
+ _DATATYPE_NULL._serialized_end = 2093
+ _DATATYPE_TIMESTAMP._serialized_start = 2095
+ _DATATYPE_TIMESTAMP._serialized_end = 2164
+ _DATATYPE_DATE._serialized_start = 2166
+ _DATATYPE_DATE._serialized_end = 2230
+ _DATATYPE_TIMESTAMPNTZ._serialized_start = 2232
+ _DATATYPE_TIMESTAMPNTZ._serialized_end = 2304
+ _DATATYPE_CALENDARINTERVAL._serialized_start = 2306
+ _DATATYPE_CALENDARINTERVAL._serialized_end = 2382
+ _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2385
+ _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2564
+ _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2567
+ _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2744
+ _DATATYPE_CHAR._serialized_start = 2746
+ _DATATYPE_CHAR._serialized_end = 2834
+ _DATATYPE_VARCHAR._serialized_start = 2836
+ _DATATYPE_VARCHAR._serialized_end = 2927
+ _DATATYPE_DECIMAL._serialized_start = 2930
+ _DATATYPE_DECIMAL._serialized_end = 3083
+ _DATATYPE_STRUCTFIELD._serialized_start = 3086
+ _DATATYPE_STRUCTFIELD._serialized_end = 3341
+ _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_start = 3282
+ _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_end = 3341
+ _DATATYPE_STRUCT._serialized_start = 3343
+ _DATATYPE_STRUCT._serialized_end = 3470
+ _DATATYPE_ARRAY._serialized_start = 3473
+ _DATATYPE_ARRAY._serialized_end = 3635
+ _DATATYPE_MAP._serialized_start = 3638
+ _DATATYPE_MAP._serialized_end = 3857
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/types_pb2.pyi b/python/pyspark/sql/connect/proto/types_pb2.pyi
index 647f625659b..72736301f88 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/types_pb2.pyi
@@ -399,23 +399,6 @@ class DataType(google.protobuf.message.Message):
self, oneof_group: typing_extensions.Literal["_start_field", b"_start_field"]
) -> typing_extensions.Literal["start_field"] | None: ...
- class UUID(google.protobuf.message.Message):
- DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
- TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
- type_variation_reference: builtins.int
- def __init__(
- self,
- *,
- type_variation_reference: builtins.int = ...,
- ) -> None: ...
- def ClearField(
- self,
- field_name: typing_extensions.Literal[
- "type_variation_reference", b"type_variation_reference"
- ],
- ) -> None: ...
-
class Char(google.protobuf.message.Message):
"""Start compound types."""
@@ -458,26 +441,6 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...
- class FixedBinary(google.protobuf.message.Message):
- DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
- LENGTH_FIELD_NUMBER: builtins.int
- TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
- length: builtins.int
- type_variation_reference: builtins.int
- def __init__(
- self,
- *,
- length: builtins.int = ...,
- type_variation_reference: builtins.int = ...,
- ) -> None: ...
- def ClearField(
- self,
- field_name: typing_extensions.Literal[
- "length", b"length", "type_variation_reference", b"type_variation_reference"
- ],
- ) -> None: ...
-
class Decimal(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -708,9 +671,6 @@ class DataType(google.protobuf.message.Message):
ARRAY_FIELD_NUMBER: builtins.int
STRUCT_FIELD_NUMBER: builtins.int
MAP_FIELD_NUMBER: builtins.int
- UUID_FIELD_NUMBER: builtins.int
- FIXED_BINARY_FIELD_NUMBER: builtins.int
- USER_DEFINED_TYPE_REFERENCE_FIELD_NUMBER: builtins.int
@property
def null(self) -> global___DataType.NULL: ...
@property
@@ -760,11 +720,6 @@ class DataType(google.protobuf.message.Message):
def struct(self) -> global___DataType.Struct: ...
@property
def map(self) -> global___DataType.Map: ...
- @property
- def uuid(self) -> global___DataType.UUID: ...
- @property
- def fixed_binary(self) -> global___DataType.FixedBinary: ...
- user_defined_type_reference: builtins.int
def __init__(
self,
*,
@@ -790,9 +745,6 @@ class DataType(google.protobuf.message.Message):
array: global___DataType.Array | None = ...,
struct: global___DataType.Struct | None = ...,
map: global___DataType.Map | None = ...,
- uuid: global___DataType.UUID | None = ...,
- fixed_binary: global___DataType.FixedBinary | None = ...,
- user_defined_type_reference: builtins.int = ...,
) -> None: ...
def HasField(
self,
@@ -817,8 +769,6 @@ class DataType(google.protobuf.message.Message):
b"decimal",
"double",
b"double",
- "fixed_binary",
- b"fixed_binary",
"float",
b"float",
"integer",
@@ -841,10 +791,6 @@ class DataType(google.protobuf.message.Message):
b"timestamp",
"timestamp_ntz",
b"timestamp_ntz",
- "user_defined_type_reference",
- b"user_defined_type_reference",
- "uuid",
- b"uuid",
"var_char",
b"var_char",
"year_month_interval",
@@ -874,8 +820,6 @@ class DataType(google.protobuf.message.Message):
b"decimal",
"double",
b"double",
- "fixed_binary",
- b"fixed_binary",
"float",
b"float",
"integer",
@@ -898,10 +842,6 @@ class DataType(google.protobuf.message.Message):
b"timestamp",
"timestamp_ntz",
b"timestamp_ntz",
- "user_defined_type_reference",
- b"user_defined_type_reference",
- "uuid",
- b"uuid",
"var_char",
b"var_char",
"year_month_interval",
@@ -933,9 +873,6 @@ class DataType(google.protobuf.message.Message):
"array",
"struct",
"map",
- "uuid",
- "fixed_binary",
- "user_defined_type_reference",
] | None: ...
global___DataType = DataType
diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
index 03966bd28df..9e2ff8290d6 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
@@ -51,6 +51,11 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
self.assertEqual(cp1, cp2)
self.assertEqual(cp2, cp3)
+ def test_null_literal(self):
+ null_lit = fun.lit(None)
+ null_lit_p = null_lit.to_plan(None)
+ self.assertEqual(null_lit_p.literal.null, True)
+
def test_binary_literal(self):
val = b"binary\0\0asas"
bin_lit = fun.lit(val)
@@ -61,15 +66,15 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
val = {"this": "is", 12: [12, 32, 43]}
map_lit = fun.lit(val)
map_lit_p = map_lit.to_plan(None)
- self.assertEqual(2, len(map_lit_p.literal.map.key_values))
- self.assertEqual("this", map_lit_p.literal.map.key_values[0].key.string)
- self.assertEqual(12, map_lit_p.literal.map.key_values[1].key.i64)
+ self.assertEqual(2, len(map_lit_p.literal.map.pairs))
+ self.assertEqual("this", map_lit_p.literal.map.pairs[0].key.string)
+ self.assertEqual(12, map_lit_p.literal.map.pairs[1].key.long)
val = {"this": fun.lit("is"), 12: [12, 32, 43]}
map_lit = fun.lit(val)
map_lit_p = map_lit.to_plan(None)
- self.assertEqual(2, len(map_lit_p.literal.map.key_values))
- self.assertEqual("is", map_lit_p.literal.map.key_values[0].value.string)
+ self.assertEqual(2, len(map_lit_p.literal.map.pairs))
+ self.assertEqual("is", map_lit_p.literal.map.pairs[0].value.string)
def test_uuid_literal(self):
val = uuid.uuid4()
@@ -84,22 +89,29 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
self.assertIsNotNone(fun.lit(10).to_plan(None))
plan = fun.lit(10).to_plan(None)
- self.assertIs(plan.literal.i64, 10)
+ self.assertIs(plan.literal.long, 10)
def test_numeric_literal_types(self):
int_lit = fun.lit(10)
float_lit = fun.lit(10.1)
decimal_lit = fun.lit(decimal.Decimal(99))
- # Decimal is not supported yet.
- with self.assertRaises(ValueError):
- self.assertIsNotNone(decimal_lit.to_plan(None))
-
self.assertIsNotNone(int_lit.to_plan(None))
self.assertIsNotNone(float_lit.to_plan(None))
+ self.assertIsNotNone(decimal_lit.to_plan(None))
+
+ def test_float_nan_inf(self):
+ na_lit = fun.lit(float("nan"))
+ self.assertIsNotNone(na_lit.to_plan(None))
+
+ inf_lit = fun.lit(float("inf"))
+ self.assertIsNotNone(inf_lit.to_plan(None))
+
+ inf_lit = fun.lit(float("-inf"))
+ self.assertIsNotNone(inf_lit.to_plan(None))
def test_datetime_literal_types(self):
- """Test the different timestamp, date, and time types."""
+ """Test the different timestamp, date, and timedelta types."""
datetime_lit = fun.lit(datetime.datetime.now())
p = datetime_lit.to_plan(None)
@@ -107,10 +119,12 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
self.assertGreater(p.literal.timestamp, 10000000000000)
date_lit = fun.lit(datetime.date.today())
- time_lit = fun.lit(datetime.time())
+ time_delta = fun.lit(datetime.timedelta(days=1, seconds=2, microseconds=3))
self.assertIsNotNone(date_lit.to_plan(None))
- self.assertIsNotNone(time_lit.to_plan(None))
+ self.assertIsNotNone(time_delta.to_plan(None))
+ # (24 * 3600 + 2) * 1000000 + 3
+ self.assertEqual(86402000003, time_delta.to_plan(None).literal.day_time_interval)
def test_list_to_literal(self):
"""Test conversion of lists to literals"""
@@ -134,6 +148,37 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
lit_list_plan = fun.lit([fun.lit(10), fun.lit("str")]).to_plan(None)
self.assertIsNotNone(lit_list_plan)
+ def test_tuple_to_literal(self):
+ """Test conversion of tuples to struct literals"""
+ t0 = ()
+ t1 = (1.0,)
+ t2 = (1, "xyz")
+ t3 = (1, "abc", (3.5, True, None))
+
+ p0 = fun.lit(t0).to_plan(None)
+ self.assertIsNotNone(p0)
+ self.assertTrue(p0.literal.HasField("struct"))
+
+ p1 = fun.lit(t1).to_plan(None)
+ self.assertIsNotNone(p1)
+ self.assertTrue(p1.literal.HasField("struct"))
+ self.assertEqual(p1.literal.struct.fields[0].double, 1.0)
+
+ p2 = fun.lit(t2).to_plan(None)
+ self.assertIsNotNone(p2)
+ self.assertTrue(p2.literal.HasField("struct"))
+ self.assertEqual(p2.literal.struct.fields[0].long, 1)
+ self.assertEqual(p2.literal.struct.fields[1].string, "xyz")
+
+ p3 = fun.lit(t3).to_plan(None)
+ self.assertIsNotNone(p3)
+ self.assertTrue(p3.literal.HasField("struct"))
+ self.assertEqual(p3.literal.struct.fields[0].long, 1)
+ self.assertEqual(p3.literal.struct.fields[1].string, "abc")
+ self.assertEqual(p3.literal.struct.fields[2].struct.fields[0].double, 3.5)
+ self.assertEqual(p3.literal.struct.fields[2].struct.fields[1].boolean, True)
+ self.assertEqual(p3.literal.struct.fields[2].struct.fields[2].null, True)
+
def test_column_alias(self) -> None:
# SPARK-40809: Support for Column Aliases
col0 = fun.col("a").alias("martin")
@@ -162,7 +207,7 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
lit_fun = expr_plan.unresolved_function.arguments[1]
self.assertIsInstance(lit_fun, ProtoExpression)
self.assertIsInstance(lit_fun.literal, ProtoExpression.Literal)
- self.assertEqual(lit_fun.literal.i64, 10)
+ self.assertEqual(lit_fun.literal.long, 10)
mod_fun = expr_plan.unresolved_function.arguments[0]
self.assertIsInstance(mod_fun, ProtoExpression)
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 367c514d0a8..3cae4cefa05 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -76,7 +76,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
plan = df.fillna(value=1)._plan.to_proto(self.connect)
self.assertEqual(len(plan.root.fill_na.values), 1)
- self.assertEqual(plan.root.fill_na.values[0].i64, 1)
+ self.assertEqual(plan.root.fill_na.values[0].long, 1)
self.assertEqual(plan.root.fill_na.cols, [])
plan = df.na.fill(value="xyz")._plan.to_proto(self.connect)
@@ -98,7 +98,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
plan = df.fillna({"col_a": 1.5, "col_b": "abc"})._plan.to_proto(self.connect)
self.assertEqual(len(plan.root.fill_na.values), 2)
- self.assertEqual(plan.root.fill_na.values[0].fp64, 1.5)
+ self.assertEqual(plan.root.fill_na.values[0].double, 1.5)
self.assertEqual(plan.root.fill_na.values[1].string, "abc")
self.assertEqual(plan.root.fill_na.cols, ["col_a", "col_b"])
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org