You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/05/08 15:43:08 UTC

spark git commit: [SPARK-24117][SQL] Unified the getSizePerRow

Repository: spark
Updated Branches:
  refs/heads/master 2f6fe7d67 -> 487faf17a


[SPARK-24117][SQL] Unified the getSizePerRow

## What changes were proposed in this pull request?

This pr unified the `getSizePerRow` because `getSizePerRow` is used in many places. For example:

1. [LocalRelation.scala#L80](https://github.com/wangyum/spark/blob/f70f46d1e5bc503e9071707d837df618b7696d32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala#L80)
2. [SizeInBytesOnlyStatsPlanVisitor.scala#L36](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala#L36)

## How was this patch tested?
Exist tests

Author: Yuming Wang <yu...@ebay.com>

Closes #21189 from wangyum/SPARK-24117.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/487faf17
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/487faf17
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/487faf17

Branch: refs/heads/master
Commit: 487faf17ab96c8edb729501dfb1ff82f7b2c6031
Parents: 2f6fe7d
Author: Yuming Wang <yu...@ebay.com>
Authored: Tue May 8 23:43:02 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue May 8 23:43:02 2018 +0800

----------------------------------------------------------------------
 .../sql/catalyst/plans/logical/LocalRelation.scala    |  3 ++-
 .../logical/statsEstimation/EstimationUtils.scala     | 14 ++++++++------
 .../SizeInBytesOnlyStatsPlanVisitor.scala             |  4 ++--
 .../apache/spark/sql/execution/streaming/memory.scala | 10 ++++------
 .../sql/execution/streaming/sources/memoryV2.scala    |  3 ++-
 .../apache/spark/sql/StatisticsCollectionSuite.scala  |  2 +-
 .../sql/execution/streaming/MemorySinkSuite.scala     |  4 ++--
 7 files changed, 21 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/487faf17/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
index 720d42a..8c4828a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
 import org.apache.spark.sql.types.{StructField, StructType}
 
 object LocalRelation {
@@ -77,7 +78,7 @@ case class LocalRelation(
   }
 
   override def computeStats(): Statistics =
-    Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)
+    Statistics(sizeInBytes = EstimationUtils.getSizePerRow(output) * data.length)
 
   def toSQL(inlineTableName: String): String = {
     require(data.nonEmpty)

http://git-wip-us.apache.org/repos/asf/spark/blob/487faf17/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index 0f147f0..211a2a0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
 
-import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.math.BigDecimal.RoundingMode
 
@@ -25,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types.{DecimalType, _}
 
-
 object EstimationUtils {
 
   /** Check if each plan has rowCount in its statistics. */
@@ -73,13 +71,12 @@ object EstimationUtils {
     AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _)))
   }
 
-  def getOutputSize(
+  def getSizePerRow(
       attributes: Seq[Attribute],
-      outputRowCount: BigInt,
       attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
     // We assign a generic overhead for a Row object, the actual overhead is different for different
     // Row format.
-    val sizePerRow = 8 + attributes.map { attr =>
+    8 + attributes.map { attr =>
       if (attrStats.get(attr).map(_.avgLen.isDefined).getOrElse(false)) {
         attr.dataType match {
           case StringType =>
@@ -92,10 +89,15 @@ object EstimationUtils {
         attr.dataType.defaultSize
       }
     }.sum
+  }
 
+  def getOutputSize(
+      attributes: Seq[Attribute],
+      outputRowCount: BigInt,
+      attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
     // Output size can't be zero, or sizeInBytes of BinaryNode will also be zero
     // (simple computation of statistics returns product of children).
-    if (outputRowCount > 0) outputRowCount * sizePerRow else 1
+    if (outputRowCount > 0) outputRowCount * getSizePerRow(attributes, attrStats) else 1
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/487faf17/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
index 85f67c7..ee43f91 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala
@@ -33,8 +33,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
   private def visitUnaryNode(p: UnaryNode): Statistics = {
     // There should be some overhead in Row object, the size should not be zero when there is
     // no columns, this help to prevent divide-by-zero error.
-    val childRowSize = p.child.output.map(_.dataType.defaultSize).sum + 8
-    val outputRowSize = p.output.map(_.dataType.defaultSize).sum + 8
+    val childRowSize = EstimationUtils.getSizePerRow(p.child.output)
+    val outputRowSize = EstimationUtils.getSizePerRow(p.output)
     // Assume there will be the same number of rows as child has.
     var sizeInBytes = (p.child.stats.sizeInBytes * outputRowSize) / childRowSize
     if (sizeInBytes == 0) {

http://git-wip-us.apache.org/repos/asf/spark/blob/487faf17/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 2225827..6720cdd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -24,23 +24,21 @@ import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.{ArrayBuffer, ListBuffer}
-import scala.reflect.ClassTag
 import scala.util.control.NonFatal
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
 import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
 import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
-import org.apache.spark.sql.streaming.{OutputMode, Trigger}
+import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.Utils
 
-
 object MemoryStream {
   protected val currentBlockId = new AtomicInteger(0)
   protected val memoryStreamId = new AtomicInteger(0)
@@ -307,7 +305,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
 case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode {
   def this(sink: MemorySink) = this(sink, sink.schema.toAttributes)
 
-  private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum
+  private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes)
 
   override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/487faf17/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
index 0d6c239..468313b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
@@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update}
 import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink}
 import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
@@ -182,7 +183,7 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode)
  * Used to query the data that has been written into a [[MemorySinkV2]].
  */
 case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode {
-  private val sizePerRow = output.map(_.dataType.defaultSize).sum
+  private val sizePerRow = EstimationUtils.getSizePerRow(output)
 
   override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/487faf17/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
index b91712f..60fa951 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -50,7 +50,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
       }
 
       assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
-      assert(sizes.head === BigInt(96),
+      assert(sizes.head === BigInt(128),
         s"expected exact size 96 for table 'test', got: ${sizes.head}")
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/487faf17/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
index e8420ee..3bc36ce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
@@ -220,11 +220,11 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
 
     sink.addBatch(0, 1 to 3)
     plan.invalidateStatsCache()
-    assert(plan.stats.sizeInBytes === 12)
+    assert(plan.stats.sizeInBytes === 36)
 
     sink.addBatch(1, 4 to 6)
     plan.invalidateStatsCache()
-    assert(plan.stats.sizeInBytes === 24)
+    assert(plan.stats.sizeInBytes === 72)
   }
 
   ignore("stress test") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org