You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/03/06 02:19:18 UTC

[spark] branch branch-3.4 updated: [SPARK-41527][CONNECT][PYTHON] Implement `DataFrame.observe`

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 7d8ae50f099 [SPARK-41527][CONNECT][PYTHON] Implement `DataFrame.observe`
7d8ae50f099 is described below

commit 7d8ae50f099366ccf9ecbac12fd4d24da85bf6dc
Author: Jiaan Geng <be...@163.com>
AuthorDate: Sun Mar 5 22:18:09 2023 -0400

    [SPARK-41527][CONNECT][PYTHON] Implement `DataFrame.observe`
    
    ### What changes were proposed in this pull request?
    Implement `DataFrame.observe` with a proto message
    
    Implement `DataFrame.observe` for scala API
    Implement `DataFrame.observe` for python API
    
    ### Why are the changes needed?
    for Connect API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    'No'. New API
    
    ### How was this patch tested?
    New test cases.
    
    Closes #39091 from beliefer/SPARK-41527.
    
    Authored-by: Jiaan Geng <be...@163.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
    (cherry picked from commit 0ce63f36384e135ebfb21b927f6cc69361919b47)
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../src/main/protobuf/spark/connect/base.proto     |   9 +
 .../main/protobuf/spark/connect/relations.proto    |  12 ++
 .../org/apache/spark/sql/connect/dsl/package.scala |  40 +++-
 .../sql/connect/planner/SparkConnectPlanner.scala  |  12 +-
 .../service/SparkConnectStreamHandler.scala        |  25 ++-
 .../connect/planner/SparkConnectProtoSuite.scala   |  63 +++++-
 .../connect/planner/SparkConnectServiceSuite.scala |  69 +++++++
 python/pyspark/sql/connect/client.py               |  47 ++++-
 python/pyspark/sql/connect/dataframe.py            |  32 ++-
 python/pyspark/sql/connect/plan.py                 |  29 +++
 python/pyspark/sql/connect/proto/base_pb2.py       | 208 ++++++++++---------
 python/pyspark/sql/connect/proto/base_pb2.pyi      |  38 ++++
 python/pyspark/sql/connect/proto/relations_pb2.py  | 226 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  49 +++++
 .../sql/tests/connect/test_connect_basic.py        |  53 ++++-
 .../pyspark/sql/tests/connect/test_connect_plan.py |  64 +++++-
 .../scala/org/apache/spark/sql/Observation.scala   |   2 +-
 17 files changed, 759 insertions(+), 219 deletions(-)

diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 066a63d58ba..09407a99119 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -21,6 +21,7 @@ package spark.connect;
 
 import "google/protobuf/any.proto";
 import "spark/connect/commands.proto";
+import "spark/connect/expressions.proto";
 import "spark/connect/relations.proto";
 import "spark/connect/types.proto";
 
@@ -236,6 +237,9 @@ message ExecutePlanResponse {
   // batch of results and then represent the overall state of the query execution.
   Metrics metrics = 4;
 
+  // The metrics observed during the execution of the query plan.
+  repeated ObservedMetrics observed_metrics = 6;
+
   // A SQL command returns an opaque Relation that can be directly used as input for the next
   // call.
   message SqlCommandResult {
@@ -265,6 +269,11 @@ message ExecutePlanResponse {
       string metric_type = 3;
     }
   }
+
+  message ObservedMetrics {
+    string name = 1;
+    repeated Expression.Literal values = 2;
+  }
 }
 
 // The key-value pair for the config request and response.
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 3e4a4daeb36..2230aada4fe 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -61,6 +61,7 @@ message Relation {
     ToSchema to_schema = 26;
     RepartitionByExpression repartition_by_expression = 27;
     FrameMap frame_map = 28;
+    CollectMetrics collect_metrics = 29;
 
     // NA functions
     NAFill fill_na = 90;
@@ -781,3 +782,14 @@ message FrameMap {
   CommonInlineUserDefinedFunction func = 2;
 }
 
+// Collect arbitrary (named) metrics from a dataset.
+message CollectMetrics {
+  // (Required) The input relation.
+  Relation input = 1;
+
+  // (Required) Name of the metrics.
+  string name = 2;
+
+  // (Required) The metric sequence.
+  repeated Expression metrics = 3;
+}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 840f43abf49..7e60c5f9a28 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -24,7 +24,7 @@ import org.apache.spark.connect.proto._
 import org.apache.spark.connect.proto.Expression.ExpressionString
 import org.apache.spark.connect.proto.Join.JoinType
 import org.apache.spark.connect.proto.SetOperation.SetOpType
-import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.{Observation, SaveMode}
 import org.apache.spark.sql.connect.common.DataTypeProtoConverter
 import org.apache.spark.sql.connect.planner.{SaveModeConverter, TableSaveMethodConverter}
 import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
@@ -141,6 +141,20 @@ package object dsl {
           Expression.UnresolvedFunction.newBuilder().setFunctionName("min").addArguments(e))
         .build()
 
+    def proto_max(e: Expression): Expression =
+      Expression
+        .newBuilder()
+        .setUnresolvedFunction(
+          Expression.UnresolvedFunction.newBuilder().setFunctionName("max").addArguments(e))
+        .build()
+
+    def proto_sum(e: Expression): Expression =
+      Expression
+        .newBuilder()
+        .setUnresolvedFunction(
+          Expression.UnresolvedFunction.newBuilder().setFunctionName("sum").addArguments(e))
+        .build()
+
     def proto_explode(e: Expression): Expression =
       Expression
         .newBuilder()
@@ -1061,6 +1075,30 @@ package object dsl {
       def randomSplit(weights: Array[Double]): Array[Relation] =
         randomSplit(weights, Utils.random.nextLong)
 
+      def observe(name: String, expr: Expression, exprs: Expression*): Relation = {
+        Relation
+          .newBuilder()
+          .setCollectMetrics(
+            CollectMetrics
+              .newBuilder()
+              .setInput(logicalPlan)
+              .setName(name)
+              .addAllMetrics((expr +: exprs).asJava))
+          .build()
+      }
+
+      def observe(observation: Observation, expr: Expression, exprs: Expression*): Relation = {
+        Relation
+          .newBuilder()
+          .setCollectMetrics(
+            CollectMetrics
+              .newBuilder()
+              .setInput(logicalPlan)
+              .setName(observation.name)
+              .addAllMetrics((expr +: exprs).asJava))
+          .build()
+      }
+
       private def createSetOperation(
           left: Relation,
           right: Relation,
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 76a4c7faaa2..60fb94e8098 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, UdfPacket}
 import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
@@ -113,6 +113,8 @@ class SparkConnectPlanner(val session: SparkSession) {
         transformRepartitionByExpression(rel.getRepartitionByExpression)
       case proto.Relation.RelTypeCase.FRAME_MAP =>
         transformFrameMap(rel.getFrameMap)
+      case proto.Relation.RelTypeCase.COLLECT_METRICS =>
+        transformCollectMetrics(rel.getCollectMetrics)
       case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
         throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
 
@@ -571,6 +573,14 @@ class SparkConnectPlanner(val session: SparkSession) {
       numPartitionsOpt)
   }
 
+  private def transformCollectMetrics(rel: proto.CollectMetrics): LogicalPlan = {
+    val metrics = rel.getMetricsList.asScala.toSeq.map { expr =>
+      Column(transformExpression(expr))
+    }
+
+    CollectMetrics(rel.getName, metrics.map(_.named), transformRelation(rel.getInput))
+  }
+
   private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
     if (!rel.hasInput) {
       throw InvalidPlanInput("Deduplicate needs a plan input")
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 41ca564e6d3..15d5a981ae8 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
+import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
 import org.apache.spark.sql.connect.service.SparkConnectStreamHandler.processAsArrowBatches
 import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
@@ -62,6 +63,10 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
     processAsArrowBatches(request.getClientId, dataframe, responseObserver)
     responseObserver.onNext(
       SparkConnectStreamHandler.sendMetricsToResponse(request.getClientId, dataframe))
+    if (dataframe.queryExecution.observedMetrics.nonEmpty) {
+      responseObserver.onNext(
+        SparkConnectStreamHandler.sendObservedMetricsToResponse(request.getClientId, dataframe))
+    }
     responseObserver.onCompleted()
   }
 
@@ -206,6 +211,25 @@ object SparkConnectStreamHandler {
       .setMetrics(MetricGenerator.buildMetrics(rows.queryExecution.executedPlan))
       .build()
   }
+
+  def sendObservedMetricsToResponse(
+      clientId: String,
+      dataframe: DataFrame): ExecutePlanResponse = {
+    val observedMetrics = dataframe.queryExecution.observedMetrics.map { case (name, row) =>
+      val cols = (0 until row.length).map(i => toConnectProtoValue(row(i)))
+      ExecutePlanResponse.ObservedMetrics
+        .newBuilder()
+        .setName(name)
+        .addAllValues(cols.asJava)
+        .build()
+    }
+    // Prepare a response with the observed metrics.
+    ExecutePlanResponse
+      .newBuilder()
+      .setClientId(clientId)
+      .addAllObservedMetrics(observedMetrics.asJava)
+      .build()
+  }
 }
 
 object MetricGenerator extends AdaptiveSparkPlanHelper {
@@ -242,5 +266,4 @@ object MetricGenerator extends AdaptiveSparkPlanHelper {
       .build()
     Seq(mo) ++ transformChildren(p)
   }
-
 }
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 7fb28f354ce..241b3dcb825 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.SparkClassNotFoundException
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.Expression
 import org.apache.spark.connect.proto.Join.JoinType
-import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row, SaveMode}
+import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Observation, Row, SaveMode}
 import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, UnsafeProjection}
 import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter}
@@ -940,6 +940,67 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
     comparePlans(connectPlan1, sparkPlan1)
   }
 
+  test("Test observe") {
+    val connectPlan0 =
+      connectTestRelation.observe(
+        "my_metric",
+        proto_min("id".protoAttr).as("min_val"),
+        proto_max("id".protoAttr).as("max_val"),
+        proto_sum("id".protoAttr))
+    val sparkPlan0 =
+      sparkTestRelation.observe(
+        "my_metric",
+        min(Column("id")).as("min_val"),
+        max(Column("id")).as("max_val"),
+        sum(Column("id")))
+    comparePlans(connectPlan0, sparkPlan0)
+
+    val connectPlan1 =
+      connectTestRelation.observe("my_metric", proto_min("id".protoAttr).as("min_val"))
+    val sparkPlan1 =
+      sparkTestRelation.observe("my_metric", min(Column("id")).as("min_val"))
+    comparePlans(connectPlan1, sparkPlan1)
+
+    checkError(
+      exception = intercept[AnalysisException] {
+        analyzePlan(
+          transform(connectTestRelation.observe("my_metric", "id".protoAttr.cast("string"))))
+      },
+      errorClass = "_LEGACY_ERROR_TEMP_2322",
+      parameters = Map("sqlExpr" -> "CAST(id AS STRING) AS id"))
+
+    val connectPlan2 =
+      connectTestRelation.observe(
+        Observation("my_metric"),
+        proto_min("id".protoAttr).as("min_val"),
+        proto_max("id".protoAttr).as("max_val"),
+        proto_sum("id".protoAttr))
+    val sparkPlan2 =
+      sparkTestRelation.observe(
+        Observation("my_metric"),
+        min(Column("id")).as("min_val"),
+        max(Column("id")).as("max_val"),
+        sum(Column("id")))
+    comparePlans(connectPlan2, sparkPlan2)
+
+    val connectPlan3 =
+      connectTestRelation.observe(
+        Observation("my_metric"),
+        proto_min("id".protoAttr).as("min_val"))
+    val sparkPlan3 =
+      sparkTestRelation.observe(Observation("my_metric"), min(Column("id")).as("min_val"))
+    comparePlans(connectPlan3, sparkPlan3)
+
+    checkError(
+      exception = intercept[AnalysisException] {
+        analyzePlan(
+          transform(
+            connectTestRelation.observe(Observation("my_metric"), "id".protoAttr.cast("string"))))
+      },
+      errorClass = "_LEGACY_ERROR_TEMP_2322",
+      parameters = Map("sqlExpr" -> "CAST(id AS STRING) AS id"))
+  }
+
   test("Test RandomSplit") {
     val splitRelations0 = connectTestRelation.randomSplit(Array[Double](1, 2, 3), 1)
     val splits0 = sparkTestRelation.randomSplit(Array[Double](1, 2, 3), 1)
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 4d9bea88f5c..2885d0035bc 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.spark.sql.connect.planner
 
+import scala.collection.JavaConverters._
 import scala.collection.mutable
 
 import io.grpc.StatusRuntimeException
@@ -26,6 +27,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
 
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.connect.dsl.MockRemoteSession
+import org.apache.spark.sql.connect.dsl.expressions._
 import org.apache.spark.sql.connect.dsl.plans._
 import org.apache.spark.sql.connect.service.{SparkConnectAnalyzeHandler, SparkConnectService}
 import org.apache.spark.sql.test.SharedSparkSession
@@ -293,4 +295,71 @@ class SparkConnectServiceSuite extends SharedSparkSession {
       assert(response.getExplain.getExplainString.contains("Physical Plan"))
     }
   }
+
+  test("Test observe response") {
+    withTable("test") {
+      spark.sql("""
+                  | CREATE TABLE test (col1 INT, col2 STRING)
+                  | USING parquet
+                  |""".stripMargin)
+
+      val instance = new SparkConnectService(false)
+
+      val connect = new MockRemoteSession()
+      val context = proto.UserContext
+        .newBuilder()
+        .setUserId("c1")
+        .build()
+      val collectMetrics = proto.Relation
+        .newBuilder()
+        .setCollectMetrics(
+          proto.CollectMetrics
+            .newBuilder()
+            .setInput(connect.sql("select id, exp(id) as eid from range(0, 100, 1, 4)"))
+            .setName("my_metric")
+            .addAllMetrics(Seq(
+              proto_min("id".protoAttr).as("min_val"),
+              proto_max("id".protoAttr).as("max_val")).asJava))
+        .build()
+      val plan = proto.Plan
+        .newBuilder()
+        .setRoot(collectMetrics)
+        .build()
+      val request = proto.ExecutePlanRequest
+        .newBuilder()
+        .setPlan(plan)
+        .setUserContext(context)
+        .build()
+
+      // Execute plan.
+      @volatile var done = false
+      val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
+      instance.executePlan(
+        request,
+        new StreamObserver[proto.ExecutePlanResponse] {
+          override def onNext(v: proto.ExecutePlanResponse): Unit = responses += v
+
+          override def onError(throwable: Throwable): Unit = throw throwable
+
+          override def onCompleted(): Unit = done = true
+        })
+
+      // The current implementation is expected to be blocking. This is here to make sure it is.
+      assert(done)
+
+      assert(responses.size == 6)
+
+      // Make sure the last response is observed metrics only
+      val last = responses.last
+      assert(last.getObservedMetricsCount == 1 && !last.hasArrowBatch)
+
+      val observedMetricsList = last.getObservedMetricsList.asScala
+      val observedMetric = observedMetricsList.head
+      assert(observedMetric.getName == "my_metric")
+      assert(observedMetric.getValuesCount == 2)
+      val valuesList = observedMetric.getValuesList.asScala
+      assert(valuesList.head.hasLong && valuesList.head.getLong == 0)
+      assert(valuesList.last.hasLong && valuesList.last.getLong == 99)
+    }
+  }
 }
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 3f70ca6ad15..8be0c997538 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -373,6 +373,23 @@ class PlanMetrics:
         return self._metrics
 
 
+class PlanObservedMetrics:
+    def __init__(self, name: str, metrics: List[pb2.Expression.Literal]):
+        self._name = name
+        self._metrics = metrics
+
+    def __repr__(self) -> str:
+        return f"Plan observed({self._name}={self._metrics})"
+
+    @property
+    def name(self) -> str:
+        return self._name
+
+    @property
+    def metrics(self) -> List[pb2.Expression.Literal]:
+        return self._metrics
+
+
 class AnalyzeResult:
     def __init__(
         self,
@@ -561,6 +578,17 @@ class SparkConnectClient(object):
             for x in metrics.metrics
         ]
 
+    def _build_observed_metrics(
+        self, metrics: List["pb2.ExecutePlanResponse.ObservedMetrics"]
+    ) -> List[PlanObservedMetrics]:
+        return [
+            PlanObservedMetrics(
+                x.name,
+                [v for v in x.values],
+            )
+            for x in metrics
+        ]
+
     def to_table(self, plan: pb2.Plan) -> "pa.Table":
         """
         Return given plan as a PyArrow Table.
@@ -568,7 +596,7 @@ class SparkConnectClient(object):
         logger.info(f"Executing plan {self._proto_to_string(plan)}")
         req = self._execute_plan_request_with_metadata()
         req.plan.CopyFrom(plan)
-        table, _, _2 = self._execute_and_fetch(req)
+        table, _, _, _3 = self._execute_and_fetch(req)
         assert table is not None
         return table
 
@@ -579,7 +607,7 @@ class SparkConnectClient(object):
         logger.info(f"Executing plan {self._proto_to_string(plan)}")
         req = self._execute_plan_request_with_metadata()
         req.plan.CopyFrom(plan)
-        table, metrics, _ = self._execute_and_fetch(req)
+        table, metrics, observed_metrics, _ = self._execute_and_fetch(req)
         assert table is not None
         column_names = table.column_names
         table = table.rename_columns([f"col_{i}" for i in range(len(column_names))])
@@ -587,6 +615,8 @@ class SparkConnectClient(object):
         pdf.columns = column_names
         if len(metrics) > 0:
             pdf.attrs["metrics"] = metrics
+        if len(observed_metrics) > 0:
+            pdf.attrs["observed_metrics"] = observed_metrics
         return pdf
 
     def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType:
@@ -654,7 +684,7 @@ class SparkConnectClient(object):
         if self._user_id:
             req.user_context.user_id = self._user_id
         req.plan.command.CopyFrom(command)
-        data, _, properties = self._execute_and_fetch(req)
+        data, _, _, properties = self._execute_and_fetch(req)
         if data is not None:
             return (data.to_pandas(), properties)
         else:
@@ -781,10 +811,11 @@ class SparkConnectClient(object):
 
     def _execute_and_fetch(
         self, req: pb2.ExecutePlanRequest
-    ) -> Tuple[Optional["pa.Table"], List[PlanMetrics], Dict[str, Any]]:
+    ) -> Tuple[Optional["pa.Table"], List[PlanMetrics], List[PlanObservedMetrics], Dict[str, Any]]:
         logger.info("ExecuteAndFetch")
 
         m: Optional[pb2.ExecutePlanResponse.Metrics] = None
+        om: List[pb2.ExecutePlanResponse.ObservedMetrics] = []
         batches: List[pa.RecordBatch] = []
         properties = {}
         try:
@@ -802,6 +833,9 @@ class SparkConnectClient(object):
                         if b.metrics is not None:
                             logger.debug("Received metric batch.")
                             m = b.metrics
+                        if b.observed_metrics is not None:
+                            logger.debug("Received observed metric batch.")
+                            om.extend(b.observed_metrics)
                         if b.HasField("sql_command_result"):
                             properties["sql_command_result"] = b.sql_command_result.relation
                         if b.HasField("arrow_batch"):
@@ -817,12 +851,13 @@ class SparkConnectClient(object):
         except grpc.RpcError as rpc_error:
             self._handle_error(rpc_error)
         metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None else []
+        observed_metrics: List[PlanObservedMetrics] = self._build_observed_metrics(om)
 
         if len(batches) > 0:
             table = pa.Table.from_batches(batches=batches)
-            return table, metrics, properties
+            return table, metrics, observed_metrics, properties
         else:
-            return None, metrics, properties
+            return None, metrics, observed_metrics, properties
 
     def _config_request_with_metadata(self) -> pb2.ConfigRequest:
         req = pb2.ConfigRequest()
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 471dbf89582..fdc71d46632 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -42,6 +42,7 @@ from collections.abc import Iterable
 
 from pyspark import _NoValue
 from pyspark._globals import _NoValueType
+from pyspark.sql.observation import Observation
 from pyspark.sql.types import Row, StructType
 from pyspark.sql.dataframe import (
     DataFrame as PySparkDataFrame,
@@ -783,6 +784,31 @@ class DataFrame:
 
     randomSplit.__doc__ = PySparkDataFrame.randomSplit.__doc__
 
+    def observe(
+        self,
+        observation: Union["Observation", str],
+        *exprs: Column,
+    ) -> "DataFrame":
+        if len(exprs) == 0:
+            raise ValueError("'exprs' should not be empty")
+        if not all(isinstance(c, Column) for c in exprs):
+            raise ValueError("all 'exprs' should be Column")
+
+        if isinstance(observation, Observation):
+            return DataFrame.withPlan(
+                plan.CollectMetrics(self._plan, str(observation._name), list(exprs)),
+                self._session,
+            )
+        elif isinstance(observation, str):
+            return DataFrame.withPlan(
+                plan.CollectMetrics(self._plan, observation, list(exprs)),
+                self._session,
+            )
+        else:
+            raise ValueError("'observation' should be either `Observation` or `str`.")
+
+    observe.__doc__ = PySparkDataFrame.observe.__doc__
+
     def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None:
         print(self._show_string(n, truncate, vertical))
 
@@ -1523,9 +1549,6 @@ class DataFrame:
     def withWatermark(self, *args: Any, **kwargs: Any) -> None:
         raise NotImplementedError("withWatermark() is not implemented.")
 
-    def observe(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("observe() is not implemented.")
-
     def foreach(self, *args: Any, **kwargs: Any) -> None:
         raise NotImplementedError("foreach() is not implemented.")
 
@@ -1730,6 +1753,9 @@ def _test() -> None:
     # TODO(SPARK-41625): Support Structured Streaming
     del pyspark.sql.connect.dataframe.DataFrame.isStreaming.__doc__
 
+    # TODO(SPARK-41888): Support StreamingQueryListener for DataFrame.observe
+    del pyspark.sql.connect.dataframe.DataFrame.observe.__doc__
+
     globs["spark"] = (
         PySparkSession.builder.appName("sql.connect.dataframe tests")
         .remote("local[4]")
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 7e767885793..8e5a96b974d 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1050,6 +1050,35 @@ class Unpivot(LogicalPlan):
         return plan
 
 
+class CollectMetrics(LogicalPlan):
+    """Logical plan object for a CollectMetrics operation."""
+
+    def __init__(
+        self,
+        child: Optional["LogicalPlan"],
+        name: str,
+        exprs: List["ColumnOrName"],
+    ) -> None:
+        super().__init__(child)
+        self._name = name
+        self._exprs = exprs
+
+    def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") -> proto.Expression:
+        if isinstance(col, Column):
+            return col.to_plan(session)
+        else:
+            return self.unresolved_attr(col)
+
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        assert self._child is not None
+
+        plan = proto.Relation()
+        plan.collect_metrics.input.CopyFrom(self._child.plan(session))
+        plan.collect_metrics.name = self._name
+        plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for x in self._exprs])
+        return plan
+
+
 class NAFill(LogicalPlan):
     def __init__(
         self, child: Optional["LogicalPlan"], cols: Optional[List[str]], values: List[Any]
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py
index 628f7ebdd46..9ece82ed535 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -31,12 +31,13 @@ _sym_db = _symbol_database.Default()
 
 from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
 from pyspark.sql.connect.proto import commands_pb2 as spark_dot_connect_dot_commands__pb2
+from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2
 from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2
 from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x0 [...]
+    b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06 [...]
 )
 
 
@@ -76,6 +77,7 @@ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY = (
 _EXECUTEPLANRESPONSE_METRICS_METRICVALUE = _EXECUTEPLANRESPONSE_METRICS.nested_types_by_name[
     "MetricValue"
 ]
+_EXECUTEPLANRESPONSE_OBSERVEDMETRICS = _EXECUTEPLANRESPONSE.nested_types_by_name["ObservedMetrics"]
 _KEYVALUE = DESCRIPTOR.message_types_by_name["KeyValue"]
 _CONFIGREQUEST = DESCRIPTOR.message_types_by_name["ConfigRequest"]
 _CONFIGREQUEST_OPERATION = _CONFIGREQUEST.nested_types_by_name["Operation"]
@@ -376,6 +378,15 @@ ExecutePlanResponse = _reflection.GeneratedProtocolMessageType(
                 # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.Metrics)
             },
         ),
+        "ObservedMetrics": _reflection.GeneratedProtocolMessageType(
+            "ObservedMetrics",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _EXECUTEPLANRESPONSE_OBSERVEDMETRICS,
+                "__module__": "spark.connect.base_pb2"
+                # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.ObservedMetrics)
+            },
+        ),
         "DESCRIPTOR": _EXECUTEPLANRESPONSE,
         "__module__": "spark.connect.base_pb2"
         # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse)
@@ -388,6 +399,7 @@ _sym_db.RegisterMessage(ExecutePlanResponse.Metrics)
 _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject)
 _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntry)
 _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricValue)
+_sym_db.RegisterMessage(ExecutePlanResponse.ObservedMetrics)
 
 KeyValue = _reflection.GeneratedProtocolMessageType(
     "KeyValue",
@@ -581,100 +593,102 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
     _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None
     _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options = b"8\001"
-    _PLAN._serialized_start = 158
-    _PLAN._serialized_end = 274
-    _USERCONTEXT._serialized_start = 276
-    _USERCONTEXT._serialized_end = 398
-    _ANALYZEPLANREQUEST._serialized_start = 401
-    _ANALYZEPLANREQUEST._serialized_end = 1843
-    _ANALYZEPLANREQUEST_SCHEMA._serialized_start = 1172
-    _ANALYZEPLANREQUEST_SCHEMA._serialized_end = 1221
-    _ANALYZEPLANREQUEST_EXPLAIN._serialized_start = 1224
-    _ANALYZEPLANREQUEST_EXPLAIN._serialized_end = 1539
-    _ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE._serialized_start = 1367
-    _ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE._serialized_end = 1539
-    _ANALYZEPLANREQUEST_TREESTRING._serialized_start = 1541
-    _ANALYZEPLANREQUEST_TREESTRING._serialized_end = 1594
-    _ANALYZEPLANREQUEST_ISLOCAL._serialized_start = 1596
-    _ANALYZEPLANREQUEST_ISLOCAL._serialized_end = 1646
-    _ANALYZEPLANREQUEST_ISSTREAMING._serialized_start = 1648
-    _ANALYZEPLANREQUEST_ISSTREAMING._serialized_end = 1702
-    _ANALYZEPLANREQUEST_INPUTFILES._serialized_start = 1704
-    _ANALYZEPLANREQUEST_INPUTFILES._serialized_end = 1757
-    _ANALYZEPLANREQUEST_SPARKVERSION._serialized_start = 1759
-    _ANALYZEPLANREQUEST_SPARKVERSION._serialized_end = 1773
-    _ANALYZEPLANREQUEST_DDLPARSE._serialized_start = 1775
-    _ANALYZEPLANREQUEST_DDLPARSE._serialized_end = 1816
-    _ANALYZEPLANRESPONSE._serialized_start = 1846
-    _ANALYZEPLANRESPONSE._serialized_end = 2916
-    _ANALYZEPLANRESPONSE_SCHEMA._serialized_start = 2525
-    _ANALYZEPLANRESPONSE_SCHEMA._serialized_end = 2582
-    _ANALYZEPLANRESPONSE_EXPLAIN._serialized_start = 2584
-    _ANALYZEPLANRESPONSE_EXPLAIN._serialized_end = 2632
-    _ANALYZEPLANRESPONSE_TREESTRING._serialized_start = 2634
-    _ANALYZEPLANRESPONSE_TREESTRING._serialized_end = 2679
-    _ANALYZEPLANRESPONSE_ISLOCAL._serialized_start = 2681
-    _ANALYZEPLANRESPONSE_ISLOCAL._serialized_end = 2717
-    _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_start = 2719
-    _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_end = 2767
-    _ANALYZEPLANRESPONSE_INPUTFILES._serialized_start = 2769
-    _ANALYZEPLANRESPONSE_INPUTFILES._serialized_end = 2803
-    _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_start = 2805
-    _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_end = 2845
-    _ANALYZEPLANRESPONSE_DDLPARSE._serialized_start = 2847
-    _ANALYZEPLANRESPONSE_DDLPARSE._serialized_end = 2906
-    _EXECUTEPLANREQUEST._serialized_start = 2919
-    _EXECUTEPLANREQUEST._serialized_end = 3126
-    _EXECUTEPLANRESPONSE._serialized_start = 3129
-    _EXECUTEPLANRESPONSE._serialized_end = 4160
-    _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 3489
-    _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 3560
-    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 3562
-    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 3623
-    _EXECUTEPLANRESPONSE_METRICS._serialized_start = 3626
-    _EXECUTEPLANRESPONSE_METRICS._serialized_end = 4143
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 3721
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 4053
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 3930
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 4053
-    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 4055
-    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 4143
-    _KEYVALUE._serialized_start = 4162
-    _KEYVALUE._serialized_end = 4227
-    _CONFIGREQUEST._serialized_start = 4230
-    _CONFIGREQUEST._serialized_end = 5256
-    _CONFIGREQUEST_OPERATION._serialized_start = 4448
-    _CONFIGREQUEST_OPERATION._serialized_end = 4946
-    _CONFIGREQUEST_SET._serialized_start = 4948
-    _CONFIGREQUEST_SET._serialized_end = 5000
-    _CONFIGREQUEST_GET._serialized_start = 5002
-    _CONFIGREQUEST_GET._serialized_end = 5027
-    _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 5029
-    _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 5092
-    _CONFIGREQUEST_GETOPTION._serialized_start = 5094
-    _CONFIGREQUEST_GETOPTION._serialized_end = 5125
-    _CONFIGREQUEST_GETALL._serialized_start = 5127
-    _CONFIGREQUEST_GETALL._serialized_end = 5175
-    _CONFIGREQUEST_UNSET._serialized_start = 5177
-    _CONFIGREQUEST_UNSET._serialized_end = 5204
-    _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 5206
-    _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 5240
-    _CONFIGRESPONSE._serialized_start = 5258
-    _CONFIGRESPONSE._serialized_end = 5378
-    _ADDARTIFACTSREQUEST._serialized_start = 5381
-    _ADDARTIFACTSREQUEST._serialized_end = 6196
-    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 5728
-    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 5781
-    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 5783
-    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 5894
-    _ADDARTIFACTSREQUEST_BATCH._serialized_start = 5896
-    _ADDARTIFACTSREQUEST_BATCH._serialized_end = 5989
-    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 5992
-    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 6185
-    _ADDARTIFACTSRESPONSE._serialized_start = 6199
-    _ADDARTIFACTSRESPONSE._serialized_end = 6387
-    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 6306
-    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 6387
-    _SPARKCONNECTSERVICE._serialized_start = 6390
-    _SPARKCONNECTSERVICE._serialized_end = 6755
+    _PLAN._serialized_start = 191
+    _PLAN._serialized_end = 307
+    _USERCONTEXT._serialized_start = 309
+    _USERCONTEXT._serialized_end = 431
+    _ANALYZEPLANREQUEST._serialized_start = 434
+    _ANALYZEPLANREQUEST._serialized_end = 1876
+    _ANALYZEPLANREQUEST_SCHEMA._serialized_start = 1205
+    _ANALYZEPLANREQUEST_SCHEMA._serialized_end = 1254
+    _ANALYZEPLANREQUEST_EXPLAIN._serialized_start = 1257
+    _ANALYZEPLANREQUEST_EXPLAIN._serialized_end = 1572
+    _ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE._serialized_start = 1400
+    _ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE._serialized_end = 1572
+    _ANALYZEPLANREQUEST_TREESTRING._serialized_start = 1574
+    _ANALYZEPLANREQUEST_TREESTRING._serialized_end = 1627
+    _ANALYZEPLANREQUEST_ISLOCAL._serialized_start = 1629
+    _ANALYZEPLANREQUEST_ISLOCAL._serialized_end = 1679
+    _ANALYZEPLANREQUEST_ISSTREAMING._serialized_start = 1681
+    _ANALYZEPLANREQUEST_ISSTREAMING._serialized_end = 1735
+    _ANALYZEPLANREQUEST_INPUTFILES._serialized_start = 1737
+    _ANALYZEPLANREQUEST_INPUTFILES._serialized_end = 1790
+    _ANALYZEPLANREQUEST_SPARKVERSION._serialized_start = 1792
+    _ANALYZEPLANREQUEST_SPARKVERSION._serialized_end = 1806
+    _ANALYZEPLANREQUEST_DDLPARSE._serialized_start = 1808
+    _ANALYZEPLANREQUEST_DDLPARSE._serialized_end = 1849
+    _ANALYZEPLANRESPONSE._serialized_start = 1879
+    _ANALYZEPLANRESPONSE._serialized_end = 2949
+    _ANALYZEPLANRESPONSE_SCHEMA._serialized_start = 2558
+    _ANALYZEPLANRESPONSE_SCHEMA._serialized_end = 2615
+    _ANALYZEPLANRESPONSE_EXPLAIN._serialized_start = 2617
+    _ANALYZEPLANRESPONSE_EXPLAIN._serialized_end = 2665
+    _ANALYZEPLANRESPONSE_TREESTRING._serialized_start = 2667
+    _ANALYZEPLANRESPONSE_TREESTRING._serialized_end = 2712
+    _ANALYZEPLANRESPONSE_ISLOCAL._serialized_start = 2714
+    _ANALYZEPLANRESPONSE_ISLOCAL._serialized_end = 2750
+    _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_start = 2752
+    _ANALYZEPLANRESPONSE_ISSTREAMING._serialized_end = 2800
+    _ANALYZEPLANRESPONSE_INPUTFILES._serialized_start = 2802
+    _ANALYZEPLANRESPONSE_INPUTFILES._serialized_end = 2836
+    _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_start = 2838
+    _ANALYZEPLANRESPONSE_SPARKVERSION._serialized_end = 2878
+    _ANALYZEPLANRESPONSE_DDLPARSE._serialized_start = 2880
+    _ANALYZEPLANRESPONSE_DDLPARSE._serialized_end = 2939
+    _EXECUTEPLANREQUEST._serialized_start = 2952
+    _EXECUTEPLANREQUEST._serialized_end = 3159
+    _EXECUTEPLANRESPONSE._serialized_start = 3162
+    _EXECUTEPLANRESPONSE._serialized_end = 4386
+    _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 3617
+    _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 3688
+    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 3690
+    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 3751
+    _EXECUTEPLANRESPONSE_METRICS._serialized_start = 3754
+    _EXECUTEPLANRESPONSE_METRICS._serialized_end = 4271
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 3849
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 4181
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 4058
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 4181
+    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 4183
+    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 4271
+    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 4273
+    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 4369
+    _KEYVALUE._serialized_start = 4388
+    _KEYVALUE._serialized_end = 4453
+    _CONFIGREQUEST._serialized_start = 4456
+    _CONFIGREQUEST._serialized_end = 5482
+    _CONFIGREQUEST_OPERATION._serialized_start = 4674
+    _CONFIGREQUEST_OPERATION._serialized_end = 5172
+    _CONFIGREQUEST_SET._serialized_start = 5174
+    _CONFIGREQUEST_SET._serialized_end = 5226
+    _CONFIGREQUEST_GET._serialized_start = 5228
+    _CONFIGREQUEST_GET._serialized_end = 5253
+    _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 5255
+    _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 5318
+    _CONFIGREQUEST_GETOPTION._serialized_start = 5320
+    _CONFIGREQUEST_GETOPTION._serialized_end = 5351
+    _CONFIGREQUEST_GETALL._serialized_start = 5353
+    _CONFIGREQUEST_GETALL._serialized_end = 5401
+    _CONFIGREQUEST_UNSET._serialized_start = 5403
+    _CONFIGREQUEST_UNSET._serialized_end = 5430
+    _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 5432
+    _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 5466
+    _CONFIGRESPONSE._serialized_start = 5484
+    _CONFIGRESPONSE._serialized_end = 5604
+    _ADDARTIFACTSREQUEST._serialized_start = 5607
+    _ADDARTIFACTSREQUEST._serialized_end = 6422
+    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 5954
+    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 6007
+    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 6009
+    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 6120
+    _ADDARTIFACTSREQUEST_BATCH._serialized_start = 6122
+    _ADDARTIFACTSREQUEST_BATCH._serialized_end = 6215
+    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 6218
+    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 6411
+    _ADDARTIFACTSRESPONSE._serialized_start = 6425
+    _ADDARTIFACTSRESPONSE._serialized_end = 6613
+    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 6532
+    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 6613
+    _SPARKCONNECTSERVICE._serialized_start = 6616
+    _SPARKCONNECTSERVICE._serialized_end = 6981
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 0e800947975..56e1d2e416f 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -41,6 +41,7 @@ import google.protobuf.internal.containers
 import google.protobuf.internal.enum_type_wrapper
 import google.protobuf.message
 import pyspark.sql.connect.proto.commands_pb2
+import pyspark.sql.connect.proto.expressions_pb2
 import pyspark.sql.connect.proto.relations_pb2
 import pyspark.sql.connect.proto.types_pb2
 import sys
@@ -905,11 +906,37 @@ class ExecutePlanResponse(google.protobuf.message.Message):
             self, field_name: typing_extensions.Literal["metrics", b"metrics"]
         ) -> None: ...
 
+    class ObservedMetrics(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        NAME_FIELD_NUMBER: builtins.int
+        VALUES_FIELD_NUMBER: builtins.int
+        name: builtins.str
+        @property
+        def values(
+            self,
+        ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+            pyspark.sql.connect.proto.expressions_pb2.Expression.Literal
+        ]: ...
+        def __init__(
+            self,
+            *,
+            name: builtins.str = ...,
+            values: collections.abc.Iterable[
+                pyspark.sql.connect.proto.expressions_pb2.Expression.Literal
+            ]
+            | None = ...,
+        ) -> None: ...
+        def ClearField(
+            self, field_name: typing_extensions.Literal["name", b"name", "values", b"values"]
+        ) -> None: ...
+
     CLIENT_ID_FIELD_NUMBER: builtins.int
     ARROW_BATCH_FIELD_NUMBER: builtins.int
     SQL_COMMAND_RESULT_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     METRICS_FIELD_NUMBER: builtins.int
+    OBSERVED_METRICS_FIELD_NUMBER: builtins.int
     client_id: builtins.str
     @property
     def arrow_batch(self) -> global___ExecutePlanResponse.ArrowBatch: ...
@@ -924,6 +951,13 @@ class ExecutePlanResponse(google.protobuf.message.Message):
         """Metrics for the query execution. Typically, this field is only present in the last
         batch of results and then represent the overall state of the query execution.
         """
+    @property
+    def observed_metrics(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+        global___ExecutePlanResponse.ObservedMetrics
+    ]:
+        """The metrics observed during the execution of the query plan."""
     def __init__(
         self,
         *,
@@ -932,6 +966,8 @@ class ExecutePlanResponse(google.protobuf.message.Message):
         sql_command_result: global___ExecutePlanResponse.SqlCommandResult | None = ...,
         extension: google.protobuf.any_pb2.Any | None = ...,
         metrics: global___ExecutePlanResponse.Metrics | None = ...,
+        observed_metrics: collections.abc.Iterable[global___ExecutePlanResponse.ObservedMetrics]
+        | None = ...,
     ) -> None: ...
     def HasField(
         self,
@@ -959,6 +995,8 @@ class ExecutePlanResponse(google.protobuf.message.Message):
             b"extension",
             "metrics",
             b"metrics",
+            "observed_metrics",
+            b"observed_metrics",
             "response_type",
             b"response_type",
             "sql_command_result",
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 3b4573177a2..f38cff1990d 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xfb\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -92,6 +92,7 @@ _UNPIVOT_VALUES = _UNPIVOT.nested_types_by_name["Values"]
 _TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"]
 _REPARTITIONBYEXPRESSION = DESCRIPTOR.message_types_by_name["RepartitionByExpression"]
 _FRAMEMAP = DESCRIPTOR.message_types_by_name["FrameMap"]
+_COLLECTMETRICS = DESCRIPTOR.message_types_by_name["CollectMetrics"]
 _JOIN_JOINTYPE = _JOIN.enum_types_by_name["JoinType"]
 _SETOPERATION_SETOPTYPE = _SETOPERATION.enum_types_by_name["SetOpType"]
 _AGGREGATE_GROUPTYPE = _AGGREGATE.enum_types_by_name["GroupType"]
@@ -636,6 +637,17 @@ FrameMap = _reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(FrameMap)
 
+CollectMetrics = _reflection.GeneratedProtocolMessageType(
+    "CollectMetrics",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _COLLECTMETRICS,
+        "__module__": "spark.connect.relations_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.CollectMetrics)
+    },
+)
+_sym_db.RegisterMessage(CollectMetrics)
+
 if _descriptor._USE_C_DESCRIPTORS == False:
 
     DESCRIPTOR._options = None
@@ -647,109 +659,111 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._options = None
     _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001"
     _RELATION._serialized_start = 165
-    _RELATION._serialized_end = 2518
-    _UNKNOWN._serialized_start = 2520
-    _UNKNOWN._serialized_end = 2529
-    _RELATIONCOMMON._serialized_start = 2531
-    _RELATIONCOMMON._serialized_end = 2622
-    _SQL._serialized_start = 2625
-    _SQL._serialized_end = 2759
-    _SQL_ARGSENTRY._serialized_start = 2704
-    _SQL_ARGSENTRY._serialized_end = 2759
-    _READ._serialized_start = 2762
-    _READ._serialized_end = 3226
-    _READ_NAMEDTABLE._serialized_start = 2904
-    _READ_NAMEDTABLE._serialized_end = 2965
-    _READ_DATASOURCE._serialized_start = 2968
-    _READ_DATASOURCE._serialized_end = 3213
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3133
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3191
-    _PROJECT._serialized_start = 3228
-    _PROJECT._serialized_end = 3345
-    _FILTER._serialized_start = 3347
-    _FILTER._serialized_end = 3459
-    _JOIN._serialized_start = 3462
-    _JOIN._serialized_end = 3933
-    _JOIN_JOINTYPE._serialized_start = 3725
-    _JOIN_JOINTYPE._serialized_end = 3933
-    _SETOPERATION._serialized_start = 3936
-    _SETOPERATION._serialized_end = 4415
-    _SETOPERATION_SETOPTYPE._serialized_start = 4252
-    _SETOPERATION_SETOPTYPE._serialized_end = 4366
-    _LIMIT._serialized_start = 4417
-    _LIMIT._serialized_end = 4493
-    _OFFSET._serialized_start = 4495
-    _OFFSET._serialized_end = 4574
-    _TAIL._serialized_start = 4576
-    _TAIL._serialized_end = 4651
-    _AGGREGATE._serialized_start = 4654
-    _AGGREGATE._serialized_end = 5236
-    _AGGREGATE_PIVOT._serialized_start = 4993
-    _AGGREGATE_PIVOT._serialized_end = 5104
-    _AGGREGATE_GROUPTYPE._serialized_start = 5107
-    _AGGREGATE_GROUPTYPE._serialized_end = 5236
-    _SORT._serialized_start = 5239
-    _SORT._serialized_end = 5399
-    _DROP._serialized_start = 5402
-    _DROP._serialized_end = 5543
-    _DEDUPLICATE._serialized_start = 5546
-    _DEDUPLICATE._serialized_end = 5717
-    _LOCALRELATION._serialized_start = 5719
-    _LOCALRELATION._serialized_end = 5808
-    _SAMPLE._serialized_start = 5811
-    _SAMPLE._serialized_end = 6084
-    _RANGE._serialized_start = 6087
-    _RANGE._serialized_end = 6232
-    _SUBQUERYALIAS._serialized_start = 6234
-    _SUBQUERYALIAS._serialized_end = 6348
-    _REPARTITION._serialized_start = 6351
-    _REPARTITION._serialized_end = 6493
-    _SHOWSTRING._serialized_start = 6496
-    _SHOWSTRING._serialized_end = 6638
-    _STATSUMMARY._serialized_start = 6640
-    _STATSUMMARY._serialized_end = 6732
-    _STATDESCRIBE._serialized_start = 6734
-    _STATDESCRIBE._serialized_end = 6815
-    _STATCROSSTAB._serialized_start = 6817
-    _STATCROSSTAB._serialized_end = 6918
-    _STATCOV._serialized_start = 6920
-    _STATCOV._serialized_end = 7016
-    _STATCORR._serialized_start = 7019
-    _STATCORR._serialized_end = 7156
-    _STATAPPROXQUANTILE._serialized_start = 7159
-    _STATAPPROXQUANTILE._serialized_end = 7323
-    _STATFREQITEMS._serialized_start = 7325
-    _STATFREQITEMS._serialized_end = 7450
-    _STATSAMPLEBY._serialized_start = 7453
-    _STATSAMPLEBY._serialized_end = 7762
-    _STATSAMPLEBY_FRACTION._serialized_start = 7654
-    _STATSAMPLEBY_FRACTION._serialized_end = 7753
-    _NAFILL._serialized_start = 7765
-    _NAFILL._serialized_end = 7899
-    _NADROP._serialized_start = 7902
-    _NADROP._serialized_end = 8036
-    _NAREPLACE._serialized_start = 8039
-    _NAREPLACE._serialized_end = 8335
-    _NAREPLACE_REPLACEMENT._serialized_start = 8194
-    _NAREPLACE_REPLACEMENT._serialized_end = 8335
-    _TODF._serialized_start = 8337
-    _TODF._serialized_end = 8425
-    _WITHCOLUMNSRENAMED._serialized_start = 8428
-    _WITHCOLUMNSRENAMED._serialized_end = 8667
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8600
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8667
-    _WITHCOLUMNS._serialized_start = 8669
-    _WITHCOLUMNS._serialized_end = 8788
-    _HINT._serialized_start = 8791
-    _HINT._serialized_end = 8923
-    _UNPIVOT._serialized_start = 8926
-    _UNPIVOT._serialized_end = 9253
-    _UNPIVOT_VALUES._serialized_start = 9183
-    _UNPIVOT_VALUES._serialized_end = 9242
-    _TOSCHEMA._serialized_start = 9255
-    _TOSCHEMA._serialized_end = 9361
-    _REPARTITIONBYEXPRESSION._serialized_start = 9364
-    _REPARTITIONBYEXPRESSION._serialized_end = 9567
-    _FRAMEMAP._serialized_start = 9569
-    _FRAMEMAP._serialized_end = 9694
+    _RELATION._serialized_end = 2592
+    _UNKNOWN._serialized_start = 2594
+    _UNKNOWN._serialized_end = 2603
+    _RELATIONCOMMON._serialized_start = 2605
+    _RELATIONCOMMON._serialized_end = 2696
+    _SQL._serialized_start = 2699
+    _SQL._serialized_end = 2833
+    _SQL_ARGSENTRY._serialized_start = 2778
+    _SQL_ARGSENTRY._serialized_end = 2833
+    _READ._serialized_start = 2836
+    _READ._serialized_end = 3300
+    _READ_NAMEDTABLE._serialized_start = 2978
+    _READ_NAMEDTABLE._serialized_end = 3039
+    _READ_DATASOURCE._serialized_start = 3042
+    _READ_DATASOURCE._serialized_end = 3287
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3207
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3265
+    _PROJECT._serialized_start = 3302
+    _PROJECT._serialized_end = 3419
+    _FILTER._serialized_start = 3421
+    _FILTER._serialized_end = 3533
+    _JOIN._serialized_start = 3536
+    _JOIN._serialized_end = 4007
+    _JOIN_JOINTYPE._serialized_start = 3799
+    _JOIN_JOINTYPE._serialized_end = 4007
+    _SETOPERATION._serialized_start = 4010
+    _SETOPERATION._serialized_end = 4489
+    _SETOPERATION_SETOPTYPE._serialized_start = 4326
+    _SETOPERATION_SETOPTYPE._serialized_end = 4440
+    _LIMIT._serialized_start = 4491
+    _LIMIT._serialized_end = 4567
+    _OFFSET._serialized_start = 4569
+    _OFFSET._serialized_end = 4648
+    _TAIL._serialized_start = 4650
+    _TAIL._serialized_end = 4725
+    _AGGREGATE._serialized_start = 4728
+    _AGGREGATE._serialized_end = 5310
+    _AGGREGATE_PIVOT._serialized_start = 5067
+    _AGGREGATE_PIVOT._serialized_end = 5178
+    _AGGREGATE_GROUPTYPE._serialized_start = 5181
+    _AGGREGATE_GROUPTYPE._serialized_end = 5310
+    _SORT._serialized_start = 5313
+    _SORT._serialized_end = 5473
+    _DROP._serialized_start = 5476
+    _DROP._serialized_end = 5617
+    _DEDUPLICATE._serialized_start = 5620
+    _DEDUPLICATE._serialized_end = 5791
+    _LOCALRELATION._serialized_start = 5793
+    _LOCALRELATION._serialized_end = 5882
+    _SAMPLE._serialized_start = 5885
+    _SAMPLE._serialized_end = 6158
+    _RANGE._serialized_start = 6161
+    _RANGE._serialized_end = 6306
+    _SUBQUERYALIAS._serialized_start = 6308
+    _SUBQUERYALIAS._serialized_end = 6422
+    _REPARTITION._serialized_start = 6425
+    _REPARTITION._serialized_end = 6567
+    _SHOWSTRING._serialized_start = 6570
+    _SHOWSTRING._serialized_end = 6712
+    _STATSUMMARY._serialized_start = 6714
+    _STATSUMMARY._serialized_end = 6806
+    _STATDESCRIBE._serialized_start = 6808
+    _STATDESCRIBE._serialized_end = 6889
+    _STATCROSSTAB._serialized_start = 6891
+    _STATCROSSTAB._serialized_end = 6992
+    _STATCOV._serialized_start = 6994
+    _STATCOV._serialized_end = 7090
+    _STATCORR._serialized_start = 7093
+    _STATCORR._serialized_end = 7230
+    _STATAPPROXQUANTILE._serialized_start = 7233
+    _STATAPPROXQUANTILE._serialized_end = 7397
+    _STATFREQITEMS._serialized_start = 7399
+    _STATFREQITEMS._serialized_end = 7524
+    _STATSAMPLEBY._serialized_start = 7527
+    _STATSAMPLEBY._serialized_end = 7836
+    _STATSAMPLEBY_FRACTION._serialized_start = 7728
+    _STATSAMPLEBY_FRACTION._serialized_end = 7827
+    _NAFILL._serialized_start = 7839
+    _NAFILL._serialized_end = 7973
+    _NADROP._serialized_start = 7976
+    _NADROP._serialized_end = 8110
+    _NAREPLACE._serialized_start = 8113
+    _NAREPLACE._serialized_end = 8409
+    _NAREPLACE_REPLACEMENT._serialized_start = 8268
+    _NAREPLACE_REPLACEMENT._serialized_end = 8409
+    _TODF._serialized_start = 8411
+    _TODF._serialized_end = 8499
+    _WITHCOLUMNSRENAMED._serialized_start = 8502
+    _WITHCOLUMNSRENAMED._serialized_end = 8741
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8674
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8741
+    _WITHCOLUMNS._serialized_start = 8743
+    _WITHCOLUMNS._serialized_end = 8862
+    _HINT._serialized_start = 8865
+    _HINT._serialized_end = 8997
+    _UNPIVOT._serialized_start = 9000
+    _UNPIVOT._serialized_end = 9327
+    _UNPIVOT_VALUES._serialized_start = 9257
+    _UNPIVOT_VALUES._serialized_end = 9316
+    _TOSCHEMA._serialized_start = 9329
+    _TOSCHEMA._serialized_end = 9435
+    _REPARTITIONBYEXPRESSION._serialized_start = 9438
+    _REPARTITIONBYEXPRESSION._serialized_end = 9641
+    _FRAMEMAP._serialized_start = 9643
+    _FRAMEMAP._serialized_end = 9768
+    _COLLECTMETRICS._serialized_start = 9771
+    _COLLECTMETRICS._serialized_end = 9907
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index b60fd5a1a61..e4d1653f0ba 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -90,6 +90,7 @@ class Relation(google.protobuf.message.Message):
     TO_SCHEMA_FIELD_NUMBER: builtins.int
     REPARTITION_BY_EXPRESSION_FIELD_NUMBER: builtins.int
     FRAME_MAP_FIELD_NUMBER: builtins.int
+    COLLECT_METRICS_FIELD_NUMBER: builtins.int
     FILL_NA_FIELD_NUMBER: builtins.int
     DROP_NA_FIELD_NUMBER: builtins.int
     REPLACE_FIELD_NUMBER: builtins.int
@@ -161,6 +162,8 @@ class Relation(google.protobuf.message.Message):
     @property
     def frame_map(self) -> global___FrameMap: ...
     @property
+    def collect_metrics(self) -> global___CollectMetrics: ...
+    @property
     def fill_na(self) -> global___NAFill:
         """NA functions"""
     @property
@@ -225,6 +228,7 @@ class Relation(google.protobuf.message.Message):
         to_schema: global___ToSchema | None = ...,
         repartition_by_expression: global___RepartitionByExpression | None = ...,
         frame_map: global___FrameMap | None = ...,
+        collect_metrics: global___CollectMetrics | None = ...,
         fill_na: global___NAFill | None = ...,
         drop_na: global___NADrop | None = ...,
         replace: global___NAReplace | None = ...,
@@ -249,6 +253,8 @@ class Relation(google.protobuf.message.Message):
             b"approx_quantile",
             "catalog",
             b"catalog",
+            "collect_metrics",
+            b"collect_metrics",
             "common",
             b"common",
             "corr",
@@ -340,6 +346,8 @@ class Relation(google.protobuf.message.Message):
             b"approx_quantile",
             "catalog",
             b"catalog",
+            "collect_metrics",
+            b"collect_metrics",
             "common",
             b"common",
             "corr",
@@ -452,6 +460,7 @@ class Relation(google.protobuf.message.Message):
         "to_schema",
         "repartition_by_expression",
         "frame_map",
+        "collect_metrics",
         "fill_na",
         "drop_na",
         "replace",
@@ -2702,3 +2711,43 @@ class FrameMap(google.protobuf.message.Message):
     ) -> None: ...
 
 global___FrameMap = FrameMap
+
+class CollectMetrics(google.protobuf.message.Message):
+    """Collect arbitrary (named) metrics from a dataset."""
+
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    INPUT_FIELD_NUMBER: builtins.int
+    NAME_FIELD_NUMBER: builtins.int
+    METRICS_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> global___Relation:
+        """(Required) The input relation."""
+    name: builtins.str
+    """(Required) Name of the metrics."""
+    @property
+    def metrics(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+        pyspark.sql.connect.proto.expressions_pb2.Expression
+    ]:
+        """(Required) The metric sequence."""
+    def __init__(
+        self,
+        *,
+        input: global___Relation | None = ...,
+        name: builtins.str = ...,
+        metrics: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression]
+        | None = ...,
+    ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["input", b"input"]
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "input", b"input", "metrics", b"metrics", "name", b"name"
+        ],
+    ) -> None: ...
+
+global___CollectMetrics = CollectMetrics
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 70938459f79..68ac8b1dd7c 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -59,6 +59,7 @@ if should_test_connect:
     import grpc
     import pandas as pd
     import numpy as np
+    from pyspark.sql.connect.proto import Expression as ProtoExpression
     from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
     from pyspark.sql.connect.client import ChannelBuilder
     from pyspark.sql.connect.column import Column
@@ -1631,6 +1632,57 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.assert_eq(relations[i].toPandas(), datasets[i].toPandas())
             i += 1
 
+    def test_observe(self):
+        # SPARK-41527: test DataFrame.observe()
+        from pyspark.sql import functions as SF
+        from pyspark.sql.connect import functions as CF
+
+        observation_name = "my_metric"
+
+        self.assert_eq(
+            self.connect.read.table(self.tbl_name)
+            .filter("id > 3")
+            .observe(observation_name, CF.min("id"), CF.max("id"), CF.sum("id"))
+            .toPandas(),
+            self.spark.read.table(self.tbl_name)
+            .filter("id > 3")
+            .observe(observation_name, SF.min("id"), SF.max("id"), SF.sum("id"))
+            .toPandas(),
+        )
+
+        from pyspark.sql.observation import Observation
+
+        observation = Observation(observation_name)
+
+        cdf = (
+            self.connect.read.table(self.tbl_name)
+            .filter("id > 3")
+            .observe(observation, CF.min("id"), CF.max("id"), CF.sum("id"))
+            .toPandas()
+        )
+        df = (
+            self.spark.read.table(self.tbl_name)
+            .filter("id > 3")
+            .observe(observation, SF.min("id"), SF.max("id"), SF.sum("id"))
+            .toPandas()
+        )
+
+        self.assert_eq(cdf, df)
+
+        observed_metrics = cdf.attrs["observed_metrics"]
+        self.assert_eq(len(observed_metrics), 1)
+        self.assert_eq(observed_metrics[0].name, observation_name)
+        self.assert_eq(len(observed_metrics[0].metrics), 3)
+        for metric in observed_metrics[0].metrics:
+            self.assertIsInstance(metric, ProtoExpression.Literal)
+        values = list(map(lambda metric: metric.long, observed_metrics[0].metrics))
+        self.assert_eq(values, [4, 99, 4944])
+
+        with self.assertRaisesRegex(ValueError, "'exprs' should not be empty"):
+            self.connect.read.table(self.tbl_name).observe(observation_name)
+        with self.assertRaisesRegex(ValueError, "all 'exprs' should be Column"):
+            self.connect.read.table(self.tbl_name).observe(observation_name, CF.lit(1), "id")
+
     def test_with_columns(self):
         # SPARK-41256: test withColumn(s).
         self.assert_eq(
@@ -2766,7 +2818,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             "cache",
             "persist",
             "withWatermark",
-            "observe",
             "foreach",
             "foreachPartition",
             "toLocalIterator",
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py
index 8c09b9cfaa5..115d189b742 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -31,7 +31,7 @@ if should_test_connect:
     from pyspark.sql.connect.dataframe import DataFrame
     from pyspark.sql.connect.plan import WriteOperation, Read
     from pyspark.sql.connect.readwriter import DataFrameReader
-    from pyspark.sql.connect.functions import col, lit
+    from pyspark.sql.connect.functions import col, lit, max, min, sum
     from pyspark.sql.connect.types import pyspark_types_to_proto_types
     from pyspark.sql.types import (
         StringType,
@@ -294,6 +294,68 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
         relations = df.filter(df.col_name > 3).randomSplit([1.0, 2.0, 3.0])
         checkRelations(relations)
 
+    def test_observe(self):
+        # SPARK-41527: test DataFrame.observe()
+        df = self.connect.readTable(table_name=self.tbl_name)
+
+        plan = (
+            df.filter(df.col_name > 3)
+            .observe("my_metric", min("id"), max("id"), sum("id"))
+            ._plan.to_proto(self.connect)
+        )
+        self.assertEqual(plan.root.collect_metrics.name, "my_metric")
+        self.assertTrue(
+            all(isinstance(c, proto.Expression) for c in plan.root.collect_metrics.metrics)
+        )
+        self.assertEqual(
+            plan.root.collect_metrics.metrics[0].unresolved_function.function_name, "min"
+        )
+        self.assertTrue(
+            len(plan.root.collect_metrics.metrics[0].unresolved_function.arguments) == 1
+        )
+        self.assertTrue(
+            all(
+                isinstance(c, proto.Expression)
+                for c in plan.root.collect_metrics.metrics[0].unresolved_function.arguments
+            )
+        )
+        self.assertEqual(
+            plan.root.collect_metrics.metrics[0]
+            .unresolved_function.arguments[0]
+            .unresolved_attribute.unparsed_identifier,
+            "id",
+        )
+
+        from pyspark.sql.observation import Observation
+
+        plan = (
+            df.filter(df.col_name > 3)
+            .observe(Observation("my_metric"), min("id"), max("id"), sum("id"))
+            ._plan.to_proto(self.connect)
+        )
+        self.assertEqual(plan.root.collect_metrics.name, "my_metric")
+        self.assertTrue(
+            all(isinstance(c, proto.Expression) for c in plan.root.collect_metrics.metrics)
+        )
+        self.assertEqual(
+            plan.root.collect_metrics.metrics[0].unresolved_function.function_name, "min"
+        )
+        self.assertTrue(
+            len(plan.root.collect_metrics.metrics[0].unresolved_function.arguments) == 1
+        )
+        self.assertTrue(
+            all(
+                isinstance(c, proto.Expression)
+                for c in plan.root.collect_metrics.metrics[0].unresolved_function.arguments
+            )
+        )
+        self.assertEqual(
+            plan.root.collect_metrics.metrics[0]
+            .unresolved_function.arguments[0]
+            .unresolved_attribute.unparsed_identifier,
+            "id",
+        )
+
     def test_summary(self):
         df = self.connect.readTable(table_name=self.tbl_name)
         plan = df.filter(df.col_name > 3).summary()._plan.to_proto(self.connect)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
index a2750dd81e9..df2c99dde5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -45,7 +45,7 @@ import org.apache.spark.sql.util.QueryExecutionListener
  * @param name name of the metric
  * @since 3.3.0
  */
-class Observation(name: String) {
+class Observation(val name: String) {
 
   if (name.isEmpty) throw new IllegalArgumentException("Name must not be empty")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org