You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hugegraph.apache.org by va...@apache.org on 2023/04/06 10:37:04 UTC

[incubator-hugegraph-toolchain] 01/01: chore: improve spark parallel

This is an automated email from the ASF dual-hosted git repository.

vaughn pushed a commit to branch zy_dev
in repository https://gitbox.apache.org/repos/asf/incubator-hugegraph-toolchain.git

commit 8f1bfe33981cb476f7d6416b795b1dff20389290
Author: vaughn.zhang <va...@zoom.us>
AuthorDate: Thu Apr 6 18:30:03 2023 +0800

    chore: improve spark parallel
---
 .../loader/spark/HugeGraphSparkLoader.java         | 65 +++++++++++++---------
 1 file changed, 38 insertions(+), 27 deletions(-)

diff --git a/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/spark/HugeGraphSparkLoader.java b/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/spark/HugeGraphSparkLoader.java
index 60c7837f..b26003b0 100644
--- a/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/spark/HugeGraphSparkLoader.java
+++ b/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/spark/HugeGraphSparkLoader.java
@@ -63,6 +63,9 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
 
 import scala.collection.JavaConverters;
 
@@ -77,11 +80,10 @@ public class HugeGraphSparkLoader implements Serializable {
         HugeGraphSparkLoader loader;
         try {
             loader = new HugeGraphSparkLoader(args);
+            loader.load();
         } catch (Throwable e) {
             Printer.printError("Failed to start loading", e);
-            return;
         }
-        loader.load();
     }
 
     public HugeGraphSparkLoader(String[] args) {
@@ -89,7 +91,7 @@ public class HugeGraphSparkLoader implements Serializable {
         this.builders = new HashMap<>();
     }
 
-    public void load() {
+    public void load() throws ExecutionException, InterruptedException {
         LoadMapping mapping = LoadMapping.of(this.loadOptions.file);
         List<InputStruct> structs = mapping.structs();
         boolean sinkType = this.loadOptions.sinkType;
@@ -123,35 +125,44 @@ public class HugeGraphSparkLoader implements Serializable {
         SparkContext sc = session.sparkContext();
 
         LongAccumulator totalInsertSuccess = sc.longAccumulator("totalInsertSuccess");
+        List<Future<Void>> futures = new ArrayList<>(structs.size());
+
         for (InputStruct struct : structs) {
-            LOG.info("\n Initializes the accumulator corresponding to the  {} ",
-                     struct.input().asFileSource().path());
-            LoadDistributeMetrics loadDistributeMetrics = new LoadDistributeMetrics(struct);
-            loadDistributeMetrics.init(sc);
-            LOG.info("\n  Start to load data, data info is: \t {} ",
-                     struct.input().asFileSource().path());
-            Dataset<Row> ds = read(session, struct);
-            if (sinkType) {
-                LOG.info("\n  Start to load data using spark apis  \n");
-                ds.foreachPartition((Iterator<Row> p) -> {
-                    LoadContext context = initPartition(this.loadOptions, struct);
-                    p.forEachRemaining((Row row) -> {
-                        loadRow(struct, row, p, context);
+            Future<Void> future = Executors.newCachedThreadPool().submit(() -> {
+                LOG.info("\n Initializes the accumulator corresponding to the  {} ",
+                        struct.input().asFileSource().path());
+                LoadDistributeMetrics loadDistributeMetrics = new LoadDistributeMetrics(struct);
+                loadDistributeMetrics.init(sc);
+                LOG.info("\n  Start to load data, data info is: \t {} ",
+                        struct.input().asFileSource().path());
+                Dataset<Row> ds = read(session, struct);
+                if (sinkType) {
+                    LOG.info("\n  Start to load data using spark apis  \n");
+                    ds.foreachPartition((Iterator<Row> p) -> {
+                        LoadContext context = initPartition(this.loadOptions, struct);
+                        p.forEachRemaining((Row row) -> {
+                            loadRow(struct, row, p, context);
+                        });
+                        context.close();
                     });
-                    context.close();
-                });
 
-            } else {
-                LOG.info("\n Start to load data using spark bulkload \n");
-                // gen-hfile
-                HBaseDirectLoader directLoader = new HBaseDirectLoader(loadOptions, struct,
-                                                                       loadDistributeMetrics);
-                directLoader.bulkload(ds);
+                } else {
+                    LOG.info("\n Start to load data using spark bulkload \n");
+                    // gen-hfile
+                    HBaseDirectLoader directLoader = new HBaseDirectLoader(loadOptions, struct,
+                            loadDistributeMetrics);
+                    directLoader.bulkload(ds);
 
-            }
-            collectLoadMetrics(loadDistributeMetrics, totalInsertSuccess);
-            LOG.info("\n Finished  load {}  data ", struct.input().asFileSource().path());
+                }
+                collectLoadMetrics(loadDistributeMetrics, totalInsertSuccess);
+                LOG.info("\n Finished  load {}  data ", struct.input().asFileSource().path());
+            });
+            futures.add(future);
         }
+        for (Future<Void> future : futures) {
+            future.get();
+        }
+
         Long totalInsertSuccessCnt = totalInsertSuccess.value();
         LOG.info("\n ------------The data load task is complete-------------------\n" +
                  "\n insertSuccessCnt:\t {} \n ---------------------------------------------\n",