You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ir...@apache.org on 2018/09/25 16:57:50 UTC

[3/3] spark git commit: [PYSPARK][SQL] Updates to RowQueue

[PYSPARK][SQL] Updates to RowQueue

Tested with updates to RowQueueSuite

(cherry picked from commit 6d742d1bd71aa3803dce91a830b37284cb18cf70)


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4f10aff4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4f10aff4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4f10aff4

Branch: refs/heads/branch-2.2
Commit: 4f10aff403ccc8287a816cb94ddf7f11e185907a
Parents: dd0e7cf
Author: Imran Rashid <ir...@cloudera.com>
Authored: Thu Sep 6 12:11:47 2018 -0500
Committer: Imran Rashid <ir...@cloudera.com>
Committed: Tue Sep 25 11:46:06 2018 -0500

----------------------------------------------------------------------
 .../spark/sql/execution/python/RowQueue.scala   | 27 ++++++++++++++-----
 .../sql/execution/python/RowQueueSuite.scala    | 28 +++++++++++++++-----
 2 files changed, 41 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4f10aff4/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
index cd1e77f..4d6820c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
@@ -21,9 +21,10 @@ import java.io._
 
 import com.google.common.io.Closeables
 
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkEnv, SparkException}
 import org.apache.spark.io.NioBufferedFileInputStream
 import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager}
+import org.apache.spark.serializer.SerializerManager
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.memory.MemoryBlock
@@ -108,9 +109,13 @@ private[python] abstract class InMemoryRowQueue(val page: MemoryBlock, numFields
  * A RowQueue that is backed by a file on disk. This queue will stop accepting new rows once any
  * reader has begun reading from the queue.
  */
-private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueue {
-  private var out = new DataOutputStream(
-    new BufferedOutputStream(new FileOutputStream(file.toString)))
+private[python] case class DiskRowQueue(
+    file: File,
+    fields: Int,
+    serMgr: SerializerManager) extends RowQueue {
+
+  private var out = new DataOutputStream(serMgr.wrapForEncryption(
+    new BufferedOutputStream(new FileOutputStream(file.toString))))
   private var unreadBytes = 0L
 
   private var in: DataInputStream = _
@@ -131,7 +136,8 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu
     if (out != null) {
       out.close()
       out = null
-      in = new DataInputStream(new NioBufferedFileInputStream(file))
+      in = new DataInputStream(serMgr.wrapForEncryption(
+        new NioBufferedFileInputStream(file)))
     }
 
     if (unreadBytes > 0) {
@@ -166,7 +172,8 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu
 private[python] case class HybridRowQueue(
     memManager: TaskMemoryManager,
     tempDir: File,
-    numFields: Int)
+    numFields: Int,
+    serMgr: SerializerManager)
   extends MemoryConsumer(memManager) with RowQueue {
 
   // Each buffer should have at least one row
@@ -212,7 +219,7 @@ private[python] case class HybridRowQueue(
   }
 
   private def createDiskQueue(): RowQueue = {
-    DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields)
+    DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields, serMgr)
   }
 
   private def createNewQueue(required: Long): RowQueue = {
@@ -279,3 +286,9 @@ private[python] case class HybridRowQueue(
     }
   }
 }
+
+private[python] object HybridRowQueue {
+  def apply(taskMemoryMgr: TaskMemoryManager, file: File, fields: Int): HybridRowQueue = {
+    HybridRowQueue(taskMemoryMgr, file, fields, SparkEnv.get.serializerManager)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4f10aff4/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
index ffda33c..1ec9986 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
@@ -20,12 +20,15 @@ package org.apache.spark.sql.execution.python
 import java.io.File
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.internal.config._
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite}
+import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.unsafe.memory.MemoryBlock
 import org.apache.spark.util.Utils
 
-class RowQueueSuite extends SparkFunSuite {
+class RowQueueSuite extends SparkFunSuite with EncryptionFunSuite {
 
   test("in-memory queue") {
     val page = MemoryBlock.fromLongArray(new Array[Long](1<<10))
@@ -53,10 +56,20 @@ class RowQueueSuite extends SparkFunSuite {
     queue.close()
   }
 
-  test("disk queue") {
+  private def createSerializerManager(conf: SparkConf): SerializerManager = {
+    val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) {
+      Some(CryptoStreamUtils.createKey(conf))
+    } else {
+      None
+    }
+    new SerializerManager(new JavaSerializer(conf), conf, ioEncryptionKey)
+  }
+
+  encryptionTest("disk queue") { conf =>
+    val serManager = createSerializerManager(conf)
     val dir = Utils.createTempDir().getCanonicalFile
     dir.mkdirs()
-    val queue = DiskRowQueue(new File(dir, "buffer"), 1)
+    val queue = DiskRowQueue(new File(dir, "buffer"), 1, serManager)
     val row = new UnsafeRow(1)
     row.pointTo(new Array[Byte](16), 16)
     val n = 1000
@@ -81,11 +94,12 @@ class RowQueueSuite extends SparkFunSuite {
     queue.close()
   }
 
-  test("hybrid queue") {
-    val mem = new TestMemoryManager(new SparkConf())
+  encryptionTest("hybrid queue") { conf =>
+    val serManager = createSerializerManager(conf)
+    val mem = new TestMemoryManager(conf)
     mem.limit(4<<10)
     val taskM = new TaskMemoryManager(mem, 0)
-    val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1)
+    val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1, serManager)
     val row = new UnsafeRow(1)
     row.pointTo(new Array[Byte](16), 16)
     val n = (4<<10) / 16 * 3


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org