You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by tdas <gi...@git.apache.org> on 2018/04/05 22:37:11 UTC

[GitHub] spark pull request #20828: [SPARK-23687][SS] Add a memory source for continu...

Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20828#discussion_r179618466
  
    --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousMemoryStream.scala ---
    @@ -0,0 +1,212 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.streaming.continuous
    +
    +import java.{util => ju}
    +import java.util.Optional
    +import java.util.concurrent.ArrayBlockingQueue
    +import javax.annotation.concurrent.GuardedBy
    +
    +import scala.collection.JavaConverters._
    +import scala.collection.mutable.ListBuffer
    +import scala.reflect.ClassTag
    +
    +import org.json4s.NoTypeHints
    +import org.json4s.jackson.Serialization
    +
    +import org.apache.spark.SparkEnv
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
    +import org.apache.spark.sql.{Dataset, Encoder, Row, SQLContext}
    +import org.apache.spark.sql.catalyst.encoders.encoderFor
    +import org.apache.spark.sql.catalyst.expressions.UnsafeRow
    +import org.apache.spark.sql.execution.streaming._
    +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions}
    +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
    +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
    +import org.apache.spark.sql.types.StructType
    +import org.apache.spark.util.RpcUtils
    +
    +/**
    + * The overall strategy here is:
    + *  * ContinuousMemoryStream maintains a list of records for each partition. addData() will
    + *    distribute records evenly-ish across partitions.
    + *  * ContinuousMemoryStreamRecordBuffer is set up as an endpoint for partition-level
    + *    ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified
    + *    offset within the list, or null if that offset doesn't yet have a record.
    + */
    +
    +private case class GetRecord(offset: ContinuousMemoryStreamPartitionOffset)
    +
    +private class ContinuousMemoryStreamRecordBuffer[A](
    +    stream: ContinuousMemoryStream[A],
    +    partitionBuffers: Seq[ListBuffer[A]]) extends ThreadSafeRpcEndpoint {
    +  override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv
    +
    +  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
    +    case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) => stream.synchronized {
    +      val buf = partitionBuffers(part)
    +
    +      val record =
    +        if (buf.size <= index) {
    +          None
    +        } else {
    +          Some(buf(index))
    +        }
    +      context.reply(record.map(Row(_)))
    +    }
    +  }
    +}
    +
    +class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
    +    extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport {
    +  private implicit val formats = Serialization.formats(NoTypeHints)
    +  val NUM_PARTITIONS = 2
    +
    +  // ContinuousReader implementation
    +
    +  @GuardedBy("this")
    +  private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A])
    +
    +  private val recordBuffer = new ContinuousMemoryStreamRecordBuffer(this, records)
    +
    +  def addData(data: TraversableOnce[A]): Offset = synchronized {
    +    // Distribute data evenly among partition lists.
    +    data.toSeq.zipWithIndex.map {
    +      case (item, index) => records(index % NUM_PARTITIONS) += item
    +    }
    +
    +    // The new target offset is the offset where all records in all partitions have been processed.
    +    ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap)
    +  }
    +
    +  private var startOffset: ContinuousMemoryStreamOffset = _
    +
    +  override def setStartOffset(start: Optional[Offset]): Unit = synchronized {
    +    // Inferred initial offset is position 0 in each partition.
    +    startOffset = start.orElse {
    +      ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap)
    +    }.asInstanceOf[ContinuousMemoryStreamOffset]
    +  }
    +
    +  override def getStartOffset: Offset = startOffset
    +
    +  override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = {
    +    ContinuousMemoryStreamOffset(Serialization.read[Map[Int, Int]](json))
    +  }
    +
    +  override def mergeOffsets(offsets: Array[PartitionOffset]): ContinuousMemoryStreamOffset = {
    +    ContinuousMemoryStreamOffset {
    +      offsets.map {
    +        case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num)
    +      }.toMap
    +    }
    +  }
    +
    +  var endpointRef: RpcEndpointRef = _
    +  override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = {
    +    synchronized {
    +      endpointRef =
    +        recordBuffer.rpcEnv.setupEndpoint(ContinuousMemoryStream.recordBufferName(id), recordBuffer)
    +
    +      startOffset.partitionNums.map {
    +        case (part, index) =>
    +          new ContinuousMemoryStreamDataReaderFactory(id, part, index): DataReaderFactory[Row]
    +      }.toList.asJava
    +    }
    +  }
    +
    +  override def stop(): Unit = {
    +    if (endpointRef != null) recordBuffer.rpcEnv.stop(endpointRef)
    +  }
    +
    +  override def commit(end: Offset): Unit = {}
    +
    +  // ContinuousReadSupport implementation
    +  // This is necessary because of how StreamTest finds the source for AddDataMemory steps.
    +
    +  def createContinuousReader(
    +      schema: Optional[StructType],
    +      checkpointLocation: String,
    +      options: DataSourceOptions): ContinuousReader = {
    +    this
    +  }
    +
    +  // MemoryStreamBase implementation
    +
    +  protected val logicalPlan =
    --- End diff --
    
    nit: same comment as above, consolidate vals, unless there is a really good reason.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org