You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/01/19 00:10:07 UTC

spark git commit: [SPARK-12882][SQL] simplify bucket tests and add more comments

Repository: spark
Updated Branches:
  refs/heads/master 4f11e3f2a -> 404190221


[SPARK-12882][SQL] simplify bucket tests and add more comments

Right now, the bucket tests are kind of hard to understand, this PR simplifies them and add more commetns.

Author: Wenchen Fan <we...@databricks.com>

Closes #10813 from cloud-fan/bucket-comment.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/40419022
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/40419022
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/40419022

Branch: refs/heads/master
Commit: 404190221a788ebc3a0cbf5cb47cf532436ce965
Parents: 4f11e3f
Author: Wenchen Fan <we...@databricks.com>
Authored: Mon Jan 18 15:10:04 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Mon Jan 18 15:10:04 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/sources/BucketedReadSuite.scala   | 56 ++++++++++------
 .../spark/sql/sources/BucketedWriteSuite.scala  | 68 ++++++++++++--------
 2 files changed, 78 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/40419022/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 58ecdd3..150d0c7 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLC
 import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.execution.Exchange
+import org.apache.spark.sql.execution.datasources.BucketSpec
 import org.apache.spark.sql.execution.joins.SortMergeJoin
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -61,15 +62,30 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
   private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
   private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2")
 
+  /**
+   * A helper method to test the bucket read functionality using join.  It will save `df1` and `df2`
+   * to hive tables, bucketed or not, according to the given bucket specifics.  Next we will join
+   * these 2 tables, and firstly make sure the answer is corrected, and then check if the shuffle
+   * exists as user expected according to the `shuffleLeft` and `shuffleRight`.
+   */
   private def testBucketing(
-      bucketing1: DataFrameWriter => DataFrameWriter,
-      bucketing2: DataFrameWriter => DataFrameWriter,
+      bucketSpecLeft: Option[BucketSpec],
+      bucketSpecRight: Option[BucketSpec],
       joinColumns: Seq[String],
       shuffleLeft: Boolean,
       shuffleRight: Boolean): Unit = {
     withTable("bucketed_table1", "bucketed_table2") {
-      bucketing1(df1.write.format("parquet")).saveAsTable("bucketed_table1")
-      bucketing2(df2.write.format("parquet")).saveAsTable("bucketed_table2")
+      def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): DataFrameWriter = {
+        bucketSpec.map { spec =>
+          writer.bucketBy(
+            spec.numBuckets,
+            spec.bucketColumnNames.head,
+            spec.bucketColumnNames.tail: _*)
+        }.getOrElse(writer)
+      }
+
+      withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1")
+      withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2")
 
       withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
         val t1 = hiveContext.table("bucketed_table1")
@@ -95,42 +111,42 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
   }
 
   test("avoid shuffle when join 2 bucketed tables") {
-    val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
-    testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
+    val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
+    testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
   }
 
   // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
   ignore("avoid shuffle when join keys are a super-set of bucket keys") {
-    val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
-    testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
+    val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
+    testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
   }
 
   test("only shuffle one side when join bucketed table and non-bucketed table") {
-    val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
-    testBucketing(bucketing, identity, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
+    val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
+    testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
   }
 
   test("only shuffle one side when 2 bucketed tables have different bucket number") {
-    val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
-    val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(5, "i", "j")
-    testBucketing(bucketing1, bucketing2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
+    val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil))
+    val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil))
+    testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
   }
 
   test("only shuffle one side when 2 bucketed tables have different bucket keys") {
-    val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
-    val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(8, "j")
-    testBucketing(bucketing1, bucketing2, Seq("i"), shuffleLeft = false, shuffleRight = true)
+    val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil))
+    val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil))
+    testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true)
   }
 
   test("shuffle when join keys are not equal to bucket keys") {
-    val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
-    testBucketing(bucketing, bucketing, Seq("j"), shuffleLeft = true, shuffleRight = true)
+    val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
+    testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true)
   }
 
   test("shuffle when join 2 bucketed tables with bucketing disabled") {
-    val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
+    val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
     withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
-      testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
+      testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/40419022/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index e812439..dad1fc1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -65,39 +65,55 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
 
   private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
 
+  /**
+   * A helper method to check the bucket write functionality in low level, i.e. check the written
+   * bucket files to see if the data are correct.  User should pass in a data dir that these bucket
+   * files are written to, and the format of data(parquet, json, etc.), and the bucketing
+   * information.
+   */
   private def testBucketing(
       dataDir: File,
       source: String,
+      numBuckets: Int,
       bucketCols: Seq[String],
       sortCols: Seq[String] = Nil): Unit = {
     val allBucketFiles = dataDir.listFiles().filterNot(f =>
       f.getName.startsWith(".") || f.getName.startsWith("_")
     )
-    val groupedBucketFiles = allBucketFiles.groupBy(f => BucketingUtils.getBucketId(f.getName).get)
-    assert(groupedBucketFiles.size <= 8)
-
-    for ((bucketId, bucketFiles) <- groupedBucketFiles) {
-      for (bucketFilePath <- bucketFiles.map(_.getAbsolutePath)) {
-        val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType)
-        val columns = (bucketCols ++ sortCols).zip(types).map {
-          case (colName, dt) => col(colName).cast(dt)
-        }
-        val readBack = sqlContext.read.format(source).load(bucketFilePath).select(columns: _*)
 
-        if (sortCols.nonEmpty) {
-          checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect())
-        }
+    for (bucketFile <- allBucketFiles) {
+      val bucketId = BucketingUtils.getBucketId(bucketFile.getName).get
+      assert(bucketId >= 0 && bucketId < numBuckets)
 
-        val qe = readBack.select(bucketCols.map(col): _*).queryExecution
-        val rows = qe.toRdd.map(_.copy()).collect()
-        val getBucketId = UnsafeProjection.create(
-          HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil,
-          qe.analyzed.output)
+      // We may loss the type information after write(e.g. json format doesn't keep schema
+      // information), here we get the types from the original dataframe.
+      val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType)
+      val columns = (bucketCols ++ sortCols).zip(types).map {
+        case (colName, dt) => col(colName).cast(dt)
+      }
 
-        for (row <- rows) {
-          val actualBucketId = getBucketId(row).getInt(0)
-          assert(actualBucketId == bucketId)
-        }
+      // Read the bucket file into a dataframe, so that it's easier to test.
+      val readBack = sqlContext.read.format(source)
+        .load(bucketFile.getAbsolutePath)
+        .select(columns: _*)
+
+      // If we specified sort columns while writing bucket table, make sure the data in this
+      // bucket file is already sorted.
+      if (sortCols.nonEmpty) {
+        checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect())
+      }
+
+      // Go through all rows in this bucket file, calculate bucket id according to bucket column
+      // values, and make sure it equals to the expected bucket id that inferred from file name.
+      val qe = readBack.select(bucketCols.map(col): _*).queryExecution
+      val rows = qe.toRdd.map(_.copy()).collect()
+      val getBucketId = UnsafeProjection.create(
+        HashPartitioning(qe.analyzed.output, numBuckets).partitionIdExpression :: Nil,
+        qe.analyzed.output)
+
+      for (row <- rows) {
+        val actualBucketId = getBucketId(row).getInt(0)
+        assert(actualBucketId == bucketId)
       }
     }
   }
@@ -113,7 +129,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
 
         val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
         for (i <- 0 until 5) {
-          testBucketing(new File(tableDir, s"i=$i"), source, Seq("j", "k"))
+          testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k"))
         }
       }
     }
@@ -131,7 +147,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
 
         val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
         for (i <- 0 until 5) {
-          testBucketing(new File(tableDir, s"i=$i"), source, Seq("j"), Seq("k"))
+          testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j"), Seq("k"))
         }
       }
     }
@@ -146,7 +162,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
           .saveAsTable("bucketed_table")
 
         val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
-        testBucketing(tableDir, source, Seq("i", "j"))
+        testBucketing(tableDir, source, 8, Seq("i", "j"))
       }
     }
   }
@@ -161,7 +177,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
           .saveAsTable("bucketed_table")
 
         val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
-        testBucketing(tableDir, source, Seq("i", "j"), Seq("k"))
+        testBucketing(tableDir, source, 8, Seq("i", "j"), Seq("k"))
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org