You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/03/06 02:19:18 UTC
[spark] branch branch-3.4 updated: [SPARK-41527][CONNECT][PYTHON] Implement `DataFrame.observe`
This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 7d8ae50f099 [SPARK-41527][CONNECT][PYTHON] Implement `DataFrame.observe`
7d8ae50f099 is described below
commit 7d8ae50f099366ccf9ecbac12fd4d24da85bf6dc
Author: Jiaan Geng <be...@163.com>
AuthorDate: Sun Mar 5 22:18:09 2023 -0400
[SPARK-41527][CONNECT][PYTHON] Implement `DataFrame.observe`
### What changes were proposed in this pull request?
Implement `DataFrame.observe` with a proto message
Implement `DataFrame.observe` for scala API
Implement `DataFrame.observe` for python API
### Why are the changes needed?
for Connect API coverage
### Does this PR introduce _any_ user-facing change?
'No'. New API
### How was this patch tested?
New test cases.
Closes #39091 from beliefer/SPARK-41527.
Authored-by: Jiaan Geng <be...@163.com>
Signed-off-by: Herman van Hovell <he...@databricks.com>
(cherry picked from commit 0ce63f36384e135ebfb21b927f6cc69361919b47)
Signed-off-by: Herman van Hovell <he...@databricks.com>
---
.../src/main/protobuf/spark/connect/base.proto | 9 +
.../main/protobuf/spark/connect/relations.proto | 12 ++
.../org/apache/spark/sql/connect/dsl/package.scala | 40 +++-
.../sql/connect/planner/SparkConnectPlanner.scala | 12 +-
.../service/SparkConnectStreamHandler.scala | 25 ++-
.../connect/planner/SparkConnectProtoSuite.scala | 63 +++++-
.../connect/planner/SparkConnectServiceSuite.scala | 69 +++++++
python/pyspark/sql/connect/client.py | 47 ++++-
python/pyspark/sql/connect/dataframe.py | 32 ++-
python/pyspark/sql/connect/plan.py | 29 +++
python/pyspark/sql/connect/proto/base_pb2.py | 208 ++++++++++---------
python/pyspark/sql/connect/proto/base_pb2.pyi | 38 ++++
python/pyspark/sql/connect/proto/relations_pb2.py | 226 +++++++++++----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 49 +++++
.../sql/tests/connect/test_connect_basic.py | 53 ++++-
.../pyspark/sql/tests/connect/test_connect_plan.py | 64 +++++-
.../scala/org/apache/spark/sql/Observation.scala | 2 +-
17 files changed, 759 insertions(+), 219 deletions(-)
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 066a63d58ba..09407a99119 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -21,6 +21,7 @@ package spark.connect;
import "google/protobuf/any.proto";
import "spark/connect/commands.proto";
+import "spark/connect/expressions.proto";
import "spark/connect/relations.proto";
import "spark/connect/types.proto";
@@ -236,6 +237,9 @@ message ExecutePlanResponse {
// batch of results and then represent the overall state of the query execution.
Metrics metrics = 4;
+ // The metrics observed during the execution of the query plan.
+ repeated ObservedMetrics observed_metrics = 6;
+
// A SQL command returns an opaque Relation that can be directly used as input for the next
// call.
message SqlCommandResult {
@@ -265,6 +269,11 @@ message ExecutePlanResponse {
string metric_type = 3;
}
}
+
+ message ObservedMetrics {
+ string name = 1;
+ repeated Expression.Literal values = 2;
+ }
}
// The key-value pair for the config request and response.
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 3e4a4daeb36..2230aada4fe 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -61,6 +61,7 @@ message Relation {
ToSchema to_schema = 26;
RepartitionByExpression repartition_by_expression = 27;
FrameMap frame_map = 28;
+ CollectMetrics collect_metrics = 29;
// NA functions
NAFill fill_na = 90;
@@ -781,3 +782,14 @@ message FrameMap {
CommonInlineUserDefinedFunction func = 2;
}
+// Collect arbitrary (named) metrics from a dataset.
+message CollectMetrics {
+ // (Required) The input relation.
+ Relation input = 1;
+
+ // (Required) Name of the metrics.
+ string name = 2;
+
+ // (Required) The metric sequence.
+ repeated Expression metrics = 3;
+}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 840f43abf49..7e60c5f9a28 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -24,7 +24,7 @@ import org.apache.spark.connect.proto._
import org.apache.spark.connect.proto.Expression.ExpressionString
import org.apache.spark.connect.proto.Join.JoinType
import org.apache.spark.connect.proto.SetOperation.SetOpType
-import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.{Observation, SaveMode}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.planner.{SaveModeConverter, TableSaveMethodConverter}
import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
@@ -141,6 +141,20 @@ package object dsl {
Expression.UnresolvedFunction.newBuilder().setFunctionName("min").addArguments(e))
.build()
+ def proto_max(e: Expression): Expression =
+ Expression
+ .newBuilder()
+ .setUnresolvedFunction(
+ Expression.UnresolvedFunction.newBuilder().setFunctionName("max").addArguments(e))
+ .build()
+
+ def proto_sum(e: Expression): Expression =
+ Expression
+ .newBuilder()
+ .setUnresolvedFunction(
+ Expression.UnresolvedFunction.newBuilder().setFunctionName("sum").addArguments(e))
+ .build()
+
def proto_explode(e: Expression): Expression =
Expression
.newBuilder()
@@ -1061,6 +1075,30 @@ package object dsl {
def randomSplit(weights: Array[Double]): Array[Relation] =
randomSplit(weights, Utils.random.nextLong)
+ def observe(name: String, expr: Expression, exprs: Expression*): Relation = {
+ Relation
+ .newBuilder()
+ .setCollectMetrics(
+ CollectMetrics
+ .newBuilder()
+ .setInput(logicalPlan)
+ .setName(name)
+ .addAllMetrics((expr +: exprs).asJava))
+ .build()
+ }
+
+ def observe(observation: Observation, expr: Expression, exprs: Expression*): Relation = {
+ Relation
+ .newBuilder()
+ .setCollectMetrics(
+ CollectMetrics
+ .newBuilder()
+ .setInput(logicalPlan)
+ .setName(observation.name)
+ .addAllMetrics((expr +: exprs).asJava))
+ .build()
+ }
+
private def createSetOperation(
left: Relation,
right: Relation,
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 76a4c7faaa2..60fb94e8098 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, UdfPacket}
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
@@ -113,6 +113,8 @@ class SparkConnectPlanner(val session: SparkSession) {
transformRepartitionByExpression(rel.getRepartitionByExpression)
case proto.Relation.RelTypeCase.FRAME_MAP =>
transformFrameMap(rel.getFrameMap)
+ case proto.Relation.RelTypeCase.COLLECT_METRICS =>
+ transformCollectMetrics(rel.getCollectMetrics)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
@@ -571,6 +573,14 @@ class SparkConnectPlanner(val session: SparkSession) {
numPartitionsOpt)
}
+ private def transformCollectMetrics(rel: proto.CollectMetrics): LogicalPlan = {
+ val metrics = rel.getMetricsList.asScala.toSeq.map { expr =>
+ Column(transformExpression(expr))
+ }
+
+ CollectMetrics(rel.getName, metrics.map(_.named), transformRelation(rel.getInput))
+ }
+
private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
if (!rel.hasInput) {
throw InvalidPlanInput("Deduplicate needs a plan input")
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 41ca564e6d3..15d5a981ae8 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
+import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.service.SparkConnectStreamHandler.processAsArrowBatches
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
@@ -62,6 +63,10 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
processAsArrowBatches(request.getClientId, dataframe, responseObserver)
responseObserver.onNext(
SparkConnectStreamHandler.sendMetricsToResponse(request.getClientId, dataframe))
+ if (dataframe.queryExecution.observedMetrics.nonEmpty) {
+ responseObserver.onNext(
+ SparkConnectStreamHandler.sendObservedMetricsToResponse(request.getClientId, dataframe))
+ }
responseObserver.onCompleted()
}
@@ -206,6 +211,25 @@ object SparkConnectStreamHandler {
.setMetrics(MetricGenerator.buildMetrics(rows.queryExecution.executedPlan))
.build()
}
+
+ def sendObservedMetricsToResponse(
+ clientId: String,
+ dataframe: DataFrame): ExecutePlanResponse = {
+ val observedMetrics = dataframe.queryExecution.observedMetrics.map { case (name, row) =>
+ val cols = (0 until row.length).map(i => toConnectProtoValue(row(i)))
+ ExecutePlanResponse.ObservedMetrics
+ .newBuilder()
+ .setName(name)
+ .addAllValues(cols.asJava)
+ .build()
+ }
+ // Prepare a response with the observed metrics.
+ ExecutePlanResponse
+ .newBuilder()
+ .setClientId(clientId)
+ .addAllObservedMetrics(observedMetrics.asJava)
+ .build()
+ }
}
object MetricGenerator extends AdaptiveSparkPlanHelper {
@@ -242,5 +266,4 @@ object MetricGenerator extends AdaptiveSparkPlanHelper {
.build()
Seq(mo) ++ transformChildren(p)
}
-
}
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 7fb28f354ce..241b3dcb825 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.SparkClassNotFoundException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Expression
import org.apache.spark.connect.proto.Join.JoinType
-import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row, SaveMode}
+import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Observation, Row, SaveMode}
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter}
@@ -940,6 +940,67 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan1, sparkPlan1)
}
+ test("Test observe") {
+ val connectPlan0 =
+ connectTestRelation.observe(
+ "my_metric",
+ proto_min("id".protoAttr).as("min_val"),
+ proto_max("id".protoAttr).as("max_val"),
+ proto_sum("id".protoAttr))
+ val sparkPlan0 =
+ sparkTestRelation.observe(
+ "my_metric",
+ min(Column("id")).as("min_val"),
+ max(Column("id")).as("max_val"),
+ sum(Column("id")))
+ comparePlans(connectPlan0, sparkPlan0)
+
+ val connectPlan1 =
+ connectTestRelation.observe("my_metric", proto_min("id".protoAttr).as("min_val"))
+ val sparkPlan1 =
+ sparkTestRelation.observe("my_metric", min(Column("id")).as("min_val"))
+ comparePlans(connectPlan1, sparkPlan1)
+
+ checkError(
+ exception = intercept[AnalysisException] {
+ analyzePlan(
+ transform(connectTestRelation.observe("my_metric", "id".protoAttr.cast("string"))))
+ },
+ errorClass = "_LEGACY_ERROR_TEMP_2322",
+ parameters = Map("sqlExpr" -> "CAST(id AS STRING) AS id"))
+
+ val connectPlan2 =
+ connectTestRelation.observe(
+ Observation("my_metric"),
+ proto_min("id".protoAttr).as("min_val"),
+ proto_max("id".protoAttr).as("max_val"),
+ proto_sum("id".protoAttr))
+ val sparkPlan2 =
+ sparkTestRelation.observe(
+ Observation("my_metric"),
+ min(Column("id")).as("min_val"),
+ max(Column("id")).as("max_val"),
+ sum(Column("id")))
+ comparePlans(connectPlan2, sparkPlan2)
+
+ val connectPlan3 =
+ connectTestRelation.observe(
+ Observation("my_metric"),
+ proto_min("id".protoAttr).as("min_val"))
+ val sparkPlan3 =
+ sparkTestRelation.observe(Observation("my_metric"), min(Column("id")).as("min_val"))
+ comparePlans(connectPlan3, sparkPlan3)
+
+ checkError(
+ exception = intercept[AnalysisException] {
+ analyzePlan(
+ transform(
+ connectTestRelation.observe(Observation("my_metric"), "id".protoAttr.cast("string"))))
+ },
+ errorClass = "_LEGACY_ERROR_TEMP_2322",
+ parameters = Map("sqlExpr" -> "CAST(id AS STRING) AS id"))
+ }
+
test("Test RandomSplit") {
val splitRelations0 = connectTestRelation.randomSplit(Array[Double](1, 2, 3), 1)
val splits0 = sparkTestRelation.randomSplit(Array[Double](1, 2, 3), 1)
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 4d9bea88f5c..2885d0035bc 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.connect.planner
+import scala.collection.JavaConverters._
import scala.collection.mutable
import io.grpc.StatusRuntimeException
@@ -26,6 +27,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.dsl.MockRemoteSession
+import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.connect.service.{SparkConnectAnalyzeHandler, SparkConnectService}
import org.apache.spark.sql.test.SharedSparkSession
@@ -293,4 +295,71 @@ class SparkConnectServiceSuite extends SharedSparkSession {
assert(response.getExplain.getExplainString.contains("Physical Plan"))
}
}
+
+ test("Test observe response") {
+ withTable("test") {
+ spark.sql("""
+ | CREATE TABLE test (col1 INT, col2 STRING)
+ | USING parquet
+ |""".stripMargin)
+
+ val instance = new SparkConnectService(false)
+
+ val connect = new MockRemoteSession()
+ val context = proto.UserContext
+ .newBuilder()
+ .setUserId("c1")
+ .build()
+ val collectMetrics = proto.Relation
+ .newBuilder()
+ .setCollectMetrics(
+ proto.CollectMetrics
+ .newBuilder()
+ .setInput(connect.sql("select id, exp(id) as eid from range(0, 100, 1, 4)"))
+ .setName("my_metric")
+ .addAllMetrics(Seq(
+ proto_min("id".protoAttr).as("min_val"),
+ proto_max("id".protoAttr).as("max_val")).asJava))
+ .build()
+ val plan = proto.Plan
+ .newBuilder()
+ .setRoot(collectMetrics)
+ .build()
+ val request = proto.ExecutePlanRequest
+ .newBuilder()
+ .setPlan(plan)
+ .setUserContext(context)
+ .build()
+
+ // Execute plan.
+ @volatile var done = false
+ val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
+ instance.executePlan(
+ request,
+ new StreamObserver[proto.ExecutePlanResponse] {
+ override def onNext(v: proto.ExecutePlanResponse): Unit = responses += v
+
+ override def onError(throwable: Throwable): Unit = throw throwable
+
+ override def onCompleted(): Unit = done = true
+ })
+
+ // The current implementation is expected to be blocking. This is here to make sure it is.
+ assert(done)
+
+ assert(responses.size == 6)
+
+ // Make sure the last response is observed metrics only
+ val last = responses.last
+ assert(last.getObservedMetricsCount == 1 && !last.hasArrowBatch)
+
+ val observedMetricsList = last.getObservedMetricsList.asScala
+ val observedMetric = observedMetricsList.head
+ assert(observedMetric.getName == "my_metric")
+ assert(observedMetric.getValuesCount == 2)
+ val valuesList = observedMetric.getValuesList.asScala
+ assert(valuesList.head.hasLong && valuesList.head.getLong == 0)
+ assert(valuesList.last.hasLong && valuesList.last.getLong == 99)
+ }
+ }
}
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 3f70ca6ad15..8be0c997538 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -373,6 +373,23 @@ class PlanMetrics:
return self._metrics
+class PlanObservedMetrics:
+ def __init__(self, name: str, metrics: List[pb2.Expression.Literal]):
+ self._name = name
+ self._metrics = metrics
+
+ def __repr__(self) -> str:
+ return f"Plan observed({self._name}={self._metrics})"
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ @property
+ def metrics(self) -> List[pb2.Expression.Literal]:
+ return self._metrics
+
+
class AnalyzeResult:
def __init__(
self,
@@ -561,6 +578,17 @@ class SparkConnectClient(object):
for x in metrics.metrics
]
+ def _build_observed_metrics(
+ self, metrics: List["pb2.ExecutePlanResponse.ObservedMetrics"]
+ ) -> List[PlanObservedMetrics]:
+ return [
+ PlanObservedMetrics(
+ x.name,
+ [v for v in x.values],
+ )
+ for x in metrics
+ ]
+
def to_table(self, plan: pb2.Plan) -> "pa.Table":
"""
Return given plan as a PyArrow Table.
@@ -568,7 +596,7 @@ class SparkConnectClient(object):
logger.info(f"Executing plan {self._proto_to_string(plan)}")
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
- table, _, _2 = self._execute_and_fetch(req)
+ table, _, _, _3 = self._execute_and_fetch(req)
assert table is not None
return table
@@ -579,7 +607,7 @@ class SparkConnectClient(object):
logger.info(f"Executing plan {self._proto_to_string(plan)}")
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
- table, metrics, _ = self._execute_and_fetch(req)
+ table, metrics, observed_metrics, _ = self._execute_and_fetch(req)
assert table is not None
column_names = table.column_names
table = table.rename_columns([f"col_{i}" for i in range(len(column_names))])
@@ -587,6 +615,8 @@ class SparkConnectClient(object):
pdf.columns = column_names
if len(metrics) > 0:
pdf.attrs["metrics"] = metrics
+ if len(observed_metrics) > 0:
+ pdf.attrs["observed_metrics"] = observed_metrics
return pdf
def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType:
@@ -654,7 +684,7 @@ class SparkConnectClient(object):
if self._user_id:
req.user_context.user_id = self._user_id
req.plan.command.CopyFrom(command)
- data, _, properties = self._execute_and_fetch(req)
+ data, _, _, properties = self._execute_and_fetch(req)
if data is not None:
return (data.to_pandas(), properties)
else:
@@ -781,10 +811,11 @@ class SparkConnectClient(object):
def _execute_and_fetch(
self, req: pb2.ExecutePlanRequest
- ) -> Tuple[Optional["pa.Table"], List[PlanMetrics], Dict[str, Any]]:
+ ) -> Tuple[Optional["pa.Table"], List[PlanMetrics], List[PlanObservedMetrics], Dict[str, Any]]:
logger.info("ExecuteAndFetch")
m: Optional[pb2.ExecutePlanResponse.Metrics] = None
+ om: List[pb2.ExecutePlanResponse.ObservedMetrics] = []
batches: List[pa.RecordBatch] = []
properties = {}
try:
@@ -802,6 +833,9 @@ class SparkConnectClient(object):
if b.metrics is not None:
logger.debug("Received metric batch.")
m = b.metrics
+ if b.observed_metrics is not None:
+ logger.debug("Received observed metric batch.")
+ om.extend(b.observed_metrics)
if b.HasField("sql_command_result"):
properties["sql_command_result"] = b.sql_command_result.relation
if b.HasField("arrow_batch"):
@@ -817,12 +851,13 @@ class SparkConnectClient(object):
except grpc.RpcError as rpc_error:
self._handle_error(rpc_error)
metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None else []
+ observed_metrics: List[PlanObservedMetrics] = self._build_observed_metrics(om)
if len(batches) > 0:
table = pa.Table.from_batches(batches=batches)
- return table, metrics, properties
+ return table, metrics, observed_metrics, properties
else:
- return None, metrics, properties
+ return None, metrics, observed_metrics, properties
def _config_request_with_metadata(self) -> pb2.ConfigRequest:
req = pb2.ConfigRequest()
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 471dbf89582..fdc71d46632 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -42,6 +42,7 @@ from collections.abc import Iterable
from pyspark import _NoValue
from pyspark._globals import _NoValueType
+from pyspark.sql.observation import Observation
from pyspark.sql.types import Row, StructType
from pyspark.sql.dataframe import (
DataFrame as PySparkDataFrame,
@@ -783,6 +784,31 @@ class DataFrame:
randomSplit.__doc__ = PySparkDataFrame.randomSplit.__doc__
+ def observe(
+ self,
+ observation: Union["Observation", str],
+ *exprs: Column,
+ ) -> "DataFrame":
+ if len(exprs) == 0:
+ raise ValueError("'exprs' should not be empty")
+ if not all(isinstance(c, Column) for c in exprs):
+ raise ValueError("all 'exprs' should be Column")
+
+ if isinstance(observation, Observation):
+ return DataFrame.withPlan(
+ plan.CollectMetrics(self._plan, str(observation._name), list(exprs)),
+ self._session,
+ )
+ elif isinstance(observation, str):
+ return DataFrame.withPlan(
+ plan.CollectMetrics(self._plan, observation, list(exprs)),
+ self._session,
+ )
+ else:
+ raise ValueError("'observation' should be either `Observation` or `str`.")
+
+ observe.__doc__ = PySparkDataFrame.observe.__doc__
+
def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None:
print(self._show_string(n, truncate, vertical))
@@ -1523,9 +1549,6 @@ class DataFrame:
def withWatermark(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("withWatermark() is not implemented.")
- def observe(self, *args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("observe() is not implemented.")
-
def foreach(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("foreach() is not implemented.")
@@ -1730,6 +1753,9 @@ def _test() -> None:
# TODO(SPARK-41625): Support Structured Streaming
del pyspark.sql.connect.dataframe.DataFrame.isStreaming.__doc__
+ # TODO(SPARK-41888): Support StreamingQueryListener for DataFrame.observe
+ del pyspark.sql.connect.dataframe.DataFrame.observe.__doc__
+
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.dataframe tests")
.remote("local[4]")
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 7e767885793..8e5a96b974d 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1050,6 +1050,35 @@ class Unpivot(LogicalPlan):
return plan
+class CollectMetrics(LogicalPlan):
+ """Logical plan object for a CollectMetrics operation."""
+
+ def __init__(
+ self,
+ child: Optional["LogicalPlan"],
+ name: str,
+ exprs: List["ColumnOrName"],
+ ) -> None:
+ super().__init__(child)
+ self._name = name
+ self._exprs = exprs
+
+ def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") -> proto.Expression:
+ if isinstance(col, Column):
+ return col.to_plan(session)
+ else:
+ return self.unresolved_attr(col)
+
+ def plan(self, session: "SparkConnectClient") -> proto.Relation:
+ assert self._child is not None
+
+ plan = proto.Relation()
+ plan.collect_metrics.input.CopyFrom(self._child.plan(session))
+ plan.collect_metrics.name = self._name
+ plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for x in self._exprs])
+ return plan
+
+
class NAFill(LogicalPlan):
def __init__(
self, child: Optional["LogicalPlan"], cols: Optional[List[str]], values: List[Any]
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py
index 628f7ebdd46..9ece82ed535 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -31,12 +31,13 @@ _sym_db = _symbol_database.Default()
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
from pyspark.sql.connect.proto import commands_pb2 as spark_dot_connect_dot_commands__pb2
+from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2
from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2
from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x0 [...]
+ b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06 [...]
)
@@ -76,6 +77,7 @@ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY = (
_EXECUTEPLANRESPONSE_METRICS_METRICVALUE = _EXECUTEPLANRESPONSE_METRICS.nested_types_by_name[
"MetricValue"
]
+_EXECUTEPLANRESPONSE_OBSERVEDMETRICS = _EXECUTEPLANRESPONSE.nested_types_by_name["ObservedMetrics"]
_KEYVALUE = DESCRIPTOR.message_types_by_name["KeyValue"]
_CONFIGREQUEST = DESCRIPTOR.message_types_by_name["ConfigRequest"]
_CONFIGREQUEST_OPERATION = _CONFIGREQUEST.nested_types_by_name["Operation"]
@@ -376,6 +378,15 @@ ExecutePlanResponse = _reflection.GeneratedProtocolMessageType(
# @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.Metrics)
},
),
+ "ObservedMetrics": _reflection.GeneratedProtocolMessageType(
+ "ObservedMetrics",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _EXECUTEPLANRESPONSE_OBSERVEDMETRICS,
+ "__module__": "spark.connect.base_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.ObservedMetrics)
+ },
+ ),
"DESCRIPTOR": _EXECUTEPLANRESPONSE,
"__module__": "spark.connect.base_pb2"
# @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse)
@@ -388,6 +399,7 @@ _sym_db.RegisterMessage(ExecutePlanResponse.Metrics)
_sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject)
_sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntry)
_sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricValue)
+_sym_db.RegisterMessage(ExecutePlanResponse.ObservedMetrics)
KeyValue = _reflection.GeneratedProtocolMessageType(
"KeyValue",
@@ -581,100 +593,102 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options = b"8\001"
- _PLAN._serialized_start = 158
- _PLAN._serialized_end = 274
- _USERCONTEXT._serialized_start = 276
- _USERCONTEXT._serialized_end = 398
- _ANALYZEPLANREQUEST._serialized_start = 401
- _ANALYZEPLANREQUEST._serialized_end = 1843
- _ANALYZEPLANREQUEST_SCHEMA._serialized_start = 1172
- _ANALYZEPLANREQUEST_SCHEMA._serialized_end = 1221
- _ANALYZEPLANREQUEST_EXPLAIN._serialized_start = 1224
- _ANALYZEPLANREQUEST_EXPLAIN._serialized_end = 1539
- _ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE._serialized_start = 1367
- _ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE._serialized_end = 1539
- _ANALYZEPLANREQUEST_TREESTRING._serialized_start = 1541
- _ANALYZEPLANREQUEST_TREESTRING._serialized_end = 1594
- _ANALYZEPLANREQUEST_ISLOCAL._serialized_start = 1596
- _ANALYZEPLANREQUEST_ISLOCAL._serialized_end = 1646
- _ANALYZEPLANREQUEST_ISSTREAMING._serialized_start = 1648
- _ANALYZEPLANREQUEST_ISSTREAMING._serialized_end = 1702
- _ANALYZEPLANREQUEST_INPUTFILES._serialized_start = 1704
- _ANALYZEPLANREQUEST_INPUTFILES._serialized_end = 1757
- _ANALYZEPLANREQUEST_SPARKVERSION._serialized_start = 1759
- _ANALYZEPLANREQUEST_SPARKVERSION._serialized_end = 1773
- _ANALYZEPLANREQUEST_DDLPARSE._serialized_start = 1775
- _ANALYZEPLANREQUEST_DDLPARSE._serialized_end = 1816
- _ANALYZEPLANRESPONSE._serialized_start = 1846
- _ANALYZEPLANRESPONSE._serialized_end = 2916
- _ANALYZEPLANRESPONSE_SCHEMA._serialized_start = 2525
- _ANALYZEPLANRESPONSE_SCHEMA._serialized_end = 2582
- _ANALYZEPLANRESPONSE_EXPLAIN._serialized_start = 2584
- _ANALYZEPLANRESPONSE_EXPLAIN._serialized_end = 2632
- _ANALYZEPLANRESPONSE_TREESTRING._serialized_start = 2634
- _ANALYZEPLANRESPONSE_TREESTRING._serialized_end = 2679
- _ANALYZEPLANRESPONSE_ISLOCAL._serialized_start = 2681
- _ANALYZEPLANRESPONSE_ISLOCAL._serialized_end = 2717
- _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_start = 2719
- _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_end = 2767
- _ANALYZEPLANRESPONSE_INPUTFILES._serialized_start = 2769
- _ANALYZEPLANRESPONSE_INPUTFILES._serialized_end = 2803
- _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_start = 2805
- _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_end = 2845
- _ANALYZEPLANRESPONSE_DDLPARSE._serialized_start = 2847
- _ANALYZEPLANRESPONSE_DDLPARSE._serialized_end = 2906
- _EXECUTEPLANREQUEST._serialized_start = 2919
- _EXECUTEPLANREQUEST._serialized_end = 3126
- _EXECUTEPLANRESPONSE._serialized_start = 3129
- _EXECUTEPLANRESPONSE._serialized_end = 4160
- _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 3489
- _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 3560
- _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 3562
- _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 3623
- _EXECUTEPLANRESPONSE_METRICS._serialized_start = 3626
- _EXECUTEPLANRESPONSE_METRICS._serialized_end = 4143
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 3721
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 4053
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 3930
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 4053
- _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 4055
- _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 4143
- _KEYVALUE._serialized_start = 4162
- _KEYVALUE._serialized_end = 4227
- _CONFIGREQUEST._serialized_start = 4230
- _CONFIGREQUEST._serialized_end = 5256
- _CONFIGREQUEST_OPERATION._serialized_start = 4448
- _CONFIGREQUEST_OPERATION._serialized_end = 4946
- _CONFIGREQUEST_SET._serialized_start = 4948
- _CONFIGREQUEST_SET._serialized_end = 5000
- _CONFIGREQUEST_GET._serialized_start = 5002
- _CONFIGREQUEST_GET._serialized_end = 5027
- _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 5029
- _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 5092
- _CONFIGREQUEST_GETOPTION._serialized_start = 5094
- _CONFIGREQUEST_GETOPTION._serialized_end = 5125
- _CONFIGREQUEST_GETALL._serialized_start = 5127
- _CONFIGREQUEST_GETALL._serialized_end = 5175
- _CONFIGREQUEST_UNSET._serialized_start = 5177
- _CONFIGREQUEST_UNSET._serialized_end = 5204
- _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 5206
- _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 5240
- _CONFIGRESPONSE._serialized_start = 5258
- _CONFIGRESPONSE._serialized_end = 5378
- _ADDARTIFACTSREQUEST._serialized_start = 5381
- _ADDARTIFACTSREQUEST._serialized_end = 6196
- _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 5728
- _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 5781
- _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 5783
- _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 5894
- _ADDARTIFACTSREQUEST_BATCH._serialized_start = 5896
- _ADDARTIFACTSREQUEST_BATCH._serialized_end = 5989
- _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 5992
- _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 6185
- _ADDARTIFACTSRESPONSE._serialized_start = 6199
- _ADDARTIFACTSRESPONSE._serialized_end = 6387
- _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 6306
- _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 6387
- _SPARKCONNECTSERVICE._serialized_start = 6390
- _SPARKCONNECTSERVICE._serialized_end = 6755
+ _PLAN._serialized_start = 191
+ _PLAN._serialized_end = 307
+ _USERCONTEXT._serialized_start = 309
+ _USERCONTEXT._serialized_end = 431
+ _ANALYZEPLANREQUEST._serialized_start = 434
+ _ANALYZEPLANREQUEST._serialized_end = 1876
+ _ANALYZEPLANREQUEST_SCHEMA._serialized_start = 1205
+ _ANALYZEPLANREQUEST_SCHEMA._serialized_end = 1254
+ _ANALYZEPLANREQUEST_EXPLAIN._serialized_start = 1257
+ _ANALYZEPLANREQUEST_EXPLAIN._serialized_end = 1572
+ _ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE._serialized_start = 1400
+ _ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE._serialized_end = 1572
+ _ANALYZEPLANREQUEST_TREESTRING._serialized_start = 1574
+ _ANALYZEPLANREQUEST_TREESTRING._serialized_end = 1627
+ _ANALYZEPLANREQUEST_ISLOCAL._serialized_start = 1629
+ _ANALYZEPLANREQUEST_ISLOCAL._serialized_end = 1679
+ _ANALYZEPLANREQUEST_ISSTREAMING._serialized_start = 1681
+ _ANALYZEPLANREQUEST_ISSTREAMING._serialized_end = 1735
+ _ANALYZEPLANREQUEST_INPUTFILES._serialized_start = 1737
+ _ANALYZEPLANREQUEST_INPUTFILES._serialized_end = 1790
+ _ANALYZEPLANREQUEST_SPARKVERSION._serialized_start = 1792
+ _ANALYZEPLANREQUEST_SPARKVERSION._serialized_end = 1806
+ _ANALYZEPLANREQUEST_DDLPARSE._serialized_start = 1808
+ _ANALYZEPLANREQUEST_DDLPARSE._serialized_end = 1849
+ _ANALYZEPLANRESPONSE._serialized_start = 1879
+ _ANALYZEPLANRESPONSE._serialized_end = 2949
+ _ANALYZEPLANRESPONSE_SCHEMA._serialized_start = 2558
+ _ANALYZEPLANRESPONSE_SCHEMA._serialized_end = 2615
+ _ANALYZEPLANRESPONSE_EXPLAIN._serialized_start = 2617
+ _ANALYZEPLANRESPONSE_EXPLAIN._serialized_end = 2665
+ _ANALYZEPLANRESPONSE_TREESTRING._serialized_start = 2667
+ _ANALYZEPLANRESPONSE_TREESTRING._serialized_end = 2712
+ _ANALYZEPLANRESPONSE_ISLOCAL._serialized_start = 2714
+ _ANALYZEPLANRESPONSE_ISLOCAL._serialized_end = 2750
+ _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_start = 2752
+ _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_end = 2800
+ _ANALYZEPLANRESPONSE_INPUTFILES._serialized_start = 2802
+ _ANALYZEPLANRESPONSE_INPUTFILES._serialized_end = 2836
+ _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_start = 2838
+ _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_end = 2878
+ _ANALYZEPLANRESPONSE_DDLPARSE._serialized_start = 2880
+ _ANALYZEPLANRESPONSE_DDLPARSE._serialized_end = 2939
+ _EXECUTEPLANREQUEST._serialized_start = 2952
+ _EXECUTEPLANREQUEST._serialized_end = 3159
+ _EXECUTEPLANRESPONSE._serialized_start = 3162
+ _EXECUTEPLANRESPONSE._serialized_end = 4386
+ _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 3617
+ _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 3688
+ _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 3690
+ _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 3751
+ _EXECUTEPLANRESPONSE_METRICS._serialized_start = 3754
+ _EXECUTEPLANRESPONSE_METRICS._serialized_end = 4271
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 3849
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 4181
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 4058
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 4181
+ _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 4183
+ _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 4271
+ _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 4273
+ _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 4369
+ _KEYVALUE._serialized_start = 4388
+ _KEYVALUE._serialized_end = 4453
+ _CONFIGREQUEST._serialized_start = 4456
+ _CONFIGREQUEST._serialized_end = 5482
+ _CONFIGREQUEST_OPERATION._serialized_start = 4674
+ _CONFIGREQUEST_OPERATION._serialized_end = 5172
+ _CONFIGREQUEST_SET._serialized_start = 5174
+ _CONFIGREQUEST_SET._serialized_end = 5226
+ _CONFIGREQUEST_GET._serialized_start = 5228
+ _CONFIGREQUEST_GET._serialized_end = 5253
+ _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 5255
+ _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 5318
+ _CONFIGREQUEST_GETOPTION._serialized_start = 5320
+ _CONFIGREQUEST_GETOPTION._serialized_end = 5351
+ _CONFIGREQUEST_GETALL._serialized_start = 5353
+ _CONFIGREQUEST_GETALL._serialized_end = 5401
+ _CONFIGREQUEST_UNSET._serialized_start = 5403
+ _CONFIGREQUEST_UNSET._serialized_end = 5430
+ _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 5432
+ _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 5466
+ _CONFIGRESPONSE._serialized_start = 5484
+ _CONFIGRESPONSE._serialized_end = 5604
+ _ADDARTIFACTSREQUEST._serialized_start = 5607
+ _ADDARTIFACTSREQUEST._serialized_end = 6422
+ _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 5954
+ _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 6007
+ _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 6009
+ _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 6120
+ _ADDARTIFACTSREQUEST_BATCH._serialized_start = 6122
+ _ADDARTIFACTSREQUEST_BATCH._serialized_end = 6215
+ _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 6218
+ _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 6411
+ _ADDARTIFACTSRESPONSE._serialized_start = 6425
+ _ADDARTIFACTSRESPONSE._serialized_end = 6613
+ _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 6532
+ _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 6613
+ _SPARKCONNECTSERVICE._serialized_start = 6616
+ _SPARKCONNECTSERVICE._serialized_end = 6981
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 0e800947975..56e1d2e416f 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -41,6 +41,7 @@ import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import pyspark.sql.connect.proto.commands_pb2
+import pyspark.sql.connect.proto.expressions_pb2
import pyspark.sql.connect.proto.relations_pb2
import pyspark.sql.connect.proto.types_pb2
import sys
@@ -905,11 +906,37 @@ class ExecutePlanResponse(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["metrics", b"metrics"]
) -> None: ...
+ class ObservedMetrics(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ NAME_FIELD_NUMBER: builtins.int
+ VALUES_FIELD_NUMBER: builtins.int
+ name: builtins.str
+ @property
+ def values(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.expressions_pb2.Expression.Literal
+ ]: ...
+ def __init__(
+ self,
+ *,
+ name: builtins.str = ...,
+ values: collections.abc.Iterable[
+ pyspark.sql.connect.proto.expressions_pb2.Expression.Literal
+ ]
+ | None = ...,
+ ) -> None: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["name", b"name", "values", b"values"]
+ ) -> None: ...
+
CLIENT_ID_FIELD_NUMBER: builtins.int
ARROW_BATCH_FIELD_NUMBER: builtins.int
SQL_COMMAND_RESULT_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
METRICS_FIELD_NUMBER: builtins.int
+ OBSERVED_METRICS_FIELD_NUMBER: builtins.int
client_id: builtins.str
@property
def arrow_batch(self) -> global___ExecutePlanResponse.ArrowBatch: ...
@@ -924,6 +951,13 @@ class ExecutePlanResponse(google.protobuf.message.Message):
"""Metrics for the query execution. Typically, this field is only present in the last
batch of results and then represent the overall state of the query execution.
"""
+ @property
+ def observed_metrics(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ global___ExecutePlanResponse.ObservedMetrics
+ ]:
+ """The metrics observed during the execution of the query plan."""
def __init__(
self,
*,
@@ -932,6 +966,8 @@ class ExecutePlanResponse(google.protobuf.message.Message):
sql_command_result: global___ExecutePlanResponse.SqlCommandResult | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
metrics: global___ExecutePlanResponse.Metrics | None = ...,
+ observed_metrics: collections.abc.Iterable[global___ExecutePlanResponse.ObservedMetrics]
+ | None = ...,
) -> None: ...
def HasField(
self,
@@ -959,6 +995,8 @@ class ExecutePlanResponse(google.protobuf.message.Message):
b"extension",
"metrics",
b"metrics",
+ "observed_metrics",
+ b"observed_metrics",
"response_type",
b"response_type",
"sql_command_result",
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 3b4573177a2..f38cff1990d 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+ b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xfb\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -92,6 +92,7 @@ _UNPIVOT_VALUES = _UNPIVOT.nested_types_by_name["Values"]
_TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"]
_REPARTITIONBYEXPRESSION = DESCRIPTOR.message_types_by_name["RepartitionByExpression"]
_FRAMEMAP = DESCRIPTOR.message_types_by_name["FrameMap"]
+_COLLECTMETRICS = DESCRIPTOR.message_types_by_name["CollectMetrics"]
_JOIN_JOINTYPE = _JOIN.enum_types_by_name["JoinType"]
_SETOPERATION_SETOPTYPE = _SETOPERATION.enum_types_by_name["SetOpType"]
_AGGREGATE_GROUPTYPE = _AGGREGATE.enum_types_by_name["GroupType"]
@@ -636,6 +637,17 @@ FrameMap = _reflection.GeneratedProtocolMessageType(
)
_sym_db.RegisterMessage(FrameMap)
+CollectMetrics = _reflection.GeneratedProtocolMessageType(
+ "CollectMetrics",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _COLLECTMETRICS,
+ "__module__": "spark.connect.relations_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.CollectMetrics)
+ },
+)
+_sym_db.RegisterMessage(CollectMetrics)
+
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
@@ -647,109 +659,111 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._options = None
_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001"
_RELATION._serialized_start = 165
- _RELATION._serialized_end = 2518
- _UNKNOWN._serialized_start = 2520
- _UNKNOWN._serialized_end = 2529
- _RELATIONCOMMON._serialized_start = 2531
- _RELATIONCOMMON._serialized_end = 2622
- _SQL._serialized_start = 2625
- _SQL._serialized_end = 2759
- _SQL_ARGSENTRY._serialized_start = 2704
- _SQL_ARGSENTRY._serialized_end = 2759
- _READ._serialized_start = 2762
- _READ._serialized_end = 3226
- _READ_NAMEDTABLE._serialized_start = 2904
- _READ_NAMEDTABLE._serialized_end = 2965
- _READ_DATASOURCE._serialized_start = 2968
- _READ_DATASOURCE._serialized_end = 3213
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3133
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3191
- _PROJECT._serialized_start = 3228
- _PROJECT._serialized_end = 3345
- _FILTER._serialized_start = 3347
- _FILTER._serialized_end = 3459
- _JOIN._serialized_start = 3462
- _JOIN._serialized_end = 3933
- _JOIN_JOINTYPE._serialized_start = 3725
- _JOIN_JOINTYPE._serialized_end = 3933
- _SETOPERATION._serialized_start = 3936
- _SETOPERATION._serialized_end = 4415
- _SETOPERATION_SETOPTYPE._serialized_start = 4252
- _SETOPERATION_SETOPTYPE._serialized_end = 4366
- _LIMIT._serialized_start = 4417
- _LIMIT._serialized_end = 4493
- _OFFSET._serialized_start = 4495
- _OFFSET._serialized_end = 4574
- _TAIL._serialized_start = 4576
- _TAIL._serialized_end = 4651
- _AGGREGATE._serialized_start = 4654
- _AGGREGATE._serialized_end = 5236
- _AGGREGATE_PIVOT._serialized_start = 4993
- _AGGREGATE_PIVOT._serialized_end = 5104
- _AGGREGATE_GROUPTYPE._serialized_start = 5107
- _AGGREGATE_GROUPTYPE._serialized_end = 5236
- _SORT._serialized_start = 5239
- _SORT._serialized_end = 5399
- _DROP._serialized_start = 5402
- _DROP._serialized_end = 5543
- _DEDUPLICATE._serialized_start = 5546
- _DEDUPLICATE._serialized_end = 5717
- _LOCALRELATION._serialized_start = 5719
- _LOCALRELATION._serialized_end = 5808
- _SAMPLE._serialized_start = 5811
- _SAMPLE._serialized_end = 6084
- _RANGE._serialized_start = 6087
- _RANGE._serialized_end = 6232
- _SUBQUERYALIAS._serialized_start = 6234
- _SUBQUERYALIAS._serialized_end = 6348
- _REPARTITION._serialized_start = 6351
- _REPARTITION._serialized_end = 6493
- _SHOWSTRING._serialized_start = 6496
- _SHOWSTRING._serialized_end = 6638
- _STATSUMMARY._serialized_start = 6640
- _STATSUMMARY._serialized_end = 6732
- _STATDESCRIBE._serialized_start = 6734
- _STATDESCRIBE._serialized_end = 6815
- _STATCROSSTAB._serialized_start = 6817
- _STATCROSSTAB._serialized_end = 6918
- _STATCOV._serialized_start = 6920
- _STATCOV._serialized_end = 7016
- _STATCORR._serialized_start = 7019
- _STATCORR._serialized_end = 7156
- _STATAPPROXQUANTILE._serialized_start = 7159
- _STATAPPROXQUANTILE._serialized_end = 7323
- _STATFREQITEMS._serialized_start = 7325
- _STATFREQITEMS._serialized_end = 7450
- _STATSAMPLEBY._serialized_start = 7453
- _STATSAMPLEBY._serialized_end = 7762
- _STATSAMPLEBY_FRACTION._serialized_start = 7654
- _STATSAMPLEBY_FRACTION._serialized_end = 7753
- _NAFILL._serialized_start = 7765
- _NAFILL._serialized_end = 7899
- _NADROP._serialized_start = 7902
- _NADROP._serialized_end = 8036
- _NAREPLACE._serialized_start = 8039
- _NAREPLACE._serialized_end = 8335
- _NAREPLACE_REPLACEMENT._serialized_start = 8194
- _NAREPLACE_REPLACEMENT._serialized_end = 8335
- _TODF._serialized_start = 8337
- _TODF._serialized_end = 8425
- _WITHCOLUMNSRENAMED._serialized_start = 8428
- _WITHCOLUMNSRENAMED._serialized_end = 8667
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8600
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8667
- _WITHCOLUMNS._serialized_start = 8669
- _WITHCOLUMNS._serialized_end = 8788
- _HINT._serialized_start = 8791
- _HINT._serialized_end = 8923
- _UNPIVOT._serialized_start = 8926
- _UNPIVOT._serialized_end = 9253
- _UNPIVOT_VALUES._serialized_start = 9183
- _UNPIVOT_VALUES._serialized_end = 9242
- _TOSCHEMA._serialized_start = 9255
- _TOSCHEMA._serialized_end = 9361
- _REPARTITIONBYEXPRESSION._serialized_start = 9364
- _REPARTITIONBYEXPRESSION._serialized_end = 9567
- _FRAMEMAP._serialized_start = 9569
- _FRAMEMAP._serialized_end = 9694
+ _RELATION._serialized_end = 2592
+ _UNKNOWN._serialized_start = 2594
+ _UNKNOWN._serialized_end = 2603
+ _RELATIONCOMMON._serialized_start = 2605
+ _RELATIONCOMMON._serialized_end = 2696
+ _SQL._serialized_start = 2699
+ _SQL._serialized_end = 2833
+ _SQL_ARGSENTRY._serialized_start = 2778
+ _SQL_ARGSENTRY._serialized_end = 2833
+ _READ._serialized_start = 2836
+ _READ._serialized_end = 3300
+ _READ_NAMEDTABLE._serialized_start = 2978
+ _READ_NAMEDTABLE._serialized_end = 3039
+ _READ_DATASOURCE._serialized_start = 3042
+ _READ_DATASOURCE._serialized_end = 3287
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3207
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3265
+ _PROJECT._serialized_start = 3302
+ _PROJECT._serialized_end = 3419
+ _FILTER._serialized_start = 3421
+ _FILTER._serialized_end = 3533
+ _JOIN._serialized_start = 3536
+ _JOIN._serialized_end = 4007
+ _JOIN_JOINTYPE._serialized_start = 3799
+ _JOIN_JOINTYPE._serialized_end = 4007
+ _SETOPERATION._serialized_start = 4010
+ _SETOPERATION._serialized_end = 4489
+ _SETOPERATION_SETOPTYPE._serialized_start = 4326
+ _SETOPERATION_SETOPTYPE._serialized_end = 4440
+ _LIMIT._serialized_start = 4491
+ _LIMIT._serialized_end = 4567
+ _OFFSET._serialized_start = 4569
+ _OFFSET._serialized_end = 4648
+ _TAIL._serialized_start = 4650
+ _TAIL._serialized_end = 4725
+ _AGGREGATE._serialized_start = 4728
+ _AGGREGATE._serialized_end = 5310
+ _AGGREGATE_PIVOT._serialized_start = 5067
+ _AGGREGATE_PIVOT._serialized_end = 5178
+ _AGGREGATE_GROUPTYPE._serialized_start = 5181
+ _AGGREGATE_GROUPTYPE._serialized_end = 5310
+ _SORT._serialized_start = 5313
+ _SORT._serialized_end = 5473
+ _DROP._serialized_start = 5476
+ _DROP._serialized_end = 5617
+ _DEDUPLICATE._serialized_start = 5620
+ _DEDUPLICATE._serialized_end = 5791
+ _LOCALRELATION._serialized_start = 5793
+ _LOCALRELATION._serialized_end = 5882
+ _SAMPLE._serialized_start = 5885
+ _SAMPLE._serialized_end = 6158
+ _RANGE._serialized_start = 6161
+ _RANGE._serialized_end = 6306
+ _SUBQUERYALIAS._serialized_start = 6308
+ _SUBQUERYALIAS._serialized_end = 6422
+ _REPARTITION._serialized_start = 6425
+ _REPARTITION._serialized_end = 6567
+ _SHOWSTRING._serialized_start = 6570
+ _SHOWSTRING._serialized_end = 6712
+ _STATSUMMARY._serialized_start = 6714
+ _STATSUMMARY._serialized_end = 6806
+ _STATDESCRIBE._serialized_start = 6808
+ _STATDESCRIBE._serialized_end = 6889
+ _STATCROSSTAB._serialized_start = 6891
+ _STATCROSSTAB._serialized_end = 6992
+ _STATCOV._serialized_start = 6994
+ _STATCOV._serialized_end = 7090
+ _STATCORR._serialized_start = 7093
+ _STATCORR._serialized_end = 7230
+ _STATAPPROXQUANTILE._serialized_start = 7233
+ _STATAPPROXQUANTILE._serialized_end = 7397
+ _STATFREQITEMS._serialized_start = 7399
+ _STATFREQITEMS._serialized_end = 7524
+ _STATSAMPLEBY._serialized_start = 7527
+ _STATSAMPLEBY._serialized_end = 7836
+ _STATSAMPLEBY_FRACTION._serialized_start = 7728
+ _STATSAMPLEBY_FRACTION._serialized_end = 7827
+ _NAFILL._serialized_start = 7839
+ _NAFILL._serialized_end = 7973
+ _NADROP._serialized_start = 7976
+ _NADROP._serialized_end = 8110
+ _NAREPLACE._serialized_start = 8113
+ _NAREPLACE._serialized_end = 8409
+ _NAREPLACE_REPLACEMENT._serialized_start = 8268
+ _NAREPLACE_REPLACEMENT._serialized_end = 8409
+ _TODF._serialized_start = 8411
+ _TODF._serialized_end = 8499
+ _WITHCOLUMNSRENAMED._serialized_start = 8502
+ _WITHCOLUMNSRENAMED._serialized_end = 8741
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8674
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8741
+ _WITHCOLUMNS._serialized_start = 8743
+ _WITHCOLUMNS._serialized_end = 8862
+ _HINT._serialized_start = 8865
+ _HINT._serialized_end = 8997
+ _UNPIVOT._serialized_start = 9000
+ _UNPIVOT._serialized_end = 9327
+ _UNPIVOT_VALUES._serialized_start = 9257
+ _UNPIVOT_VALUES._serialized_end = 9316
+ _TOSCHEMA._serialized_start = 9329
+ _TOSCHEMA._serialized_end = 9435
+ _REPARTITIONBYEXPRESSION._serialized_start = 9438
+ _REPARTITIONBYEXPRESSION._serialized_end = 9641
+ _FRAMEMAP._serialized_start = 9643
+ _FRAMEMAP._serialized_end = 9768
+ _COLLECTMETRICS._serialized_start = 9771
+ _COLLECTMETRICS._serialized_end = 9907
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index b60fd5a1a61..e4d1653f0ba 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -90,6 +90,7 @@ class Relation(google.protobuf.message.Message):
TO_SCHEMA_FIELD_NUMBER: builtins.int
REPARTITION_BY_EXPRESSION_FIELD_NUMBER: builtins.int
FRAME_MAP_FIELD_NUMBER: builtins.int
+ COLLECT_METRICS_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
@@ -161,6 +162,8 @@ class Relation(google.protobuf.message.Message):
@property
def frame_map(self) -> global___FrameMap: ...
@property
+ def collect_metrics(self) -> global___CollectMetrics: ...
+ @property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
@@ -225,6 +228,7 @@ class Relation(google.protobuf.message.Message):
to_schema: global___ToSchema | None = ...,
repartition_by_expression: global___RepartitionByExpression | None = ...,
frame_map: global___FrameMap | None = ...,
+ collect_metrics: global___CollectMetrics | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
@@ -249,6 +253,8 @@ class Relation(google.protobuf.message.Message):
b"approx_quantile",
"catalog",
b"catalog",
+ "collect_metrics",
+ b"collect_metrics",
"common",
b"common",
"corr",
@@ -340,6 +346,8 @@ class Relation(google.protobuf.message.Message):
b"approx_quantile",
"catalog",
b"catalog",
+ "collect_metrics",
+ b"collect_metrics",
"common",
b"common",
"corr",
@@ -452,6 +460,7 @@ class Relation(google.protobuf.message.Message):
"to_schema",
"repartition_by_expression",
"frame_map",
+ "collect_metrics",
"fill_na",
"drop_na",
"replace",
@@ -2702,3 +2711,43 @@ class FrameMap(google.protobuf.message.Message):
) -> None: ...
global___FrameMap = FrameMap
+
+class CollectMetrics(google.protobuf.message.Message):
+ """Collect arbitrary (named) metrics from a dataset."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ INPUT_FIELD_NUMBER: builtins.int
+ NAME_FIELD_NUMBER: builtins.int
+ METRICS_FIELD_NUMBER: builtins.int
+ @property
+ def input(self) -> global___Relation:
+ """(Required) The input relation."""
+ name: builtins.str
+ """(Required) Name of the metrics."""
+ @property
+ def metrics(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]:
+ """(Required) The metric sequence."""
+ def __init__(
+ self,
+ *,
+ input: global___Relation | None = ...,
+ name: builtins.str = ...,
+ metrics: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression]
+ | None = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["input", b"input"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "input", b"input", "metrics", b"metrics", "name", b"name"
+ ],
+ ) -> None: ...
+
+global___CollectMetrics = CollectMetrics
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 70938459f79..68ac8b1dd7c 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -59,6 +59,7 @@ if should_test_connect:
import grpc
import pandas as pd
import numpy as np
+ from pyspark.sql.connect.proto import Expression as ProtoExpression
from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
from pyspark.sql.connect.client import ChannelBuilder
from pyspark.sql.connect.column import Column
@@ -1631,6 +1632,57 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.assert_eq(relations[i].toPandas(), datasets[i].toPandas())
i += 1
+ def test_observe(self):
+ # SPARK-41527: test DataFrame.observe()
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
+
+ observation_name = "my_metric"
+
+ self.assert_eq(
+ self.connect.read.table(self.tbl_name)
+ .filter("id > 3")
+ .observe(observation_name, CF.min("id"), CF.max("id"), CF.sum("id"))
+ .toPandas(),
+ self.spark.read.table(self.tbl_name)
+ .filter("id > 3")
+ .observe(observation_name, SF.min("id"), SF.max("id"), SF.sum("id"))
+ .toPandas(),
+ )
+
+ from pyspark.sql.observation import Observation
+
+ observation = Observation(observation_name)
+
+ cdf = (
+ self.connect.read.table(self.tbl_name)
+ .filter("id > 3")
+ .observe(observation, CF.min("id"), CF.max("id"), CF.sum("id"))
+ .toPandas()
+ )
+ df = (
+ self.spark.read.table(self.tbl_name)
+ .filter("id > 3")
+ .observe(observation, SF.min("id"), SF.max("id"), SF.sum("id"))
+ .toPandas()
+ )
+
+ self.assert_eq(cdf, df)
+
+ observed_metrics = cdf.attrs["observed_metrics"]
+ self.assert_eq(len(observed_metrics), 1)
+ self.assert_eq(observed_metrics[0].name, observation_name)
+ self.assert_eq(len(observed_metrics[0].metrics), 3)
+ for metric in observed_metrics[0].metrics:
+ self.assertIsInstance(metric, ProtoExpression.Literal)
+ values = list(map(lambda metric: metric.long, observed_metrics[0].metrics))
+ self.assert_eq(values, [4, 99, 4944])
+
+ with self.assertRaisesRegex(ValueError, "'exprs' should not be empty"):
+ self.connect.read.table(self.tbl_name).observe(observation_name)
+ with self.assertRaisesRegex(ValueError, "all 'exprs' should be Column"):
+ self.connect.read.table(self.tbl_name).observe(observation_name, CF.lit(1), "id")
+
def test_with_columns(self):
# SPARK-41256: test withColumn(s).
self.assert_eq(
@@ -2766,7 +2818,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
"cache",
"persist",
"withWatermark",
- "observe",
"foreach",
"foreachPartition",
"toLocalIterator",
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py
index 8c09b9cfaa5..115d189b742 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -31,7 +31,7 @@ if should_test_connect:
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import WriteOperation, Read
from pyspark.sql.connect.readwriter import DataFrameReader
- from pyspark.sql.connect.functions import col, lit
+ from pyspark.sql.connect.functions import col, lit, max, min, sum
from pyspark.sql.connect.types import pyspark_types_to_proto_types
from pyspark.sql.types import (
StringType,
@@ -294,6 +294,68 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
relations = df.filter(df.col_name > 3).randomSplit([1.0, 2.0, 3.0])
checkRelations(relations)
+ def test_observe(self):
+ # SPARK-41527: test DataFrame.observe()
+ df = self.connect.readTable(table_name=self.tbl_name)
+
+ plan = (
+ df.filter(df.col_name > 3)
+ .observe("my_metric", min("id"), max("id"), sum("id"))
+ ._plan.to_proto(self.connect)
+ )
+ self.assertEqual(plan.root.collect_metrics.name, "my_metric")
+ self.assertTrue(
+ all(isinstance(c, proto.Expression) for c in plan.root.collect_metrics.metrics)
+ )
+ self.assertEqual(
+ plan.root.collect_metrics.metrics[0].unresolved_function.function_name, "min"
+ )
+ self.assertTrue(
+ len(plan.root.collect_metrics.metrics[0].unresolved_function.arguments) == 1
+ )
+ self.assertTrue(
+ all(
+ isinstance(c, proto.Expression)
+ for c in plan.root.collect_metrics.metrics[0].unresolved_function.arguments
+ )
+ )
+ self.assertEqual(
+ plan.root.collect_metrics.metrics[0]
+ .unresolved_function.arguments[0]
+ .unresolved_attribute.unparsed_identifier,
+ "id",
+ )
+
+ from pyspark.sql.observation import Observation
+
+ plan = (
+ df.filter(df.col_name > 3)
+ .observe(Observation("my_metric"), min("id"), max("id"), sum("id"))
+ ._plan.to_proto(self.connect)
+ )
+ self.assertEqual(plan.root.collect_metrics.name, "my_metric")
+ self.assertTrue(
+ all(isinstance(c, proto.Expression) for c in plan.root.collect_metrics.metrics)
+ )
+ self.assertEqual(
+ plan.root.collect_metrics.metrics[0].unresolved_function.function_name, "min"
+ )
+ self.assertTrue(
+ len(plan.root.collect_metrics.metrics[0].unresolved_function.arguments) == 1
+ )
+ self.assertTrue(
+ all(
+ isinstance(c, proto.Expression)
+ for c in plan.root.collect_metrics.metrics[0].unresolved_function.arguments
+ )
+ )
+ self.assertEqual(
+ plan.root.collect_metrics.metrics[0]
+ .unresolved_function.arguments[0]
+ .unresolved_attribute.unparsed_identifier,
+ "id",
+ )
+
def test_summary(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter(df.col_name > 3).summary()._plan.to_proto(self.connect)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
index a2750dd81e9..df2c99dde5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -45,7 +45,7 @@ import org.apache.spark.sql.util.QueryExecutionListener
* @param name name of the metric
* @since 3.3.0
*/
-class Observation(name: String) {
+class Observation(val name: String) {
if (name.isEmpty) throw new IllegalArgumentException("Name must not be empty")
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org