You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/03/30 11:44:14 UTC

[spark] branch branch-3.4 updated: Revert "[SPARK-41765][SQL] Pull out v1 write metrics to WriteFiles"

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new e586527f79c Revert "[SPARK-41765][SQL] Pull out v1 write metrics to WriteFiles"
e586527f79c is described below

commit e586527f79c93519467d202a4e258864fda6e8f8
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Thu Mar 30 19:43:36 2023 +0800

    Revert "[SPARK-41765][SQL] Pull out v1 write metrics to WriteFiles"
    
    This reverts commit a111a02de1a814c5f335e0bcac4cffb0515557dc.
    
    ### What changes were proposed in this pull request?
    
    SQLMetrics is not only used in the UI, but is also a programming API as users can write a listener, get the physical plan, and read the SQLMetrics values directly.
    
    We can ask users to update their code and read SQLMetrics from the new `WriteFiles` node instead. But this is troublesome and sometimes they may need to get both write metrics and commit metrics, then they need to look at two physical plan nodes. Given that https://github.com/apache/spark/pull/39428 is mostly for cleanup and does not have many benefits, reverting is a better idea.
    
    ### Why are the changes needed?
    
    avoid breaking changes.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, they can programmatically get the write command metrics as before.
    
    ### How was this patch tested?
    
    N/A
    
    Closes #40604 from cloud-fan/revert.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit f4af6a05c0879887e7db2377a174e7b7d7bab693)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/execution/command/DataWritingCommand.scala | 47 ++++++++-------
 .../datasources/BasicWriteStatsTracker.scala       | 17 +-----
 .../execution/datasources/FileFormatWriter.scala   | 69 +++++++++-------------
 .../sql/execution/datasources/WriteFiles.scala     |  3 -
 .../BasicWriteJobStatsTrackerMetricSuite.scala     | 19 +++---
 .../sql/execution/metric/SQLMetricsSuite.scala     | 28 ++++-----
 .../sql/execution/metric/SQLMetricsTestUtils.scala | 21 +------
 .../hive/execution/InsertIntoHiveDirCommand.scala  |  7 ---
 .../spark/sql/hive/execution/SQLMetricsSuite.scala | 42 +++++--------
 9 files changed, 94 insertions(+), 159 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala
index 58c3fca4ad7..338ce8cac42 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala
@@ -21,13 +21,14 @@ import java.net.URI
 
 import org.apache.hadoop.conf.Configuration
 
+import org.apache.spark.SparkContext
 import org.apache.spark.sql.{Row, SaveMode, SparkSession}
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryCommand}
 import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
 import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
-import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.util.SerializableConfiguration
 
@@ -51,19 +52,11 @@ trait DataWritingCommand extends UnaryCommand {
   def outputColumns: Seq[Attribute] =
     DataWritingCommand.logicalPlanOutputWithNames(query, outputColumnNames)
 
-  lazy val metrics: Map[String, SQLMetric] = {
-    // If planned write is enable, we have pulled out write files metrics from `V1WriteCommand`
-    // to `WriteFiles`. `DataWritingCommand` should only holds the task commit metric and driver
-    // commit metric.
-    if (conf.getConf(SQLConf.PLANNED_WRITE_ENABLED)) {
-      BasicWriteJobStatsTracker.writeCommitMetrics
-    } else {
-      BasicWriteJobStatsTracker.metrics
-    }
-  }
+  lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics
 
   def basicWriteJobStatsTracker(hadoopConf: Configuration): BasicWriteJobStatsTracker = {
-    DataWritingCommand.basicWriteJobStatsTracker(metrics, hadoopConf)
+    val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
+    new BasicWriteJobStatsTracker(serializableHadoopConf, metrics)
   }
 
   def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row]
@@ -86,6 +79,27 @@ object DataWritingCommand {
     }
   }
 
+  /**
+   * When execute CTAS operators, Spark will use [[InsertIntoHadoopFsRelationCommand]]
+   * or [[InsertIntoHiveTable]] command to write data, they both inherit metrics from
+   * [[DataWritingCommand]], but after running [[InsertIntoHadoopFsRelationCommand]]
+   * or [[InsertIntoHiveTable]], we only update metrics in these two command through
+   * [[BasicWriteJobStatsTracker]], we also need to propogate metrics to the command
+   * that actually calls [[InsertIntoHadoopFsRelationCommand]] or [[InsertIntoHiveTable]].
+   *
+   * @param sparkContext Current SparkContext.
+   * @param command Command to execute writing data.
+   * @param metrics Metrics of real DataWritingCommand.
+   */
+  def propogateMetrics(
+      sparkContext: SparkContext,
+      command: DataWritingCommand,
+      metrics: Map[String, SQLMetric]): Unit = {
+    command.metrics.foreach { case (key, metric) => metrics(key).set(metric.value) }
+    SQLMetrics.postDriverMetricUpdates(sparkContext,
+      sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY),
+      metrics.values.toSeq)
+  }
   /**
    * When execute CTAS operators, and the location is not empty, throw [[AnalysisException]].
    * For CTAS, the SaveMode is always [[ErrorIfExists]]
@@ -106,11 +120,4 @@ object DataWritingCommand {
       }
     }
   }
-
-  def basicWriteJobStatsTracker(
-      metrics: Map[String, SQLMetric],
-      hadoopConf: Configuration): BasicWriteJobStatsTracker = {
-    val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
-    new BasicWriteJobStatsTracker(serializableHadoopConf, metrics)
-  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
index 980659b9fab..47685899784 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala
@@ -197,11 +197,6 @@ class BasicWriteJobStatsTracker(
     this(serializableHadoopConf, metrics - TASK_COMMIT_TIME, metrics(TASK_COMMIT_TIME))
   }
 
-  def writeCommitMetrics(): Map[String, SQLMetric] = {
-    Map(TASK_COMMIT_TIME -> taskCommitTimeMetric,
-      JOB_COMMIT_TIME -> driverSideMetrics(JOB_COMMIT_TIME))
-  }
-
   override def newTaskInstance(): WriteTaskStatsTracker = {
     new BasicWriteTaskStatsTracker(serializableHadoopConf.value, Some(taskCommitTimeMetric))
   }
@@ -244,22 +239,12 @@ object BasicWriteJobStatsTracker {
   val FILE_LENGTH_XATTR = "header.x-hadoop-s3a-magic-data-length"
 
   def metrics: Map[String, SQLMetric] = {
-    writeFilesMetrics ++ writeCommitMetrics
-  }
-
-  def writeFilesMetrics: Map[String, SQLMetric] = {
     val sparkContext = SparkContext.getActive.get
     Map(
       NUM_FILES_KEY -> SQLMetrics.createMetric(sparkContext, "number of written files"),
       NUM_OUTPUT_BYTES_KEY -> SQLMetrics.createSizeMetric(sparkContext, "written output"),
       NUM_OUTPUT_ROWS_KEY -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
-      NUM_PARTS_KEY -> SQLMetrics.createMetric(sparkContext, "number of dynamic part")
-    )
-  }
-
-  def writeCommitMetrics: Map[String, SQLMetric] = {
-    val sparkContext = SparkContext.getActive.get
-    Map(
+      NUM_PARTS_KEY -> SQLMetrics.createMetric(sparkContext, "number of dynamic part"),
       TASK_COMMIT_TIME -> SQLMetrics.createTimingMetric(sparkContext, "task commit time"),
       JOB_COMMIT_TIME -> SQLMetrics.createTimingMetric(sparkContext, "job commit time")
     )
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 8f1eaf4ab77..8321b1fac71 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -40,7 +40,6 @@ import org.apache.spark.sql.connector.write.WriterCommitMessage
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
-import org.apache.spark.sql.execution.command.DataWritingCommand
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.util.{SerializableConfiguration, Utils}
 
@@ -123,6 +122,29 @@ object FileFormatWriter extends Logging {
     val outputWriterFactory =
       fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema)
 
+    val description = new WriteJobDescription(
+      uuid = UUID.randomUUID.toString,
+      serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
+      outputWriterFactory = outputWriterFactory,
+      allColumns = finalOutputSpec.outputColumns,
+      dataColumns = dataColumns,
+      partitionColumns = partitionColumns,
+      bucketSpec = writerBucketSpec,
+      path = finalOutputSpec.outputPath,
+      customPartitionLocations = finalOutputSpec.customPartitionLocations,
+      maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
+        .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
+      timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION)
+        .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone),
+      statsTrackers = statsTrackers
+    )
+
+    // We should first sort by dynamic partition columns, then bucket id, and finally sorting
+    // columns.
+    val requiredOrdering = partitionColumns.drop(numStaticPartitionCols) ++
+        writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns
+    val writeFilesOpt = V1WritesUtils.getWriteFilesOpt(plan)
+
     // SPARK-40588: when planned writing is disabled and AQE is enabled,
     // plan contains an AdaptiveSparkPlanExec, which does not know
     // its final plan's ordering, so we have to materialize that plan first
@@ -132,17 +154,18 @@ object FileFormatWriter extends Logging {
       case p: SparkPlan => p.withNewChildren(p.children.map(materializeAdaptiveSparkPlan))
     }
 
-    // We should first sort by dynamic partition columns, then bucket id, and finally sorting
-    // columns.
-    val requiredOrdering = partitionColumns.drop(numStaticPartitionCols) ++
-      writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns
-    val writeFilesOpt = V1WritesUtils.getWriteFilesOpt(plan)
     // the sort order doesn't matter
     val actualOrdering = writeFilesOpt.map(_.child)
       .getOrElse(materializeAdaptiveSparkPlan(plan))
       .outputOrdering
     val orderingMatched = V1WritesUtils.isOrderingMatched(requiredOrdering, actualOrdering)
 
+    SQLExecution.checkSQLExecutionId(sparkSession)
+
+    // propagate the description UUID into the jobs, so that committers
+    // get an ID guaranteed to be unique.
+    job.getConfiguration.set("spark.sql.sources.writeJobUUID", description.uuid)
+
     // When `PLANNED_WRITE_ENABLED` is true, the optimizer rule V1Writes will add logical sort
     // operator based on the required ordering of the V1 write command. So the output
     // ordering of the physical plan should always match the required ordering. Here
@@ -153,40 +176,6 @@ object FileFormatWriter extends Logging {
     //    V1 write command will be empty).
     if (Utils.isTesting) outputOrderingMatched = orderingMatched
 
-    SQLExecution.checkSQLExecutionId(sparkSession)
-
-    val finalStatsTrackers = if (writeFilesOpt.isDefined) {
-      val writeFilesMetrics = writeFilesOpt.get.metrics
-      statsTrackers.map {
-        case tracker: BasicWriteJobStatsTracker =>
-          val finalMetrics = writeFilesMetrics ++ tracker.writeCommitMetrics()
-          DataWritingCommand.basicWriteJobStatsTracker(finalMetrics, hadoopConf)
-        case other => other
-      }
-    } else {
-      statsTrackers
-    }
-    val description = new WriteJobDescription(
-      uuid = UUID.randomUUID.toString,
-      serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
-      outputWriterFactory = outputWriterFactory,
-      allColumns = finalOutputSpec.outputColumns,
-      dataColumns = dataColumns,
-      partitionColumns = partitionColumns,
-      bucketSpec = writerBucketSpec,
-      path = finalOutputSpec.outputPath,
-      customPartitionLocations = finalOutputSpec.customPartitionLocations,
-      maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
-        .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
-      timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION)
-        .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone),
-      statsTrackers = finalStatsTrackers
-    )
-
-    // propagate the description UUID into the jobs, so that committers
-    // get an ID guaranteed to be unique.
-    job.getConfiguration.set("spark.sql.sources.writeJobUUID", description.uuid)
-
     if (writeFilesOpt.isDefined) {
       // build `WriteFilesSpec` for `WriteFiles`
       val concurrentOutputWriterSpecFunc = (plan: SparkPlan) => {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteFiles.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteFiles.scala
index 2437bf4e580..d0ed6b02fef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteFiles.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteFiles.scala
@@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
 import org.apache.spark.sql.connector.write.WriterCommitMessage
 import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec
-import org.apache.spark.sql.execution.metric.SQLMetric
 
 /**
  * The write files spec holds all information of [[V1WriteCommand]] if its provider is
@@ -70,8 +69,6 @@ case class WriteFilesExec(
     staticPartitions: TablePartitionSpec) extends UnaryExecNode {
   override def output: Seq[Attribute] = Seq.empty
 
-  override lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.writeFilesMetrics
-
   override protected def doExecuteWrite(
       writeFilesSpec: WriteFilesSpec): RDD[WriterCommitMessage] = {
     val rdd = child.execute()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteJobStatsTrackerMetricSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteJobStatsTrackerMetricSuite.scala
index c538757b2d1..3e58c225d8c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteJobStatsTrackerMetricSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteJobStatsTrackerMetricSuite.scala
@@ -17,10 +17,6 @@
 
 package org.apache.spark.sql.execution.datasources
 
-import org.scalatest.concurrent.Eventually.eventually
-import org.scalatest.concurrent.Futures.timeout
-import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
-
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.{LocalSparkSession, SparkSession}
 
@@ -48,14 +44,13 @@ class BasicWriteJobStatsTrackerMetricSuite extends SparkFunSuite with LocalSpark
       // but the executionId is indeterminate in maven test,
       // so the `statusStore.execution(executionId)` API is not used.
       assert(statusStore.executionsCount() == 2)
-      eventually(timeout(10.seconds)) {
-        val executionData = statusStore.executionsList()(1)
-        val accumulatorIdOpt =
-          executionData.metrics.find(_.name == "number of dynamic part").map(_.accumulatorId)
-        assert(accumulatorIdOpt.isDefined)
-        val numPartsOpt = executionData.metricValues.get(accumulatorIdOpt.get)
-        assert(numPartsOpt.isDefined && numPartsOpt.get == partitions)
-      }
+      val executionData = statusStore.executionsList()(1)
+      val accumulatorIdOpt =
+        executionData.metrics.find(_.name == "number of dynamic part").map(_.accumulatorId)
+      assert(accumulatorIdOpt.isDefined)
+      val numPartsOpt = executionData.metricValues.get(accumulatorIdOpt.get)
+      assert(numPartsOpt.isDefined && numPartsOpt.get == partitions)
+
     } finally {
       spark.sql("drop table if exists dynamic_partition")
       spark.stop()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 5b61c046941..26e61c6b58d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.command.DataWritingCommandExec
-import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, SQLHadoopMapReduceCommitProtocol, WriteFilesExec}
+import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, InsertIntoHadoopFsRelationCommand, SQLHadoopMapReduceCommitProtocol, V1WriteCommand}
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec}
 import org.apache.spark.sql.expressions.Window
@@ -830,33 +830,29 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils
 
   test("SPARK-34567: Add metrics for CTAS operator") {
     withTable("t") {
-      var dataWriting: DataWritingCommandExec = null
+      var v1WriteCommand: V1WriteCommand = null
       val listener = new QueryExecutionListener {
         override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {}
         override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
           qe.executedPlan match {
             case dataWritingCommandExec: DataWritingCommandExec =>
-              dataWriting = dataWritingCommandExec
+              val createTableAsSelect = dataWritingCommandExec.cmd
+              v1WriteCommand = createTableAsSelect.asInstanceOf[InsertIntoHadoopFsRelationCommand]
             case _ =>
           }
         }
       }
       spark.listenerManager.register(listener)
       try {
-        sql("CREATE TABLE t USING PARQUET AS SELECT 1 as a")
+        val df = sql("CREATE TABLE t USING PARQUET AS SELECT 1 as a")
         sparkContext.listenerBus.waitUntilEmpty()
-        assert(dataWriting != null)
-        val metrics = if (conf.plannedWriteEnabled) {
-          dataWriting.child.asInstanceOf[WriteFilesExec].metrics
-        } else {
-          dataWriting.cmd.metrics
-        }
-        assert(metrics.contains("numFiles"))
-        assert(metrics("numFiles").value == 1)
-        assert(metrics.contains("numOutputBytes"))
-        assert(metrics("numOutputBytes").value > 0)
-        assert(metrics.contains("numOutputRows"))
-        assert(metrics("numOutputRows").value == 1)
+        assert(v1WriteCommand != null)
+        assert(v1WriteCommand.metrics.contains("numFiles"))
+        assert(v1WriteCommand.metrics("numFiles").value == 1)
+        assert(v1WriteCommand.metrics.contains("numOutputBytes"))
+        assert(v1WriteCommand.metrics("numOutputBytes").value > 0)
+        assert(v1WriteCommand.metrics.contains("numOutputRows"))
+        assert(v1WriteCommand.metrics("numOutputRows").value == 1)
       } finally {
         spark.listenerManager.unregister(listener)
       }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala
index 0952b1c0c10..81667d52e16 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.execution.{SparkPlan, SparkPlanInfo}
 import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore}
-import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.SQLConf.WHOLESTAGE_CODEGEN_ENABLED
 import org.apache.spark.sql.test.SQLTestUtils
 
@@ -81,13 +80,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
     assert(executionIds.size == 1)
     val executionId = executionIds.head
 
-    val executedNode = if (conf.plannedWriteEnabled) {
-      val executedNodeOpt = statusStore.planGraph(executionId).nodes.find(_.name == "WriteFiles")
-      assert(executedNodeOpt.isDefined)
-      executedNodeOpt.get
-    } else {
-      statusStore.planGraph(executionId).nodes.head
-    }
+    val executedNode = statusStore.planGraph(executionId).nodes.head
 
     val metricsNames = Seq(
       "number of written files",
@@ -111,17 +104,9 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
     assert(totalNumBytes > 0)
   }
 
-  protected def withPlannedWrite(f: => Unit): Unit = {
-    Seq(true, false).foreach { plannedWrite =>
-      withSQLConf(SQLConf.PLANNED_WRITE_ENABLED.key -> plannedWrite.toString) {
-        f
-      }
-    }
-  }
-
   protected def testMetricsNonDynamicPartition(
       dataFormat: String,
-      tableName: String): Unit = withPlannedWrite {
+      tableName: String): Unit = {
     withTable(tableName) {
       Seq((1, 2)).toDF("i", "j")
         .write.format(dataFormat).mode("overwrite").saveAsTable(tableName)
@@ -141,7 +126,7 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
   protected def testMetricsDynamicPartition(
       provider: String,
       dataFormat: String,
-      tableName: String): Unit = withPlannedWrite {
+      tableName: String): Unit = {
     withTable(tableName) {
       withTempPath { dir =>
         spark.sql(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala
index fed7f2cc7a3..bd6278473a7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala
@@ -30,8 +30,6 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.command.DDLUtils
-import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
-import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.hive.client.HiveClientImpl
 import org.apache.spark.sql.util.SchemaUtils
 
@@ -59,11 +57,6 @@ case class InsertIntoHiveDirCommand(
     overwrite: Boolean,
     outputColumnNames: Seq[String]) extends SaveAsHiveFile with V1WritesHiveUtils {
 
-  // We did not pull out `InsertIntoHiveDirCommand` to `V1WriteCommand`,
-  // so there is no `WriteFiles`. It should always hold all metrics by itself.
-  override lazy val metrics: Map[String, SQLMetric] =
-    BasicWriteJobStatsTracker.metrics
-
   override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
     assert(storage.locationUri.nonEmpty)
     SchemaUtils.checkColumnNameDuplication(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala
index cb19f01a522..c5a84b930a9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.hive.execution
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
 import org.apache.spark.sql.execution.command.DataWritingCommandExec
-import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, WriteFilesExec}
+import org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand, V1WriteCommand}
 import org.apache.spark.sql.execution.metric.SQLMetricsTestUtils
 import org.apache.spark.sql.hive.HiveUtils
 import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -43,22 +43,21 @@ class SQLMetricsSuite extends SQLMetricsTestUtils with TestHiveSingleton
   }
 
   test("SPARK-34567: Add metrics for CTAS operator") {
-    withPlannedWrite {
-      checkWriteMetrics()
-    }
-  }
-
-  private def checkWriteMetrics(): Unit = {
     Seq(false, true).foreach { canOptimized =>
       withSQLConf(HiveUtils.CONVERT_METASTORE_CTAS.key -> canOptimized.toString) {
         withTable("t") {
-          var dataWriting: DataWritingCommandExec = null
+          var v1WriteCommand: V1WriteCommand = null
           val listener = new QueryExecutionListener {
             override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {}
             override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
               qe.executedPlan match {
                 case dataWritingCommandExec: DataWritingCommandExec =>
-                  dataWriting = dataWritingCommandExec
+                  val createTableAsSelect = dataWritingCommandExec.cmd
+                  v1WriteCommand = if (canOptimized) {
+                    createTableAsSelect.asInstanceOf[InsertIntoHadoopFsRelationCommand]
+                  } else {
+                    createTableAsSelect.asInstanceOf[InsertIntoHiveTable]
+                  }
                 case _ =>
               }
             }
@@ -67,24 +66,13 @@ class SQLMetricsSuite extends SQLMetricsTestUtils with TestHiveSingleton
           try {
             sql(s"CREATE TABLE t STORED AS PARQUET AS SELECT 1 as a")
             sparkContext.listenerBus.waitUntilEmpty()
-            assert(dataWriting != null)
-            val v1WriteCommand = if (canOptimized) {
-              dataWriting.cmd.asInstanceOf[InsertIntoHadoopFsRelationCommand]
-            } else {
-              dataWriting.cmd.asInstanceOf[InsertIntoHiveTable]
-            }
-            val metrics = if (conf.plannedWriteEnabled) {
-              dataWriting.child.asInstanceOf[WriteFilesExec].metrics
-            } else {
-              v1WriteCommand.metrics
-            }
-
-            assert(metrics.contains("numFiles"))
-            assert(metrics("numFiles").value == 1)
-            assert(metrics.contains("numOutputBytes"))
-            assert(metrics("numOutputBytes").value > 0)
-            assert(metrics.contains("numOutputRows"))
-            assert(metrics("numOutputRows").value == 1)
+            assert(v1WriteCommand != null)
+            assert(v1WriteCommand.metrics.contains("numFiles"))
+            assert(v1WriteCommand.metrics("numFiles").value == 1)
+            assert(v1WriteCommand.metrics.contains("numOutputBytes"))
+            assert(v1WriteCommand.metrics("numOutputBytes").value > 0)
+            assert(v1WriteCommand.metrics.contains("numOutputRows"))
+            assert(v1WriteCommand.metrics("numOutputRows").value == 1)
           } finally {
             spark.listenerManager.unregister(listener)
           }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org