You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hive.apache.org by pg...@apache.org on 2022/02/18 19:10:38 UTC

[hive] branch master updated: HIVE-25149: Support parallel load for Fast HT implementations (#3029) (Panagiotis Garefalakis, reviewed by Ramesh Kumar)

This is an automated email from the ASF dual-hosted git repository.

pgaref pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hive.git


The following commit(s) were added to refs/heads/master by this push:
     new 2d05298  HIVE-25149: Support parallel load for Fast HT implementations (#3029) (Panagiotis Garefalakis, reviewed by Ramesh Kumar)
2d05298 is described below

commit 2d05298f5b4a40b5eb339efcd0701e31b694ccc7
Author: Panagiotis Garefalakis <pg...@apache.org>
AuthorDate: Fri Feb 18 11:08:03 2022 -0800

    HIVE-25149: Support parallel load for Fast HT implementations (#3029) (Panagiotis Garefalakis, reviewed by Ramesh Kumar)
    
    * HIVE-25149: Support parallel load for Fast HT implementations
    * Introducing new HiveConf controlling the number of threads used in parallel HT load for MJ
    * Extending VectorMapJoinFastHashTableLoader to support parallel HT loading
---
 .../java/org/apache/hadoop/hive/conf/HiveConf.java |   2 +
 .../fast/VectorMapJoinFastHashTableLoader.java     | 215 +++++++++++++++++++--
 2 files changed, 197 insertions(+), 20 deletions(-)

diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
index d8398b3..3ac60f3 100644
--- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
+++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
@@ -1962,6 +1962,8 @@ public class HiveConf extends Configuration {
         "Only works on Tez and Spark, because memory-optimized hashtable cannot be serialized."),
     HIVEMAPJOINOPTIMIZEDTABLEPROBEPERCENT("hive.mapjoin.optimized.hashtable.probe.percent",
         (float) 0.5, "Probing space percentage of the optimized hashtable"),
+    HIVEMAPJOINPARALELHASHTABLETHREADS("hive.mapjoin.hashtable.load.threads", 2,
+        "Number of threads used to load records from a broadcast edge in HT used for MJ"),
     HIVEUSEHYBRIDGRACEHASHJOIN("hive.mapjoin.hybridgrace.hashtable", false, "Whether to use hybrid" +
         "grace hash join as the join method for mapjoin. Tez only."),
     HIVEHYBRIDGRACEHASHJOINMEMCHECKFREQ("hive.mapjoin.hybridgrace.memcheckfrequency", 1024, "For " +
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/VectorMapJoinFastHashTableLoader.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/VectorMapJoinFastHashTableLoader.java
index e0d8e8d..38eccb1 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/VectorMapJoinFastHashTableLoader.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/mapjoin/fast/VectorMapJoinFastHashTableLoader.java
@@ -20,12 +20,21 @@ package org.apache.hadoop.hive.ql.exec.vector.mapjoin.fast;
 import java.io.IOException;
 import java.util.Collections;
 import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.LongAccumulator;
 
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import org.apache.hadoop.hive.common.Pool;
 import org.apache.hadoop.hive.llap.LlapDaemonInfo;
 import org.apache.hadoop.hive.ql.exec.MemoryMonitorInfo;
 import org.apache.hadoop.hive.ql.exec.Operator;
 import org.apache.hadoop.hive.ql.exec.Utilities;
 import org.apache.hadoop.hive.ql.exec.mapjoin.MapJoinMemoryExhaustionError;
+import org.apache.hive.common.util.FixedSizedObjectPool;
 import org.apache.tez.common.counters.TezCounter;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -59,6 +68,15 @@ public class VectorMapJoinFastHashTableLoader implements org.apache.hadoop.hive.
   private TezContext tezContext;
   private String cacheKey;
   private TezCounter htLoadCounter;
+  private LongAccumulator totalEntries;
+
+  // Parallel loading variables
+  private int numLoadThreads;
+  private ExecutorService loadExecService;
+  private HashTableElementBatch[] elementBatches;
+  private FixedSizedObjectPool<HashTableElementBatch> batchPool;
+  private BlockingQueue<HashTableElementBatch>[] loadBatchQueues;
+  private static final HashTableElementBatch DONE_SENTINEL = new HashTableElementBatch();
 
   @Override
   public void init(ExecMapperContext context, MapredContext mrContext,
@@ -73,6 +91,88 @@ public class VectorMapJoinFastHashTableLoader implements org.apache.hadoop.hive.
     this.htLoadCounter = tezContext.getTezProcessorContext().getCounters().findCounter(counterGroup, counterName);
   }
 
+  private void initHTLoadingService(long estKeyCount) {
+    // Avoid many small HTs that will rehash multiple times causing GCs
+    this.numLoadThreads = (estKeyCount < VectorMapJoinFastHashTable.FIRST_SIZE_UP) ? 1 :
+        HiveConf.getIntVar(hconf, HiveConf.ConfVars.HIVEMAPJOINPARALELHASHTABLETHREADS);
+    this.totalEntries = new LongAccumulator(Long::sum, 0L);
+    this.loadExecService = Executors.newFixedThreadPool(numLoadThreads,
+        new ThreadFactoryBuilder()
+            .setDaemon(true)
+            .setPriority(Thread.NORM_PRIORITY)
+            .setNameFormat("HT-Load-Thread-%d")
+            .build());
+    // Reuse HashTableElementBatches to reduce GC pressure
+    this.batchPool = new FixedSizedObjectPool<>(8 * numLoadThreads, new Pool.PoolObjectHelper<HashTableElementBatch>() {
+      @Override
+      public HashTableElementBatch create() {
+        return new HashTableElementBatch();
+      }
+
+      @Override
+      public void resetBeforeOffer(HashTableElementBatch elementBatch) {
+        elementBatch.reset();
+      }
+    });
+    this.elementBatches = new HashTableElementBatch[numLoadThreads];
+    this.loadBatchQueues = new BlockingQueue[numLoadThreads];
+    for (int i = 0; i < numLoadThreads; i++) {
+      loadBatchQueues[i] = new LinkedBlockingQueue();
+      elementBatches[i] = batchPool.take();
+    }
+  }
+
+  private void submitQueueDrainThreads(VectorMapJoinFastTableContainer vectorMapJoinFastTableContainer)
+      throws InterruptedException, IOException, SerDeException {
+    for (int partitionId = 0; partitionId < numLoadThreads; partitionId++) {
+      int finalPartitionId = partitionId;
+      this.loadExecService.submit(() -> {
+        try {
+          LOG.info("Partition id {} with Queue size {}", finalPartitionId, loadBatchQueues[finalPartitionId].size());
+          drainAndLoadForPartition(finalPartitionId, vectorMapJoinFastTableContainer);
+        } catch (IOException | InterruptedException | SerDeException | HiveException e) {
+          throw new RuntimeException("Failed to start HT Load threads", e);
+        }
+      });
+    }
+  }
+
+  private void drainAndLoadForPartition(int partitionId, VectorMapJoinFastTableContainer tableContainer)
+      throws InterruptedException, IOException, HiveException, SerDeException {
+    LOG.info("Starting draining thread {}", partitionId);
+    long totalProcessedEntries = 0;
+    HashTableElementBatch batch = null;
+    while (batch != DONE_SENTINEL) {
+      batch = this.loadBatchQueues[partitionId].take();
+      LOG.debug("Draining thread {} batchSize {}", partitionId, batch.getSize());
+      for (int i = 0; i < batch.getSize(); i++) {
+        try {
+          HashTableElement h = batch.getBatch(i);
+          tableContainer.putRow(h.getHashCode(), h.getKey(), h.getValue());
+        }
+        catch (Exception e) {
+          throw new HiveException("Exception in draining thread put row", e);
+        }
+      }
+      totalProcessedEntries += batch.getSize();
+      LOG.debug("Draining thread {} added {} entries", partitionId, batch.getSize());
+      totalEntries.accumulate(batch.getSize());
+      // Offer must be at the end as it is resetting the Index(size)
+      this.batchPool.offer(batch);
+    }
+
+    LOG.info("Terminating draining thread {} after processing Entries {}", partitionId, totalProcessedEntries);
+  }
+
+  private void addQueueDoneSentinel() {
+    // Add sentinel at the end of queue
+    for (int i = 0; i < numLoadThreads; i++) {
+      // add sentinel to the Queue not the batch!
+      loadBatchQueues[i].add(elementBatches[i]);
+      loadBatchQueues[i].add(DONE_SENTINEL);
+    }
+  }
+
   @Override
   public void load(MapJoinTableContainer[] mapJoinTables,
       MapJoinTableContainerSerDe[] mapJoinTableSerdes)
@@ -106,7 +206,6 @@ public class VectorMapJoinFastHashTableLoader implements org.apache.hadoop.hive.
         continue;
       }
 
-      long numEntries = 0;
       String inputName = parentToInput.get(pos);
       LogicalInput input = tezContext.getInput(inputName);
 
@@ -135,25 +234,38 @@ public class VectorMapJoinFastHashTableLoader implements org.apache.hadoop.hive.
           LOG.debug("Failed to get value for counter APPROXIMATE_INPUT_RECORDS", e);
         }
         long keyCount = Math.max(estKeyCount, inputRecords);
+        initHTLoadingService(keyCount);
 
-        VectorMapJoinFastTableContainer vectorMapJoinFastTableContainer =
-                new VectorMapJoinFastTableContainer(desc, hconf, keyCount, 1);
+        VectorMapJoinFastTableContainer tableContainer =
+            new VectorMapJoinFastTableContainer(desc, hconf, keyCount, numLoadThreads);
 
         LOG.info("Loading hash table for input: {} cacheKey: {} tableContainer: {} smallTablePos: {} " +
                 "estKeyCount : {} keyCount : {}", inputName, cacheKey,
-                vectorMapJoinFastTableContainer.getClass().getSimpleName(), pos, estKeyCount, keyCount);
+                tableContainer.getClass().getSimpleName(), pos, estKeyCount, keyCount);
+
+        tableContainer.setSerde(null, null); // No SerDes here.
+        // Submit parallel loading Threads
+        submitQueueDrainThreads(tableContainer);
 
-        vectorMapJoinFastTableContainer.setSerde(null, null); // No SerDes here.
+        long receivedEntries = 0;
         long startTime = System.currentTimeMillis();
         while (kvReader.next()) {
-          vectorMapJoinFastTableContainer.putRow((BytesWritable)kvReader.getCurrentKey(),
-              (BytesWritable)kvReader.getCurrentValue());
-          numEntries++;
-          if (doMemCheck && (numEntries % memoryMonitorInfo.getMemoryCheckInterval() == 0)) {
-              final long estMemUsage = vectorMapJoinFastTableContainer.getEstimatedMemorySize();
-              if (estMemUsage > effectiveThreshold) {
-                String msg = "Hash table loading exceeded memory limits for input: " + inputName +
-                  " numEntries: " + numEntries + " estimatedMemoryUsage: " + estMemUsage +
+          BytesWritable currentKey = (BytesWritable) kvReader.getCurrentKey();
+          BytesWritable currentValue = (BytesWritable) kvReader.getCurrentValue();
+          long hashCode = tableContainer.getHashCode(currentKey);
+          int partitionId = (int) ((numLoadThreads - 1) & hashCode); // numLoadThreads divisor must be a power of 2!
+          // call getBytes as copy is called later
+          HashTableElement h = new HashTableElement(hashCode, currentValue.copyBytes(), currentKey.copyBytes());
+          if (elementBatches[partitionId].addElement(h)) {
+            loadBatchQueues[partitionId].add(elementBatches[partitionId]);
+            elementBatches[partitionId] = batchPool.take();
+          }
+          receivedEntries++;
+          if (doMemCheck && (receivedEntries % memoryMonitorInfo.getMemoryCheckInterval() == 0)) {
+            final long estMemUsage = tableContainer.getEstimatedMemorySize();
+            if (estMemUsage > effectiveThreshold) {
+              String msg = "Hash table loading exceeded memory limits for input: " + inputName +
+                  " numEntries: " + receivedEntries + " estimatedMemoryUsage: " + estMemUsage +
                   " effectiveThreshold: " + effectiveThreshold + " memoryMonitorInfo: " + memoryMonitorInfo;
                 LOG.error(msg);
                 throw new MapJoinMemoryExhaustionError(msg);
@@ -161,23 +273,34 @@ public class VectorMapJoinFastHashTableLoader implements org.apache.hadoop.hive.
               LOG.info(
                   "Checking hash table loader memory usage for input: {} numEntries: {} "
                       + "estimatedMemoryUsage: {} effectiveThreshold: {}",
-                  inputName, numEntries, estMemUsage, effectiveThreshold);
-              }
+                  inputName, receivedEntries, estMemUsage, effectiveThreshold);
+            }
           }
         }
+
+        LOG.info("Finished loading the queue for input: {} waiting {} minutes for TPool shutdown", inputName, 2);
+        addQueueDoneSentinel();
+        loadExecService.shutdown();
+        loadExecService.awaitTermination(2, TimeUnit.MINUTES);
+        batchPool.clear();
+        LOG.info("Total received entries: {} Threads {} HT entries: {}", receivedEntries, numLoadThreads, totalEntries.get());
+
         long delta = System.currentTimeMillis() - startTime;
         htLoadCounter.increment(delta);
 
-        vectorMapJoinFastTableContainer.seal();
-        mapJoinTables[pos] = vectorMapJoinFastTableContainer;
+        tableContainer.seal();
+        mapJoinTables[pos] = tableContainer;
         if (doMemCheck) {
           LOG.info("Finished loading hash table for input: {} cacheKey: {} numEntries: {} " +
-              "estimatedMemoryUsage: {} Load Time : {} ", inputName, cacheKey, numEntries,
-            vectorMapJoinFastTableContainer.getEstimatedMemorySize(), delta);
+              "estimatedMemoryUsage: {} Load Time : {} ", inputName, cacheKey, receivedEntries,
+            tableContainer.getEstimatedMemorySize(), delta);
         } else {
           LOG.info("Finished loading hash table for input: {} cacheKey: {} numEntries: {} Load Time : {} ",
-                  inputName, cacheKey, numEntries, delta);
+                  inputName, cacheKey, receivedEntries, delta);
         }
+      } catch (InterruptedException e) {
+        Thread.currentThread().interrupt();
+        throw new HiveException(e);
       } catch (IOException e) {
         throw new HiveException(e);
       } catch (SerDeException e) {
@@ -187,4 +310,56 @@ public class VectorMapJoinFastHashTableLoader implements org.apache.hadoop.hive.
       }
     }
   }
+
+  private static class HashTableElement {
+    private final long hashCode;
+    private final byte[] keyBytes;
+    private final byte[] valueBytes;
+
+    public HashTableElement(long hashCode, byte[] valueBytes, byte[] keyBytes) {
+      this.hashCode = hashCode;
+      this.keyBytes = keyBytes;
+      this.valueBytes = valueBytes;
+    }
+
+    public BytesWritable getKey() {
+      return new BytesWritable(this.keyBytes, this.keyBytes.length);
+    }
+
+    public BytesWritable getValue() {
+      return new BytesWritable(this.valueBytes, this.valueBytes.length);
+    }
+
+    public long getHashCode() {
+      return this.hashCode;
+    }
+  }
+
+  private static class HashTableElementBatch {
+    private static final int BATCH_SIZE = 1024;
+    private final HashTableElement[] batch;
+    private int currentIndex;
+
+    public HashTableElementBatch() {
+      this.batch = new HashTableElement[BATCH_SIZE];
+      this.currentIndex = 0;
+    }
+
+    public boolean addElement(HashTableElement h) {
+      this.batch[this.currentIndex++] = h;
+      return (this.currentIndex == BATCH_SIZE);
+    }
+
+    public HashTableElement getBatch(int i) {
+      return this.batch[i];
+    }
+
+    public int getSize() {
+      return this.currentIndex;
+    }
+
+    public void reset() {
+      this.currentIndex = 0;
+    }
+  }
 }