You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@seatunnel.apache.org by wa...@apache.org on 2023/04/08 16:01:36 UTC

[incubator-seatunnel] branch dev updated: [Improve][Core/Spark-Starter] Push transform operation from Spark Driver to Executors (#4503)

This is an automated email from the ASF dual-hosted git repository.

wanghailin pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/incubator-seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new e96bbf50e [Improve][Core/Spark-Starter] Push transform operation from Spark Driver to Executors (#4503)
e96bbf50e is described below

commit e96bbf50e102dc579dcb5d4f00f7317498d63848
Author: Chengyu Yan <ch...@hotmail.com>
AuthorDate: Sun Apr 9 00:01:30 2023 +0800

    [Improve][Core/Spark-Starter] Push transform operation from Spark Driver to Executors (#4503)
---
 release-note.md                                    |  3 +-
 .../spark/execution/TransformExecuteProcessor.java | 54 +++++++++++++++++----
 .../spark/execution/TransformExecuteProcessor.java | 55 ++++++++++++++++++----
 3 files changed, 92 insertions(+), 20 deletions(-)

diff --git a/release-note.md b/release-note.md
index 870d26cd5..d781e3fe6 100644
--- a/release-note.md
+++ b/release-note.md
@@ -23,6 +23,8 @@
 - [Canal]Support read canal format message #3950
 
 ## Improves
+### Core
+- [Starter][Spark] Push transform operation from Spark Driver to Executors #4502
 ### Connectors
 - [CDC]Add mysql-cdc source factory #3791
 - [JDBC]Fix the problem that the exception cannot be thrown normally #3796
@@ -58,7 +60,6 @@
 - [Storage] Remove seatunnel-api from engine storage. #3834
 - [Core] change queue to disruptor. #3847
 - [Improve] Statistics server job and system resource usage. #3982
-- 
 ## Bug Fixes
 ### Connectors
 - [ClickHouse File] Fix ClickhouseFile Committer Serializable Problems #3803
diff --git a/seatunnel-core/seatunnel-spark-starter/seatunnel-spark-2-starter/src/main/java/org/apache/seatunnel/core/starter/spark/execution/TransformExecuteProcessor.java b/seatunnel-core/seatunnel-spark-starter/seatunnel-spark-2-starter/src/main/java/org/apache/seatunnel/core/starter/spark/execution/TransformExecuteProcessor.java
index 76fa2d0e8..f82c465ec 100644
--- a/seatunnel-core/seatunnel-spark-starter/seatunnel-spark-2-starter/src/main/java/org/apache/seatunnel/core/starter/spark/execution/TransformExecuteProcessor.java
+++ b/seatunnel-core/seatunnel-spark-starter/seatunnel-spark-2-starter/src/main/java/org/apache/seatunnel/core/starter/spark/execution/TransformExecuteProcessor.java
@@ -28,8 +28,11 @@ import org.apache.seatunnel.plugin.discovery.PluginIdentifier;
 import org.apache.seatunnel.plugin.discovery.seatunnel.SeaTunnelTransformPluginDiscovery;
 import org.apache.seatunnel.translation.spark.utils.TypeConverterUtils;
 
+import org.apache.spark.api.java.function.MapPartitionsFunction;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.catalyst.encoders.RowEncoder;
 import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
 import org.apache.spark.sql.types.StructType;
 
@@ -37,6 +40,7 @@ import com.google.common.collect.Lists;
 import lombok.extern.slf4j.Slf4j;
 
 import java.io.IOException;
+import java.io.Serializable;
 import java.net.URL;
 import java.util.ArrayList;
 import java.util.Iterator;
@@ -120,18 +124,50 @@ public class TransformExecuteProcessor
         transform.setTypeInfo(seaTunnelDataType);
         StructType structType =
                 (StructType) TypeConverterUtils.convert(transform.getProducedType());
-        SeaTunnelRow seaTunnelRow;
-        List<Row> outputRows = new ArrayList<>();
-        Iterator<Row> rowIterator = stream.toLocalIterator();
-        while (rowIterator.hasNext()) {
-            Row row = rowIterator.next();
-            seaTunnelRow = new SeaTunnelRow(((GenericRowWithSchema) row).values());
+        ExpressionEncoder<Row> encoder = RowEncoder.apply(structType);
+        return stream.mapPartitions(
+                        (MapPartitionsFunction<Row, Row>)
+                                (Iterator<Row> rowIterator) -> {
+                                    TransformIterator iterator =
+                                            new TransformIterator(
+                                                    rowIterator, transform, structType);
+                                    return iterator;
+                                },
+                        encoder)
+                .filter(
+                        (Row row) -> {
+                            return row != null;
+                        });
+    }
+
+    private static class TransformIterator implements Iterator<Row>, Serializable {
+        private Iterator<Row> sourceIterator;
+        private SeaTunnelTransform<SeaTunnelRow> transform;
+        private StructType structType;
+
+        public TransformIterator(
+                Iterator<Row> sourceIterator,
+                SeaTunnelTransform<SeaTunnelRow> transform,
+                StructType structType) {
+            this.sourceIterator = sourceIterator;
+            this.transform = transform;
+            this.structType = structType;
+        }
+
+        @Override
+        public boolean hasNext() {
+            return sourceIterator.hasNext();
+        }
+
+        @Override
+        public Row next() {
+            Row row = sourceIterator.next();
+            SeaTunnelRow seaTunnelRow = new SeaTunnelRow(((GenericRowWithSchema) row).values());
             seaTunnelRow = (SeaTunnelRow) transform.map(seaTunnelRow);
             if (seaTunnelRow == null) {
-                continue;
+                return null;
             }
-            outputRows.add(new GenericRowWithSchema(seaTunnelRow.getFields(), structType));
+            return new GenericRowWithSchema(seaTunnelRow.getFields(), structType);
         }
-        return sparkRuntimeEnvironment.getSparkSession().createDataFrame(outputRows, structType);
     }
 }
diff --git a/seatunnel-core/seatunnel-spark-starter/seatunnel-spark-3-starter/src/main/java/org/apache/seatunnel/core/starter/spark/execution/TransformExecuteProcessor.java b/seatunnel-core/seatunnel-spark-starter/seatunnel-spark-3-starter/src/main/java/org/apache/seatunnel/core/starter/spark/execution/TransformExecuteProcessor.java
index 908c6cffc..f82c465ec 100644
--- a/seatunnel-core/seatunnel-spark-starter/seatunnel-spark-3-starter/src/main/java/org/apache/seatunnel/core/starter/spark/execution/TransformExecuteProcessor.java
+++ b/seatunnel-core/seatunnel-spark-starter/seatunnel-spark-3-starter/src/main/java/org/apache/seatunnel/core/starter/spark/execution/TransformExecuteProcessor.java
@@ -28,8 +28,11 @@ import org.apache.seatunnel.plugin.discovery.PluginIdentifier;
 import org.apache.seatunnel.plugin.discovery.seatunnel.SeaTunnelTransformPluginDiscovery;
 import org.apache.seatunnel.translation.spark.utils.TypeConverterUtils;
 
+import org.apache.spark.api.java.function.MapPartitionsFunction;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.catalyst.encoders.RowEncoder;
 import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
 import org.apache.spark.sql.types.StructType;
 
@@ -37,6 +40,7 @@ import com.google.common.collect.Lists;
 import lombok.extern.slf4j.Slf4j;
 
 import java.io.IOException;
+import java.io.Serializable;
 import java.net.URL;
 import java.util.ArrayList;
 import java.util.Iterator;
@@ -120,19 +124,50 @@ public class TransformExecuteProcessor
         transform.setTypeInfo(seaTunnelDataType);
         StructType structType =
                 (StructType) TypeConverterUtils.convert(transform.getProducedType());
-        SeaTunnelRow seaTunnelRow;
-        List<Row> outputRows = new ArrayList<>();
-        Iterator<Row> rowIterator = stream.toLocalIterator();
-        while (rowIterator.hasNext()) {
-            Row row = rowIterator.next();
-            seaTunnelRow = new SeaTunnelRow(((GenericRowWithSchema) row).values());
+        ExpressionEncoder<Row> encoder = RowEncoder.apply(structType);
+        return stream.mapPartitions(
+                        (MapPartitionsFunction<Row, Row>)
+                                (Iterator<Row> rowIterator) -> {
+                                    TransformIterator iterator =
+                                            new TransformIterator(
+                                                    rowIterator, transform, structType);
+                                    return iterator;
+                                },
+                        encoder)
+                .filter(
+                        (Row row) -> {
+                            return row != null;
+                        });
+    }
+
+    private static class TransformIterator implements Iterator<Row>, Serializable {
+        private Iterator<Row> sourceIterator;
+        private SeaTunnelTransform<SeaTunnelRow> transform;
+        private StructType structType;
+
+        public TransformIterator(
+                Iterator<Row> sourceIterator,
+                SeaTunnelTransform<SeaTunnelRow> transform,
+                StructType structType) {
+            this.sourceIterator = sourceIterator;
+            this.transform = transform;
+            this.structType = structType;
+        }
+
+        @Override
+        public boolean hasNext() {
+            return sourceIterator.hasNext();
+        }
+
+        @Override
+        public Row next() {
+            Row row = sourceIterator.next();
+            SeaTunnelRow seaTunnelRow = new SeaTunnelRow(((GenericRowWithSchema) row).values());
             seaTunnelRow = (SeaTunnelRow) transform.map(seaTunnelRow);
             if (seaTunnelRow == null) {
-                continue;
+                return null;
             }
-            outputRows.add(new GenericRowWithSchema(seaTunnelRow.getFields(), structType));
+            return new GenericRowWithSchema(seaTunnelRow.getFields(), structType);
         }
-
-        return sparkRuntimeEnvironment.getSparkSession().createDataFrame(outputRows, structType);
     }
 }