You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ir...@apache.org on 2018/09/13 14:20:41 UTC
[3/4] spark git commit: [PYSPARK][SQL] Updates to RowQueue
[PYSPARK][SQL] Updates to RowQueue
Tested with updates to RowQueueSuite
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6d742d1b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6d742d1b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6d742d1b
Branch: refs/heads/branch-2.3
Commit: 6d742d1bd71aa3803dce91a830b37284cb18cf70
Parents: 09dd34c
Author: Imran Rashid <ir...@cloudera.com>
Authored: Thu Sep 6 12:11:47 2018 -0500
Committer: Imran Rashid <ir...@cloudera.com>
Committed: Thu Sep 13 09:19:56 2018 -0500
----------------------------------------------------------------------
.../spark/sql/execution/python/RowQueue.scala | 27 ++++++++++++++-----
.../sql/execution/python/RowQueueSuite.scala | 28 +++++++++++++++-----
2 files changed, 41 insertions(+), 14 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/6d742d1b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
index e2fa6e7..d2820ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
@@ -21,9 +21,10 @@ import java.io._
import com.google.common.io.Closeables
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.io.NioBufferedFileInputStream
import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager}
+import org.apache.spark.serializer.SerializerManager
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.memory.MemoryBlock
@@ -108,9 +109,13 @@ private[python] abstract class InMemoryRowQueue(val page: MemoryBlock, numFields
* A RowQueue that is backed by a file on disk. This queue will stop accepting new rows once any
* reader has begun reading from the queue.
*/
-private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueue {
- private var out = new DataOutputStream(
- new BufferedOutputStream(new FileOutputStream(file.toString)))
+private[python] case class DiskRowQueue(
+ file: File,
+ fields: Int,
+ serMgr: SerializerManager) extends RowQueue {
+
+ private var out = new DataOutputStream(serMgr.wrapForEncryption(
+ new BufferedOutputStream(new FileOutputStream(file.toString))))
private var unreadBytes = 0L
private var in: DataInputStream = _
@@ -131,7 +136,8 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu
if (out != null) {
out.close()
out = null
- in = new DataInputStream(new NioBufferedFileInputStream(file))
+ in = new DataInputStream(serMgr.wrapForEncryption(
+ new NioBufferedFileInputStream(file)))
}
if (unreadBytes > 0) {
@@ -166,7 +172,8 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu
private[python] case class HybridRowQueue(
memManager: TaskMemoryManager,
tempDir: File,
- numFields: Int)
+ numFields: Int,
+ serMgr: SerializerManager)
extends MemoryConsumer(memManager) with RowQueue {
// Each buffer should have at least one row
@@ -212,7 +219,7 @@ private[python] case class HybridRowQueue(
}
private def createDiskQueue(): RowQueue = {
- DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields)
+ DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields, serMgr)
}
private def createNewQueue(required: Long): RowQueue = {
@@ -279,3 +286,9 @@ private[python] case class HybridRowQueue(
}
}
}
+
+private[python] object HybridRowQueue {
+ def apply(taskMemoryMgr: TaskMemoryManager, file: File, fields: Int): HybridRowQueue = {
+ HybridRowQueue(taskMemoryMgr, file, fields, SparkEnv.get.serializerManager)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/6d742d1b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
index ffda33c..1ec9986 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
@@ -20,12 +20,15 @@ package org.apache.spark.sql.execution.python
import java.io.File
import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.internal.config._
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite}
+import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.unsafe.memory.MemoryBlock
import org.apache.spark.util.Utils
-class RowQueueSuite extends SparkFunSuite {
+class RowQueueSuite extends SparkFunSuite with EncryptionFunSuite {
test("in-memory queue") {
val page = MemoryBlock.fromLongArray(new Array[Long](1<<10))
@@ -53,10 +56,20 @@ class RowQueueSuite extends SparkFunSuite {
queue.close()
}
- test("disk queue") {
+ private def createSerializerManager(conf: SparkConf): SerializerManager = {
+ val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) {
+ Some(CryptoStreamUtils.createKey(conf))
+ } else {
+ None
+ }
+ new SerializerManager(new JavaSerializer(conf), conf, ioEncryptionKey)
+ }
+
+ encryptionTest("disk queue") { conf =>
+ val serManager = createSerializerManager(conf)
val dir = Utils.createTempDir().getCanonicalFile
dir.mkdirs()
- val queue = DiskRowQueue(new File(dir, "buffer"), 1)
+ val queue = DiskRowQueue(new File(dir, "buffer"), 1, serManager)
val row = new UnsafeRow(1)
row.pointTo(new Array[Byte](16), 16)
val n = 1000
@@ -81,11 +94,12 @@ class RowQueueSuite extends SparkFunSuite {
queue.close()
}
- test("hybrid queue") {
- val mem = new TestMemoryManager(new SparkConf())
+ encryptionTest("hybrid queue") { conf =>
+ val serManager = createSerializerManager(conf)
+ val mem = new TestMemoryManager(conf)
mem.limit(4<<10)
val taskM = new TaskMemoryManager(mem, 0)
- val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1)
+ val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1, serManager)
val row = new UnsafeRow(1)
row.pointTo(new Array[Byte](16), 16)
val n = (4<<10) / 16 * 3
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org