You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hive.apache.org by gu...@apache.org on 2014/05/31 08:31:27 UTC

svn commit: r1598829 - in /hive/trunk/ql/src/java/org/apache/hadoop/hive/ql: exec/tez/CustomPartitionVertex.java exec/tez/DagUtils.java exec/tez/HiveSplitGenerator.java exec/tez/SplitGrouper.java plan/PartitionDesc.java

Author: gunther
Date: Sat May 31 06:31:26 2014
New Revision: 1598829

URL: http://svn.apache.org/r1598829
Log:
HIVE-7071: Use custom Tez split generator to support schema evolution (Patch by Gunther Hagleitner, reviewed by Vikram Dixit K)

Added:
    hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java
    hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java
Modified:
    hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java
    hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java
    hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/plan/PartitionDesc.java

Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java?rev=1598829&r1=1598828&r2=1598829&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java Sat May 31 06:31:26 2014
@@ -21,11 +21,7 @@ package org.apache.hadoop.hive.ql.exec.t
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collection;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
@@ -35,18 +31,14 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.hive.ql.io.HiveInputFormat;
 import org.apache.hadoop.io.DataOutputBuffer;
 import org.apache.hadoop.io.serializer.SerializationFactory;
 import org.apache.hadoop.mapred.FileSplit;
 import org.apache.hadoop.mapred.InputSplit;
 import org.apache.hadoop.mapred.split.TezGroupedSplitsInputFormat;
-import org.apache.hadoop.mapred.split.TezMapredSplitsGrouper;
 import org.apache.tez.dag.api.EdgeManagerDescriptor;
 import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.EdgeProperty.DataMovementType;
-import org.apache.tez.dag.api.VertexLocationHint.TaskLocationHint;
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.api.VertexLocationHint;
@@ -61,16 +53,13 @@ import org.apache.tez.runtime.api.events
 import org.apache.tez.runtime.api.events.RootInputUpdatePayloadEvent;
 import org.apache.tez.runtime.api.events.VertexManagerEvent;
 
-import com.google.common.base.Function;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ArrayListMultimap;
 import com.google.common.collect.HashMultimap;
-import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Multimap;
 
-
 /*
  * Only works with old mapred API
  * Will only work with a single MRInput for now.
@@ -78,32 +67,22 @@ import com.google.common.collect.Multima
 public class CustomPartitionVertex implements VertexManagerPlugin {
 
   private static final Log LOG = LogFactory.getLog(CustomPartitionVertex.class.getName());
-  public static final String GROUP_SPLITS = "hive.enable.custom.grouped.splits";
-
 
   VertexManagerPluginContext context;
 
-  private Multimap<Integer, Integer> bucketToTaskMap = HashMultimap.<Integer, Integer>create();
-  private Multimap<Integer, InputSplit> bucketToInitialSplitMap = 
-      ArrayListMultimap.<Integer, InputSplit>create();
-
   private RootInputConfigureVertexTasksEvent configureVertexTaskEvent;
   private List<RootInputDataInformationEvent> dataInformationEvents;
-  private Map<Path, List<FileSplit>> pathFileSplitsMap = new TreeMap<Path, List<FileSplit>>();
   private int numBuckets = -1;
   private Configuration conf = null;
   private boolean rootVertexInitialized = false;
-  Multimap<Integer, InputSplit> bucketToGroupedSplitMap;
-
-
-  private Map<Integer, Integer> bucketToNumTaskMap = new HashMap<Integer, Integer>();
+  private final SplitGrouper grouper = new SplitGrouper();
 
   public CustomPartitionVertex() {
   }
 
   @Override
   public void initialize(VertexManagerPluginContext context) {
-    this.context = context; 
+    this.context = context;
     ByteBuffer byteBuf = ByteBuffer.wrap(context.getUserPayload());
     this.numBuckets = byteBuf.getInt();
   }
@@ -112,7 +91,7 @@ public class CustomPartitionVertex imple
   public void onVertexStarted(Map<String, List<Integer>> completions) {
     int numTasks = context.getVertexNumTasks(context.getVertexName());
     List<Integer> scheduledTasks = new ArrayList<Integer>(numTasks);
-    for (int i=0; i<numTasks; ++i) {
+    for (int i = 0; i < numTasks; ++i) {
       scheduledTasks.add(new Integer(i));
     }
     context.scheduleVertexTasks(scheduledTasks);
@@ -132,55 +111,63 @@ public class CustomPartitionVertex imple
       List<Event> events) {
 
     // Ideally, since there's only 1 Input expected at the moment -
-    // ensure this method is called only once. Tez will call it once per Root Input.
+    // ensure this method is called only once. Tez will call it once per Root
+    // Input.
     Preconditions.checkState(rootVertexInitialized == false);
     rootVertexInitialized = true;
     try {
       // This is using the payload from the RootVertexInitializer corresponding
-      // to InputName. Ideally it should be using it's own configuration class - but that
+      // to InputName. Ideally it should be using it's own configuration class -
+      // but that
       // means serializing another instance.
-      MRInputUserPayloadProto protoPayload = 
+      MRInputUserPayloadProto protoPayload =
           MRHelpers.parseMRInputPayload(inputDescriptor.getUserPayload());
       this.conf = MRHelpers.createConfFromByteString(protoPayload.getConfigurationBytes());
 
       /*
-       * Currently in tez, the flow of events is thus: "Generate Splits -> Initialize Vertex"
-       * (with parallelism info obtained from the generate splits phase). The generate splits
-       * phase groups splits using the TezGroupedSplitsInputFormat. However, for bucket map joins
-       * the grouping done by this input format results in incorrect results as the grouper has no
-       * knowledge of buckets. So, we initially set the input format to be HiveInputFormat
-       * (in DagUtils) for the case of bucket map joins so as to obtain un-grouped splits.
-       * We then group the splits corresponding to buckets using the tez grouper which returns
+       * Currently in tez, the flow of events is thus:
+       * "Generate Splits -> Initialize Vertex" (with parallelism info obtained
+       * from the generate splits phase). The generate splits phase groups
+       * splits using the TezGroupedSplitsInputFormat. However, for bucket map
+       * joins the grouping done by this input format results in incorrect
+       * results as the grouper has no knowledge of buckets. So, we initially
+       * set the input format to be HiveInputFormat (in DagUtils) for the case
+       * of bucket map joins so as to obtain un-grouped splits. We then group
+       * the splits corresponding to buckets using the tez grouper which returns
        * TezGroupedSplits.
        */
 
-      if (conf.getBoolean(GROUP_SPLITS, true)) {
-        // Changing the InputFormat - so that the correct one is initialized in MRInput.
-        this.conf.set("mapred.input.format.class", TezGroupedSplitsInputFormat.class.getName());
-        MRInputUserPayloadProto updatedPayload = MRInputUserPayloadProto
-            .newBuilder(protoPayload)
-            .setConfigurationBytes(MRHelpers.createByteStringFromConf(conf))
-            .build();
-        inputDescriptor.setUserPayload(updatedPayload.toByteArray());
-      }
+      // This assumes that Grouping will always be used.
+      // Changing the InputFormat - so that the correct one is initialized in
+      // MRInput.
+      this.conf.set("mapred.input.format.class", TezGroupedSplitsInputFormat.class.getName());
+      MRInputUserPayloadProto updatedPayload =
+          MRInputUserPayloadProto.newBuilder(protoPayload)
+              .setConfigurationBytes(MRHelpers.createByteStringFromConf(conf)).build();
+      inputDescriptor.setUserPayload(updatedPayload.toByteArray());
     } catch (IOException e) {
       e.printStackTrace();
       throw new RuntimeException(e);
     }
+
     boolean dataInformationEventSeen = false;
+    Map<Path, List<FileSplit>> pathFileSplitsMap = new TreeMap<Path, List<FileSplit>>();
+
     for (Event event : events) {
       if (event instanceof RootInputConfigureVertexTasksEvent) {
-        // No tasks should have been started yet. Checked by initial state check.
+        // No tasks should have been started yet. Checked by initial state
+        // check.
         Preconditions.checkState(dataInformationEventSeen == false);
         Preconditions
-        .checkState(
-            context.getVertexNumTasks(context.getVertexName()) == -1,
-            "Parallelism for the vertex should be set to -1 if the InputInitializer is setting parallelism");
+            .checkState(context.getVertexNumTasks(context.getVertexName()) == -1,
+                "Parallelism for the vertex should be set to -1 if the InputInitializer is setting parallelism");
         RootInputConfigureVertexTasksEvent cEvent = (RootInputConfigureVertexTasksEvent) event;
 
-        // The vertex cannot be configured until all DataEvents are seen - to build the routing table.
+        // The vertex cannot be configured until all DataEvents are seen - to
+        // build the routing table.
         configureVertexTaskEvent = cEvent;
-        dataInformationEvents = Lists.newArrayListWithCapacity(configureVertexTaskEvent.getNumTasks());
+        dataInformationEvents =
+            Lists.newArrayListWithCapacity(configureVertexTaskEvent.getNumTasks());
       }
       if (event instanceof RootInputUpdatePayloadEvent) {
         // this event can never occur. If it does, fail.
@@ -195,7 +182,7 @@ public class CustomPartitionVertex imple
         } catch (IOException e) {
           throw new RuntimeException("Failed to get file split for event: " + diEvent);
         }
-        List<FileSplit> fsList = pathFileSplitsMap.get(fileSplit.getPath()); 
+        List<FileSplit> fsList = pathFileSplitsMap.get(fileSplit.getPath());
         if (fsList == null) {
           fsList = new ArrayList<FileSplit>();
           pathFileSplitsMap.put(fileSplit.getPath(), fsList);
@@ -204,17 +191,33 @@ public class CustomPartitionVertex imple
       }
     }
 
-    setBucketNumForPath(pathFileSplitsMap);
+    Multimap<Integer, InputSplit> bucketToInitialSplitMap =
+        getBucketSplitMapForPath(pathFileSplitsMap);
+
     try {
-      groupSplits();
-      processAllEvents(inputName);
+      int totalResource = context.getTotalAVailableResource().getMemory();
+      int taskResource = context.getVertexTaskResource().getMemory();
+      float waves =
+          conf.getFloat(TezConfiguration.TEZ_AM_GROUPING_SPLIT_WAVES,
+              TezConfiguration.TEZ_AM_GROUPING_SPLIT_WAVES_DEFAULT);
+
+      int availableSlots = totalResource / taskResource;
+
+      LOG.info("Grouping splits. " + availableSlots + " available slots, " + waves + " waves.");
+
+      Multimap<Integer, InputSplit> bucketToGroupedSplitMap =
+          grouper.group(conf, bucketToInitialSplitMap, availableSlots, waves);
+
+      processAllEvents(inputName, bucketToGroupedSplitMap);
     } catch (IOException e) {
       throw new RuntimeException(e);
     }
   }
 
-  private void processAllEvents(String inputName) throws IOException {
+  private void processAllEvents(String inputName,
+      Multimap<Integer, InputSplit> bucketToGroupedSplitMap) throws IOException {
 
+    Multimap<Integer, Integer> bucketToTaskMap = HashMultimap.<Integer, Integer> create();
     List<InputSplit> finalSplits = Lists.newLinkedList();
     int taskCount = 0;
     for (Entry<Integer, Collection<InputSplit>> entry : bucketToGroupedSplitMap.asMap().entrySet()) {
@@ -227,9 +230,10 @@ public class CustomPartitionVertex imple
       }
     }
 
-    // Construct the EdgeManager descriptor to be used by all edges which need the routing table.
-    EdgeManagerDescriptor hiveEdgeManagerDesc = new EdgeManagerDescriptor(
-        CustomPartitionEdge.class.getName());    
+    // Construct the EdgeManager descriptor to be used by all edges which need
+    // the routing table.
+    EdgeManagerDescriptor hiveEdgeManagerDesc =
+        new EdgeManagerDescriptor(CustomPartitionEdge.class.getName());
     byte[] payload = getBytePayload(bucketToTaskMap);
     hiveEdgeManagerDesc.setUserPayload(payload);
 
@@ -246,13 +250,14 @@ public class CustomPartitionVertex imple
 
     LOG.info("Task count is " + taskCount);
 
-    List<RootInputDataInformationEvent> taskEvents = Lists.newArrayListWithCapacity(finalSplits.size());
+    List<RootInputDataInformationEvent> taskEvents =
+        Lists.newArrayListWithCapacity(finalSplits.size());
     // Re-serialize the splits after grouping.
     int count = 0;
     for (InputSplit inputSplit : finalSplits) {
       MRSplitProto serializedSplit = MRHelpers.createSplitProto(inputSplit);
-      RootInputDataInformationEvent diEvent = new RootInputDataInformationEvent(
-          count, serializedSplit.toByteArray());
+      RootInputDataInformationEvent diEvent =
+          new RootInputDataInformationEvent(count, serializedSplit.toByteArray());
       diEvent.setTargetIndex(count);
       count++;
       taskEvents.add(diEvent);
@@ -261,7 +266,7 @@ public class CustomPartitionVertex imple
     // Replace the Edge Managers
     context.setVertexParallelism(
         taskCount,
-        new VertexLocationHint(createTaskLocationHintsFromSplits(finalSplits
+        new VertexLocationHint(grouper.createTaskLocationHints(finalSplits
             .toArray(new InputSplit[finalSplits.size()]))), emMap);
 
     // Set the actual events for the tasks.
@@ -269,7 +274,7 @@ public class CustomPartitionVertex imple
   }
 
   private byte[] getBytePayload(Multimap<Integer, Integer> routingTable) throws IOException {
-    CustomEdgeConfiguration edgeConf = 
+    CustomEdgeConfiguration edgeConf =
         new CustomEdgeConfiguration(routingTable.keySet().size(), routingTable);
     DataOutputBuffer dob = new DataOutputBuffer();
     edgeConf.write(dob);
@@ -278,17 +283,14 @@ public class CustomPartitionVertex imple
     return serialized;
   }
 
-  private FileSplit getFileSplitFromEvent(RootInputDataInformationEvent event)
-      throws IOException {
+  private FileSplit getFileSplitFromEvent(RootInputDataInformationEvent event) throws IOException {
     InputSplit inputSplit = null;
     if (event.getDeserializedUserPayload() != null) {
       inputSplit = (InputSplit) event.getDeserializedUserPayload();
     } else {
       MRSplitProto splitProto = MRSplitProto.parseFrom(event.getUserPayload());
-      SerializationFactory serializationFactory = new SerializationFactory(
-          new Configuration());
-      inputSplit = MRHelpers.createOldFormatSplitFromUserPayload(splitProto,
-          serializationFactory);
+      SerializationFactory serializationFactory = new SerializationFactory(new Configuration());
+      inputSplit = MRHelpers.createOldFormatSplitFromUserPayload(splitProto, serializationFactory);
     }
 
     if (!(inputSplit instanceof FileSplit)) {
@@ -301,9 +303,15 @@ public class CustomPartitionVertex imple
   /*
    * This method generates the map of bucket to file splits.
    */
-  private void setBucketNumForPath(Map<Path, List<FileSplit>> pathFileSplitsMap) {
+  private Multimap<Integer, InputSplit> getBucketSplitMapForPath(
+      Map<Path, List<FileSplit>> pathFileSplitsMap) {
+
     int bucketNum = 0;
     int fsCount = 0;
+
+    Multimap<Integer, InputSplit> bucketToInitialSplitMap =
+        ArrayListMultimap.<Integer, InputSplit> create();
+
     for (Map.Entry<Path, List<FileSplit>> entry : pathFileSplitsMap.entrySet()) {
       int bucketId = bucketNum % numBuckets;
       for (FileSplit fsplit : entry.getValue()) {
@@ -313,94 +321,9 @@ public class CustomPartitionVertex imple
       bucketNum++;
     }
 
-    LOG.info("Total number of splits counted: " + fsCount + " and total files encountered: " 
+    LOG.info("Total number of splits counted: " + fsCount + " and total files encountered: "
         + pathFileSplitsMap.size());
-  }
-
-  private void groupSplits () throws IOException {
-    bucketToGroupedSplitMap = 
-      ArrayListMultimap.<Integer, InputSplit>create(bucketToInitialSplitMap);
-    if (conf.getBoolean(GROUP_SPLITS, true)) {
-      estimateBucketSizes();
-      Map<Integer, Collection<InputSplit>> bucketSplitMap = bucketToInitialSplitMap.asMap();
-      for (int bucketId : bucketSplitMap.keySet()) {
-        Collection<InputSplit>inputSplitCollection = bucketSplitMap.get(bucketId);
-        TezMapredSplitsGrouper grouper = new TezMapredSplitsGrouper();
-
-        InputSplit[] groupedSplits = grouper.getGroupedSplits(conf, 
-            inputSplitCollection.toArray(new InputSplit[0]), bucketToNumTaskMap.get(bucketId),
-            HiveInputFormat.class.getName());
-        LOG.info("Original split size is " + 
-            inputSplitCollection.toArray(new InputSplit[0]).length + 
-            " grouped split size is " + groupedSplits.length);
-        bucketToGroupedSplitMap.removeAll(bucketId);
-        for (InputSplit inSplit : groupedSplits) {
-          bucketToGroupedSplitMap.put(bucketId, inSplit);
-        }
-      }
-    }
-  }
-
-  private void estimateBucketSizes() {
-    Map<Integer, Long>bucketSizeMap = new HashMap<Integer, Long>();
-    Map<Integer, Collection<InputSplit>> bucketSplitMap = bucketToInitialSplitMap.asMap();
-    long totalSize = 0;
-    for (int bucketId : bucketSplitMap.keySet()) {
-      Long size = 0L;
-      Collection<InputSplit>inputSplitCollection = bucketSplitMap.get(bucketId);
-      Iterator<InputSplit> iter = inputSplitCollection.iterator();
-      while (iter.hasNext()) {
-        FileSplit fsplit = (FileSplit)iter.next();
-        size += fsplit.getLength();
-        totalSize += fsplit.getLength();
-      }
-      bucketSizeMap.put(bucketId, size);
-    }
-
-    int totalResource = context.getTotalAVailableResource().getMemory();
-    int taskResource = context.getVertexTaskResource().getMemory();
-    float waves = conf.getFloat(
-        TezConfiguration.TEZ_AM_GROUPING_SPLIT_WAVES,
-        TezConfiguration.TEZ_AM_GROUPING_SPLIT_WAVES_DEFAULT);
-
-    int numTasks = (int)((totalResource*waves)/taskResource);
-    LOG.info("Total resource: " + totalResource + " Task Resource: " + taskResource
-        + " waves: " + waves + " total size of splits: " + totalSize + 
-        " total number of tasks: " + numTasks);
-
-    for (int bucketId : bucketSizeMap.keySet()) {
-      int numEstimatedTasks = 0;
-      if (totalSize != 0) {
-        numEstimatedTasks = (int)(numTasks * bucketSizeMap.get(bucketId) / totalSize);
-      }
-      LOG.info("Estimated number of tasks: " + numEstimatedTasks + " for bucket " + bucketId);
-      if (numEstimatedTasks == 0) {
-        numEstimatedTasks = 1;
-      }
-      bucketToNumTaskMap.put(bucketId, numEstimatedTasks);
-    }
-  }
-
-  private static List<TaskLocationHint> createTaskLocationHintsFromSplits(
-      org.apache.hadoop.mapred.InputSplit[] oldFormatSplits) {
-    Iterable<TaskLocationHint> iterable = Iterables.transform(Arrays.asList(oldFormatSplits),
-        new Function<org.apache.hadoop.mapred.InputSplit, TaskLocationHint>() {
-      @Override
-      public TaskLocationHint apply(org.apache.hadoop.mapred.InputSplit input) {
-        try {
-          if (input.getLocations() != null) {
-            return new TaskLocationHint(new HashSet<String>(Arrays.asList(input.getLocations())),
-                null);
-          } else {
-            LOG.info("NULL Location: returning an empty location hint");
-            return new TaskLocationHint(null,null);
-          }
-        } catch (IOException e) {
-          throw new RuntimeException(e);
-        }
-      }
-    });
 
-    return Lists.newArrayList(iterable);
+    return bucketToInitialSplitMap;
   }
 }

Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java?rev=1598829&r1=1598828&r2=1598829&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java Sat May 31 06:31:26 2014
@@ -23,7 +23,6 @@ import java.net.URI;
 import java.net.URISyntaxException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
@@ -51,14 +50,12 @@ import org.apache.hadoop.hive.ql.exec.mr
 import org.apache.hadoop.hive.ql.exec.mr.ExecReducer;
 import org.apache.hadoop.hive.ql.exec.tez.tools.TezMergedLogicalInput;
 import org.apache.hadoop.hive.ql.io.BucketizedHiveInputFormat;
-import org.apache.hadoop.hive.ql.io.CombineHiveInputFormat;
 import org.apache.hadoop.hive.ql.io.HiveInputFormat;
 import org.apache.hadoop.hive.ql.io.HiveKey;
 import org.apache.hadoop.hive.ql.io.HiveOutputFormatImpl;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.ql.plan.BaseWork;
 import org.apache.hadoop.hive.ql.plan.MapWork;
-import org.apache.hadoop.hive.ql.plan.PartitionDesc;
 import org.apache.hadoop.hive.ql.plan.ReduceWork;
 import org.apache.hadoop.hive.ql.plan.TezEdgeProperty;
 import org.apache.hadoop.hive.ql.plan.TezEdgeProperty.EdgeType;
@@ -97,10 +94,9 @@ import org.apache.tez.dag.api.OutputDesc
 import org.apache.tez.dag.api.ProcessorDescriptor;
 import org.apache.tez.dag.api.TezException;
 import org.apache.tez.dag.api.Vertex;
-import org.apache.tez.dag.api.VertexGroup;
 import org.apache.tez.dag.api.VertexLocationHint;
 import org.apache.tez.dag.api.VertexManagerPluginDescriptor;
-import org.apache.tez.mapreduce.common.MRInputAMSplitGenerator;
+import org.apache.tez.dag.api.VertexGroup;
 import org.apache.tez.mapreduce.hadoop.InputSplitInfo;
 import org.apache.tez.mapreduce.hadoop.MRHelpers;
 import org.apache.tez.mapreduce.hadoop.MRJobConfig;
@@ -415,7 +411,7 @@ public class DagUtils {
     Vertex map = null;
 
     // use tez to combine splits
-    boolean useTezGroupedSplits = true;
+    boolean useTezGroupedSplits = false;
 
     int numTasks = -1;
     Class amSplitGeneratorClass = null;
@@ -431,44 +427,8 @@ public class DagUtils {
         }
       }
     }
-
-    // we cannot currently allow grouping of splits where each split is a different input format 
-    // or has different deserializers similar to the checks in CombineHiveInputFormat. We do not
-    // need the check for the opList because we will not process different opLists at this time.
-    // Long term fix would be to have a custom input format
-    // logic that groups only the splits that share the same input format
-    Class<?> previousInputFormatClass = null;
-    Class<?> previousDeserializerClass = null;
-    for (String path : mapWork.getPathToPartitionInfo().keySet()) {
-      PartitionDesc pd = mapWork.getPathToPartitionInfo().get(path);
-      Class<?> currentDeserializerClass = pd.getDeserializer(conf).getClass();
-      Class<?> currentInputFormatClass = pd.getInputFileFormatClass();
-      if (previousInputFormatClass == null) {
-        previousInputFormatClass = currentInputFormatClass;
-      }
-      if (previousDeserializerClass == null) {
-        previousDeserializerClass = currentDeserializerClass;
-      }
-      if (LOG.isDebugEnabled()) {
-        LOG.debug("Current input format class = "+currentInputFormatClass+", previous input format class = "
-          + previousInputFormatClass + ", verifying " + " current deserializer class = "
-          + currentDeserializerClass + " previous deserializer class = " + previousDeserializerClass);
-      }
-      if ((currentInputFormatClass != previousInputFormatClass) ||
-          (currentDeserializerClass != previousDeserializerClass)) {
-        useTezGroupedSplits = false;
-        break;
-      }
-    }
     if (vertexHasCustomInput) {
-      // if it is the case of different input formats for different partitions, we cannot group
-      // in the custom vertex for now. Long term, this can be improved to group the buckets that
-      // share the same input format.
-      if (useTezGroupedSplits == false) {
-        conf.setBoolean(CustomPartitionVertex.GROUP_SPLITS, false);
-      } else {
-        conf.setBoolean(CustomPartitionVertex.GROUP_SPLITS, true);
-      }
+      useTezGroupedSplits = false;
       // grouping happens in execution phase. Setting the class to TezGroupedSplitsInputFormat
       // here would cause pre-mature grouping which would be incorrect.
       inputFormatClass = HiveInputFormat.class;
@@ -476,23 +436,19 @@ public class DagUtils {
       // mapreduce.tez.input.initializer.serialize.event.payload should be set to false when using
       // this plug-in to avoid getting a serialized event at run-time.
       conf.setBoolean("mapreduce.tez.input.initializer.serialize.event.payload", false);
-    } else if (useTezGroupedSplits) {
+    } else {
       // we'll set up tez to combine spits for us iff the input format
       // is HiveInputFormat
       if (inputFormatClass == HiveInputFormat.class) {
+        useTezGroupedSplits = true;
         conf.setClass("mapred.input.format.class", TezGroupedSplitsInputFormat.class, InputFormat.class);
-      } else {
-        conf.setClass("mapred.input.format.class", CombineHiveInputFormat.class, InputFormat.class);
-        useTezGroupedSplits = false;
       }
-    } else {
-      conf.setClass("mapred.input.format.class", CombineHiveInputFormat.class, InputFormat.class);
     }
 
     if (HiveConf.getBoolVar(conf, ConfVars.HIVE_AM_SPLIT_GENERATION)) {
       // if we're generating the splits in the AM, we just need to set
       // the correct plugin.
-      amSplitGeneratorClass = MRInputAMSplitGenerator.class;
+      amSplitGeneratorClass = HiveSplitGenerator.class;
     } else {
       // client side split generation means we have to compute them now
       inputSplitInfo = MRHelpers.generateInputSplits(conf,

Added: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java?rev=1598829&view=auto
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java (added)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java Sat May 31 06:31:26 2014
@@ -0,0 +1,199 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.tez;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.hive.ql.exec.Utilities;
+import org.apache.hadoop.hive.ql.io.HiveFileFormatUtils;
+import org.apache.hadoop.hive.ql.plan.MapWork;
+import org.apache.hadoop.hive.ql.plan.PartitionDesc;
+import org.apache.hadoop.mapred.FileSplit;
+import org.apache.hadoop.mapred.InputFormat;
+import org.apache.hadoop.mapred.InputSplit;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.util.ReflectionUtils;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.VertexLocationHint.TaskLocationHint;
+import org.apache.tez.mapreduce.hadoop.InputSplitInfoMem;
+import org.apache.tez.mapreduce.hadoop.MRHelpers;
+import org.apache.tez.mapreduce.protos.MRRuntimeProtos.MRInputUserPayloadProto;
+import org.apache.tez.mapreduce.protos.MRRuntimeProtos.MRSplitProto;
+import org.apache.tez.mapreduce.protos.MRRuntimeProtos.MRSplitsProto;
+import org.apache.tez.runtime.api.Event;
+import org.apache.tez.runtime.api.TezRootInputInitializer;
+import org.apache.tez.runtime.api.TezRootInputInitializerContext;
+import org.apache.tez.runtime.api.events.RootInputConfigureVertexTasksEvent;
+import org.apache.tez.runtime.api.events.RootInputDataInformationEvent;
+
+import com.google.common.collect.ArrayListMultimap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multimap;
+
+/**
+ * This class is used to generate splits inside the AM on the cluster. It
+ * optionally groups together splits based on available head room as well as
+ * making sure that splits from different partitions are only grouped if they
+ * are of the same schema, format and serde
+ */
+public class HiveSplitGenerator implements TezRootInputInitializer {
+
+  private static final Log LOG = LogFactory.getLog(HiveSplitGenerator.class);
+
+  private final SplitGrouper grouper = new SplitGrouper();
+
+  @Override
+  public List<Event> initialize(TezRootInputInitializerContext rootInputContext) throws Exception {
+
+    MRInputUserPayloadProto userPayloadProto =
+        MRHelpers.parseMRInputPayload(rootInputContext.getUserPayload());
+
+    Configuration conf =
+        MRHelpers.createConfFromByteString(userPayloadProto.getConfigurationBytes());
+
+    boolean sendSerializedEvents =
+        conf.getBoolean("mapreduce.tez.input.initializer.serialize.event.payload", true);
+
+    // Read all credentials into the credentials instance stored in JobConf.
+    JobConf jobConf = new JobConf(conf);
+    jobConf.getCredentials().mergeAll(UserGroupInformation.getCurrentUser().getCredentials());
+
+    InputSplitInfoMem inputSplitInfo = null;
+    String realInputFormatName = userPayloadProto.getInputFormatName();
+    if (realInputFormatName != null && !realInputFormatName.isEmpty()) {
+      inputSplitInfo = generateGroupedSplits(rootInputContext, jobConf, conf, realInputFormatName);
+    } else {
+      inputSplitInfo = MRHelpers.generateInputSplitsToMem(jobConf);
+    }
+
+    return createEventList(sendSerializedEvents, inputSplitInfo);
+  }
+
+  private InputSplitInfoMem generateGroupedSplits(TezRootInputInitializerContext context,
+      JobConf jobConf, Configuration conf, String realInputFormatName) throws Exception {
+
+    int totalResource = context.getTotalAvailableResource().getMemory();
+    int taskResource = context.getVertexTaskResource().getMemory();
+    int availableSlots = totalResource / taskResource;
+
+    float waves =
+        conf.getFloat(TezConfiguration.TEZ_AM_GROUPING_SPLIT_WAVES,
+            TezConfiguration.TEZ_AM_GROUPING_SPLIT_WAVES_DEFAULT);
+
+    MapWork work = Utilities.getMapWork(jobConf);
+
+    LOG.info("Grouping splits for " + work.getName() + ". " + availableSlots + " available slots, "
+        + waves + " waves. Input format is: " + realInputFormatName);
+
+    // Need to instantiate the realInputFormat
+    InputFormat<?, ?> inputFormat =
+        (InputFormat<?, ?>) ReflectionUtils
+            .newInstance(Class.forName(realInputFormatName), jobConf);
+
+    // Create the un-grouped splits
+    InputSplit[] splits = inputFormat.getSplits(jobConf, (int) (availableSlots * waves));
+    LOG.info("Number of input splits: " + splits.length);
+
+    Multimap<Integer, InputSplit> bucketSplitMultiMap =
+        ArrayListMultimap.<Integer, InputSplit> create();
+
+    Class<?> previousInputFormatClass = null;
+    String previousDeserializerClass = null;
+    Map<Map<String, PartitionDesc>, Map<String, PartitionDesc>> cache =
+        new HashMap<Map<String, PartitionDesc>, Map<String, PartitionDesc>>();
+
+    int i = 0;
+
+    for (InputSplit s : splits) {
+      // this is the bit where we make sure we don't group across partition
+      // schema boundaries
+
+      Path path = ((FileSplit) s).getPath();
+
+      PartitionDesc pd =
+          HiveFileFormatUtils.getPartitionDescFromPathRecursively(work.getPathToPartitionInfo(),
+              path, cache);
+
+      String currentDeserializerClass = pd.getDeserializerClassName();
+      Class<?> currentInputFormatClass = pd.getInputFileFormatClass();
+
+      if ((currentInputFormatClass != previousInputFormatClass)
+          || (!currentDeserializerClass.equals(previousDeserializerClass))) {
+        ++i;
+      }
+
+      previousInputFormatClass = currentInputFormatClass;
+      previousDeserializerClass = currentDeserializerClass;
+
+      if (LOG.isDebugEnabled()) {
+        LOG.debug("Adding split " + path + " to src group " + i);
+      }
+      bucketSplitMultiMap.put(i, s);
+    }
+    LOG.info("# Src groups for split generation: " + (i + 1));
+
+    // group them into the chunks we want
+    Multimap<Integer, InputSplit> groupedSplits =
+        grouper.group(jobConf, bucketSplitMultiMap, availableSlots, waves);
+
+    // And finally return them in a flat array
+    InputSplit[] flatSplits = groupedSplits.values().toArray(new InputSplit[0]);
+    LOG.info("Number of grouped splits: " + flatSplits.length);
+
+    List<TaskLocationHint> locationHints = grouper.createTaskLocationHints(flatSplits);
+
+    Utilities.clearWork(jobConf);
+
+    return new InputSplitInfoMem(flatSplits, locationHints, flatSplits.length, null, jobConf);
+  }
+
+  private List<Event> createEventList(boolean sendSerializedEvents, InputSplitInfoMem inputSplitInfo) {
+
+    List<Event> events = Lists.newArrayListWithCapacity(inputSplitInfo.getNumTasks() + 1);
+
+    RootInputConfigureVertexTasksEvent configureVertexEvent =
+        new RootInputConfigureVertexTasksEvent(inputSplitInfo.getNumTasks(),
+            inputSplitInfo.getTaskLocationHints());
+    events.add(configureVertexEvent);
+
+    if (sendSerializedEvents) {
+      MRSplitsProto splitsProto = inputSplitInfo.getSplitsProto();
+      int count = 0;
+      for (MRSplitProto mrSplit : splitsProto.getSplitsList()) {
+        RootInputDataInformationEvent diEvent =
+            new RootInputDataInformationEvent(count++, mrSplit.toByteArray());
+        events.add(diEvent);
+      }
+    } else {
+      int count = 0;
+      for (org.apache.hadoop.mapred.InputSplit split : inputSplitInfo.getOldFormatSplits()) {
+        RootInputDataInformationEvent diEvent = new RootInputDataInformationEvent(count++, split);
+        events.add(diEvent);
+      }
+    }
+    return events;
+  }
+}

Added: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java?rev=1598829&view=auto
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java (added)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java Sat May 31 06:31:26 2014
@@ -0,0 +1,156 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.tez;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hive.ql.io.HiveInputFormat;
+import org.apache.hadoop.mapred.FileSplit;
+import org.apache.hadoop.mapred.InputSplit;
+import org.apache.hadoop.mapred.split.TezGroupedSplit;
+import org.apache.hadoop.mapred.split.TezMapredSplitsGrouper;
+import org.apache.tez.dag.api.VertexLocationHint.TaskLocationHint;
+
+import com.google.common.collect.ArrayListMultimap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multimap;
+
+/**
+ * SplitGrouper is used to combine splits based on head room and locality. It
+ * also enforces restrictions around schema, file format and bucketing.
+ */
+public class SplitGrouper {
+
+  private static final Log LOG = LogFactory.getLog(SplitGrouper.class);
+
+  private final TezMapredSplitsGrouper tezGrouper = new TezMapredSplitsGrouper();
+
+  /**
+   * group splits for each bucket separately - while evenly filling all the
+   * available slots with tasks
+   */
+  public Multimap<Integer, InputSplit> group(Configuration conf,
+      Multimap<Integer, InputSplit> bucketSplitMultimap, int availableSlots, float waves)
+      throws IOException {
+
+    // figure out how many tasks we want for each bucket
+    Map<Integer, Integer> bucketTaskMap =
+        estimateBucketSizes(availableSlots, waves, bucketSplitMultimap.asMap());
+
+    // allocate map bucket id to grouped splits
+    Multimap<Integer, InputSplit> bucketGroupedSplitMultimap =
+        ArrayListMultimap.<Integer, InputSplit> create();
+
+    // use the tez grouper to combine splits once per bucket
+    for (int bucketId : bucketSplitMultimap.keySet()) {
+      Collection<InputSplit> inputSplitCollection = bucketSplitMultimap.get(bucketId);
+
+      InputSplit[] rawSplits = inputSplitCollection.toArray(new InputSplit[0]);
+      InputSplit[] groupedSplits =
+          tezGrouper.getGroupedSplits(conf, rawSplits, bucketTaskMap.get(bucketId),
+              HiveInputFormat.class.getName());
+
+      LOG.info("Original split size is " + rawSplits.length + " grouped split size is "
+          + groupedSplits.length + ", for bucket: " + bucketId);
+
+      for (InputSplit inSplit : groupedSplits) {
+        bucketGroupedSplitMultimap.put(bucketId, inSplit);
+      }
+    }
+
+    return bucketGroupedSplitMultimap;
+  }
+
+  /**
+   * get the size estimates for each bucket in tasks. This is used to make sure
+   * we allocate the head room evenly
+   */
+  private Map<Integer, Integer> estimateBucketSizes(int availableSlots, float waves,
+      Map<Integer, Collection<InputSplit>> bucketSplitMap) {
+
+    // mapping of bucket id to size of all splits in bucket in bytes
+    Map<Integer, Long> bucketSizeMap = new HashMap<Integer, Long>();
+
+    // mapping of bucket id to number of required tasks to run
+    Map<Integer, Integer> bucketTaskMap = new HashMap<Integer, Integer>();
+
+    // compute the total size per bucket
+    long totalSize = 0;
+    for (int bucketId : bucketSplitMap.keySet()) {
+      long size = 0;
+      for (InputSplit s : bucketSplitMap.get(bucketId)) {
+        FileSplit fsplit = (FileSplit) s;
+        size += fsplit.getLength();
+        totalSize += fsplit.getLength();
+      }
+      bucketSizeMap.put(bucketId, size);
+    }
+
+    // compute the number of tasks
+    for (int bucketId : bucketSizeMap.keySet()) {
+      int numEstimatedTasks = 0;
+      if (totalSize != 0) {
+        // availableSlots * waves => desired slots to fill
+        // sizePerBucket/totalSize => weight for particular bucket. weights add
+        // up to 1.
+        numEstimatedTasks =
+            (int) (availableSlots * waves * bucketSizeMap.get(bucketId) / totalSize);
+      }
+
+      LOG.info("Estimated number of tasks: " + numEstimatedTasks + " for bucket " + bucketId);
+      if (numEstimatedTasks == 0) {
+        numEstimatedTasks = 1;
+      }
+      bucketTaskMap.put(bucketId, numEstimatedTasks);
+    }
+
+    return bucketTaskMap;
+  }
+
+  public List<TaskLocationHint> createTaskLocationHints(InputSplit[] splits) throws IOException {
+
+    List<TaskLocationHint> locationHints = Lists.newArrayListWithCapacity(splits.length);
+
+    for (InputSplit split : splits) {
+      String rack = (split instanceof TezGroupedSplit) ? ((TezGroupedSplit) split).getRack() : null;
+      if (rack == null) {
+        if (split.getLocations() != null) {
+          locationHints.add(new TaskLocationHint(new HashSet<String>(Arrays.asList(split
+              .getLocations())), null));
+        } else {
+          locationHints.add(new TaskLocationHint(null, null));
+        }
+      } else {
+        locationHints.add(new TaskLocationHint(null, Collections.singleton(rack)));
+      }
+    }
+
+    return locationHints;
+  }
+}

Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/plan/PartitionDesc.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/plan/PartitionDesc.java?rev=1598829&r1=1598828&r2=1598829&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/plan/PartitionDesc.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/plan/PartitionDesc.java Sat May 31 06:31:26 2014
@@ -118,16 +118,23 @@ public class PartitionDesc implements Se
     return inputFileFormatClass;
   }
 
-  /**
-   * Return a deserializer object corresponding to the partitionDesc.
-   */
-  public Deserializer getDeserializer(Configuration conf) throws Exception {
+  public String getDeserializerClassName() {
     Properties schema = getProperties();
     String clazzName = schema.getProperty(serdeConstants.SERIALIZATION_LIB);
     if (clazzName == null) {
       throw new IllegalStateException("Property " + serdeConstants.SERIALIZATION_LIB +
           " cannot be null");
     }
+
+    return clazzName;
+  }
+
+  /**
+   * Return a deserializer object corresponding to the partitionDesc.
+   */
+  public Deserializer getDeserializer(Configuration conf) throws Exception {
+    Properties schema = getProperties();
+    String clazzName = getDeserializerClassName();
     Deserializer deserializer = ReflectionUtils.newInstance(conf.getClassByName(clazzName)
         .asSubclass(Deserializer.class), conf);
     SerDeUtils.initializeSerDe(deserializer, conf, getTableDesc().getProperties(), schema);