You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by di...@apache.org on 2023/11/07 03:13:08 UTC

(doris-spark-connector) branch master updated: [improvement] support two phases commit in structured streaming (#156)

This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git


The following commit(s) were added to refs/heads/master by this push:
     new df4f107  [improvement] support two phases commit in structured streaming (#156)
df4f107 is described below

commit df4f107f6cf9ee4f75bd33c74aad7e43dde5058f
Author: gnehil <ad...@gmail.com>
AuthorDate: Tue Nov 7 11:13:04 2023 +0800

    [improvement] support two phases commit in structured streaming (#156)
---
 .../spark/listener/DorisTransactionListener.scala  |  83 -----------------
 .../doris/spark/sql/DorisSourceProvider.scala      |   4 +-
 .../doris/spark/sql/DorisStreamLoadSink.scala      |  11 ++-
 .../doris/spark/txn/TransactionHandler.scala       | 100 +++++++++++++++++++++
 .../txn/listener/DorisTransactionListener.scala    |  66 ++++++++++++++
 .../listener/DorisTxnStreamingQueryListener.scala  |  69 ++++++++++++++
 .../apache/doris/spark/writer/DorisWriter.scala    |  42 ++++-----
 7 files changed, 262 insertions(+), 113 deletions(-)

diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
deleted file mode 100644
index b1e9d84..0000000
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
+++ /dev/null
@@ -1,83 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.spark.listener
-
-import org.apache.doris.spark.load.DorisStreamLoad
-import org.apache.doris.spark.sql.Utils
-import org.apache.spark.scheduler._
-import org.apache.spark.util.CollectionAccumulator
-import org.slf4j.{Logger, LoggerFactory}
-
-import java.time.Duration
-import scala.collection.JavaConverters._
-import scala.collection.mutable
-import scala.util.{Failure, Success}
-
-class DorisTransactionListener(preCommittedTxnAcc: CollectionAccumulator[Long], dorisStreamLoad: DorisStreamLoad, sinkTnxIntervalMs: Int, sinkTxnRetries: Int)
-  extends SparkListener {
-
-  val logger: Logger = LoggerFactory.getLogger(classOf[DorisTransactionListener])
-
-  override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
-    val txnIds: mutable.Buffer[Long] = preCommittedTxnAcc.value.asScala
-    val failedTxnIds = mutable.Buffer[Long]()
-    jobEnd.jobResult match {
-      // if job succeed, commit all transactions
-      case JobSucceeded =>
-        if (txnIds.isEmpty) {
-          logger.warn("job run succeed, but there is no pre-committed txn ids")
-          return
-        }
-        logger.info("job run succeed, start committing transactions")
-        txnIds.foreach(txnId =>
-          Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTnxIntervalMs), logger) {
-            dorisStreamLoad.commit(txnId)
-          } () match {
-            case Success(_) => // do nothing
-            case Failure(_) => failedTxnIds += txnId
-          }
-        )
-
-        if (failedTxnIds.nonEmpty) {
-          logger.error("uncommitted txn ids: {}", failedTxnIds.mkString(","))
-        } else {
-          logger.info("commit transaction success")
-        }
-      // if job failed, abort all pre committed transactions
-      case _ =>
-        if (txnIds.isEmpty) {
-          logger.warn("job run failed, but there is no pre-committed txn ids")
-          return
-        }
-        logger.info("job run failed, start aborting transactions")
-        txnIds.foreach(txnId =>
-          Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTnxIntervalMs), logger) {
-            dorisStreamLoad.abortById(txnId)
-          } () match {
-            case Success(_) => // do nothing
-            case Failure(_) => failedTxnIds += txnId
-          })
-        if (failedTxnIds.nonEmpty) {
-          logger.error("not aborted txn ids: {}", failedTxnIds.mkString(","))
-        } else {
-          logger.info("abort transaction success")
-        }
-    }
-  }
-
-}
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
index ac04401..995bd41 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
@@ -67,8 +67,10 @@ private[sql] class DorisSourceProvider extends DataSourceRegister
       case _: SaveMode => // do nothing
     }
 
+    // accumulator for transaction handling
+    val acc = sqlContext.sparkContext.collectionAccumulator[Long]("BatchTxnAcc")
     // init stream loader
-    val writer = new DorisWriter(sparkSettings)
+    val writer = new DorisWriter(sparkSettings, acc)
     writer.write(data)
 
     new BaseRelation {
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
index d1a2b74..9a80fa8 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
@@ -17,7 +17,9 @@
 
 package org.apache.doris.spark.sql
 
-import org.apache.doris.spark.cfg.SparkSettings
+import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
+import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad}
+import org.apache.doris.spark.txn.listener.DorisTxnStreamingQueryListener
 import org.apache.doris.spark.writer.DorisWriter
 import org.apache.spark.sql.execution.streaming.Sink
 import org.apache.spark.sql.{DataFrame, SQLContext}
@@ -28,7 +30,12 @@ private[sql] class DorisStreamLoadSink(sqlContext: SQLContext, settings: SparkSe
   private val logger: Logger = LoggerFactory.getLogger(classOf[DorisStreamLoadSink].getName)
   @volatile private var latestBatchId = -1L
 
-  private val writer = new DorisWriter(settings)
+  // accumulator for transaction handling
+  private val acc = sqlContext.sparkContext.collectionAccumulator[Long]("StreamTxnAcc")
+  private val writer = new DorisWriter(settings, acc)
+
+  // add listener for structured streaming
+  sqlContext.streams.addListener(new DorisTxnStreamingQueryListener(acc, settings))
 
   override def addBatch(batchId: Long, data: DataFrame): Unit = {
     if (batchId <= latestBatchId) {
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/TransactionHandler.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/TransactionHandler.scala
new file mode 100644
index 0000000..deeb40b
--- /dev/null
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/TransactionHandler.scala
@@ -0,0 +1,100 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.txn
+
+import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
+import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad}
+import org.apache.doris.spark.sql.Utils
+import org.apache.spark.internal.Logging
+
+import java.time.Duration
+import scala.collection.mutable
+import scala.util.{Failure, Success}
+
+/**
+ * Stream load transaction handler
+ *
+ * @param settings job settings
+ */
+class TransactionHandler(settings: SparkSettings) extends Logging {
+
+  private val sinkTxnIntervalMs: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS,
+    ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS_DEFAULT)
+  private val sinkTxnRetries: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_RETRIES,
+    ConfigurationOptions.DORIS_SINK_TXN_RETRIES_DEFAULT)
+  private val dorisStreamLoad: DorisStreamLoad = CachedDorisStreamLoadClient.getOrCreate(settings)
+
+  /**
+   * commit transactions
+   *
+   * @param txnIds transaction id list
+   */
+  def commitTransactions(txnIds: List[Long]): Unit = {
+    log.debug(s"start to commit transactions, count ${txnIds.size}")
+    val (failedTxnIds, ex) = txnIds.map(commitTransaction).filter(_._1.nonEmpty)
+      .map(e => (e._1.get, e._2.get))
+      .aggregate((mutable.Buffer[Long](), new Exception))(
+        (z, r) => ((z._1 += r._1).asInstanceOf[mutable.Buffer[Long]], r._2), (r1, r2) => (r1._1 ++ r2._1, r2._2))
+    if (failedTxnIds.nonEmpty) {
+      log.error("uncommitted txn ids: {}", failedTxnIds.mkString("[", ",", "]"))
+      throw ex
+    }
+  }
+
+  /**
+   * commit single transaction
+   *
+   * @param txnId transaction id
+   * @return
+   */
+  private def commitTransaction(txnId: Long): (Option[Long], Option[Exception]) = {
+    Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTxnIntervalMs), log) {
+      dorisStreamLoad.commit(txnId)
+    }() match {
+      case Success(_) => (None, None)
+      case Failure(e: Exception) => (Option(txnId), Option(e))
+    }
+  }
+
+  /**
+   * abort transactions
+   *
+   * @param txnIds transaction id list
+   */
+  def abortTransactions(txnIds: List[Long]): Unit = {
+    log.debug(s"start to abort transactions, count ${txnIds.size}")
+    var ex: Option[Exception] = None
+    val failedTxnIds = txnIds.map(txnId =>
+      Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTxnIntervalMs), log) {
+        dorisStreamLoad.abortById(txnId)
+      }() match {
+        case Success(_) => None
+        case Failure(e: Exception) =>
+          ex = Option(e)
+          Option(txnId)
+      }).filter(_.nonEmpty).map(_.get)
+    if (failedTxnIds.nonEmpty) {
+      log.error("not aborted txn ids: {}", failedTxnIds.mkString("[", ",", "]"))
+    }
+  }
+
+}
+
+object TransactionHandler {
+  def apply(settings: SparkSettings): TransactionHandler = new TransactionHandler(settings)
+}
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTransactionListener.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTransactionListener.scala
new file mode 100644
index 0000000..b23dcae
--- /dev/null
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTransactionListener.scala
@@ -0,0 +1,66 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.txn.listener
+
+import org.apache.doris.spark.cfg.SparkSettings
+import org.apache.doris.spark.txn.TransactionHandler
+import org.apache.spark.internal.Logging
+import org.apache.spark.scheduler._
+import org.apache.spark.util.CollectionAccumulator
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+class DorisTransactionListener(preCommittedTxnAcc: CollectionAccumulator[Long], settings: SparkSettings)
+  extends SparkListener with Logging {
+
+  val txnHandler: TransactionHandler = TransactionHandler(settings)
+
+  override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
+    val txnIds: mutable.Buffer[Long] = preCommittedTxnAcc.value.asScala
+    jobEnd.jobResult match {
+      // if job succeed, commit all transactions
+      case JobSucceeded =>
+        if (txnIds.isEmpty) {
+          log.debug("job run succeed, but there is no pre-committed txn ids")
+          return
+        }
+        log.info("job run succeed, start committing transactions")
+        try txnHandler.commitTransactions(txnIds.toList)
+        catch {
+          case e: Exception => throw e
+        }
+        finally preCommittedTxnAcc.reset()
+        log.info("commit transaction success")
+      // if job failed, abort all pre committed transactions
+      case _ =>
+        if (txnIds.isEmpty) {
+          log.debug("job run failed, but there is no pre-committed txn ids")
+          return
+        }
+        log.info("job run failed, start aborting transactions")
+        try txnHandler.abortTransactions(txnIds.toList)
+        catch {
+          case e: Exception => throw e
+        }
+        finally preCommittedTxnAcc.reset()
+        log.info("abort transaction success")
+    }
+  }
+
+}
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTxnStreamingQueryListener.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTxnStreamingQueryListener.scala
new file mode 100644
index 0000000..77ac9c3
--- /dev/null
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTxnStreamingQueryListener.scala
@@ -0,0 +1,69 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.txn.listener
+
+import org.apache.doris.spark.cfg.SparkSettings
+import org.apache.doris.spark.txn.TransactionHandler
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.streaming.StreamingQueryListener
+import org.apache.spark.util.CollectionAccumulator
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+class DorisTxnStreamingQueryListener(preCommittedTxnAcc: CollectionAccumulator[Long], settings: SparkSettings)
+  extends StreamingQueryListener with Logging {
+
+  private val txnHandler = TransactionHandler(settings)
+
+  override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {}
+
+  override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
+    // do commit transaction when each batch ends
+    val txnIds: mutable.Buffer[Long] = preCommittedTxnAcc.value.asScala
+    if (txnIds.isEmpty) {
+      log.warn("job run succeed, but there is no pre-committed txn ids")
+      return
+    }
+    log.info(s"batch[${event.progress.batchId}] run succeed, start committing transactions")
+    try txnHandler.commitTransactions(txnIds.toList)
+    catch {
+      case e: Exception => throw e
+    } finally preCommittedTxnAcc.reset()
+    log.info(s"batch[${event.progress.batchId}] commit transaction success")
+  }
+
+
+  override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
+    val txnIds: mutable.Buffer[Long] = preCommittedTxnAcc.value.asScala
+    // if job failed, abort all pre committed transactions
+    if (event.exception.nonEmpty) {
+      if (txnIds.isEmpty) {
+        log.warn("job run failed, but there is no pre-committed txn ids")
+        return
+      }
+      log.info("job run failed, start aborting transactions")
+      try txnHandler.abortTransactions(txnIds.toList)
+      catch {
+        case e: Exception => throw e
+      } finally preCommittedTxnAcc.reset()
+      log.info("abort transaction success")
+    }
+  }
+
+}
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
index 55f4d73..59092f6 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
@@ -18,9 +18,10 @@
 package org.apache.doris.spark.writer
 
 import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
-import org.apache.doris.spark.listener.DorisTransactionListener
 import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad}
 import org.apache.doris.spark.sql.Utils
+import org.apache.doris.spark.txn.TransactionHandler
+import org.apache.doris.spark.txn.listener.DorisTransactionListener
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types.StructType
@@ -32,11 +33,10 @@ import java.time.Duration
 import java.util
 import java.util.Objects
 import scala.collection.JavaConverters._
-import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.util.{Failure, Success}
 
-class DorisWriter(settings: SparkSettings) extends Serializable {
+class DorisWriter(settings: SparkSettings, preCommittedTxnAcc: CollectionAccumulator[Long]) extends Serializable {
 
   private val logger: Logger = LoggerFactory.getLogger(classOf[DorisWriter])
 
@@ -57,13 +57,11 @@ class DorisWriter(settings: SparkSettings) extends Serializable {
 
   private val enable2PC: Boolean = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC,
     ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT)
-  private val sinkTxnIntervalMs: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS,
-    ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS_DEFAULT)
-  private val sinkTxnRetries: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_RETRIES,
-    ConfigurationOptions.DORIS_SINK_TXN_RETRIES_DEFAULT)
 
   private val dorisStreamLoader: DorisStreamLoad = CachedDorisStreamLoadClient.getOrCreate(settings)
 
+  private var isStreaming = false;
+
   /**
    * write data in batch mode
    *
@@ -79,19 +77,14 @@ class DorisWriter(settings: SparkSettings) extends Serializable {
    * @param dataFrame source dataframe
    */
   def writeStream(dataFrame: DataFrame): Unit = {
-    if (enable2PC) {
-      val errMsg = "two phrase commit is not supported in stream mode, please set doris.sink.enable-2pc to false."
-      throw new UnsupportedOperationException(errMsg)
-    }
+    isStreaming = true
     doWrite(dataFrame, dorisStreamLoader.loadStream)
   }
 
   private def doWrite(dataFrame: DataFrame, loadFunc: (util.Iterator[InternalRow], StructType) => Long): Unit = {
-
-    val sc = dataFrame.sqlContext.sparkContext
-    val preCommittedTxnAcc = sc.collectionAccumulator[Long]("preCommittedTxnAcc")
-    if (enable2PC) {
-      sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader, sinkTxnIntervalMs, sinkTxnRetries))
+    // do not add spark listener when job is streaming mode
+    if (enable2PC && !isStreaming) {
+      dataFrame.sparkSession.sparkContext.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, settings))
     }
 
     var resultRdd = dataFrame.queryExecution.toRdd
@@ -132,17 +125,12 @@ class DorisWriter(settings: SparkSettings) extends Serializable {
       logger.info("no pre-committed transactions, skip abort")
       return
     }
-    val abortFailedTxnIds = mutable.Buffer[Long]()
-    acc.value.asScala.foreach(txnId => {
-      Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) {
-        dorisStreamLoader.abortById(txnId)
-      }() match {
-        case Success(_) => // do nothing
-        case Failure(_) => abortFailedTxnIds += txnId
-      }
-    })
-    if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(","))
-    acc.reset()
+
+    try TransactionHandler(settings).abortTransactions(acc.value.asScala.toList)
+    catch {
+      case e: Exception => throw e
+    }
+    finally acc.reset()
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org