You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by ss...@apache.org on 2014/03/29 05:20:15 UTC

git commit: TEZ-932. Add a weighted scaling initial memory allocator. (sseth)

Repository: incubator-tez
Updated Branches:
  refs/heads/master 34a557d13 -> 792cda593


TEZ-932. Add a weighted scaling initial memory allocator. (sseth)


Project: http://git-wip-us.apache.org/repos/asf/incubator-tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-tez/commit/792cda59
Tree: http://git-wip-us.apache.org/repos/asf/incubator-tez/tree/792cda59
Diff: http://git-wip-us.apache.org/repos/asf/incubator-tez/diff/792cda59

Branch: refs/heads/master
Commit: 792cda59370ff493b7672da4f5c80e1929eecd61
Parents: 34a557d
Author: Siddharth Seth <ss...@apache.org>
Authored: Fri Mar 28 21:19:01 2014 -0700
Committer: Siddharth Seth <ss...@apache.org>
Committed: Fri Mar 28 21:19:01 2014 -0700

----------------------------------------------------------------------
 pom.xml                                         |   6 +
 .../org/apache/tez/common/TezJobConfig.java     |  51 +++-
 .../apache/tez/dag/api/TezConfiguration.java    |  17 --
 .../resources/InitialMemoryAllocator.java       |  29 ++
 .../resources/InitialMemoryRequestContext.java  |  62 ++++
 .../common/resources/MemoryDistributor.java     | 196 +++----------
 .../common/resources/ScalingAllocator.java      | 107 +++++++
 .../common/resources/TestMemoryDistributor.java |  60 ++--
 tez-runtime-library/pom.xml                     |   6 +
 .../WeightedScalingMemoryDistributor.java       | 283 +++++++++++++++++++
 .../TestWeightedScalingMemoryDistributor.java   | 181 ++++++++++++
 11 files changed, 787 insertions(+), 211 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/792cda59/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index f2b562c..c8767d9 100644
--- a/pom.xml
+++ b/pom.xml
@@ -110,6 +110,12 @@
       </dependency>
       <dependency>
         <groupId>org.apache.tez</groupId>
+        <artifactId>tez-runtime-internals</artifactId>
+        <version>${project.version}</version>
+        <type>test-jar</type>
+      </dependency>
+      <dependency>
+        <groupId>org.apache.tez</groupId>
         <artifactId>tez-runtime-library</artifactId>
         <version>${project.version}</version>
       </dependency>

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/792cda59/tez-api/src/main/java/org/apache/tez/common/TezJobConfig.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/common/TezJobConfig.java b/tez-api/src/main/java/org/apache/tez/common/TezJobConfig.java
index 0af0ffb..e5e409b 100644
--- a/tez-api/src/main/java/org/apache/tez/common/TezJobConfig.java
+++ b/tez-api/src/main/java/org/apache/tez/common/TezJobConfig.java
@@ -19,6 +19,8 @@ package org.apache.tez.common;
 
 import org.apache.hadoop.classification.InterfaceAudience.Private;
 import org.apache.hadoop.classification.InterfaceStability.Evolving;
+import org.apache.hadoop.classification.InterfaceStability.Unstable;
+import org.apache.tez.dag.api.TezConfiguration;
 
 
 /**
@@ -31,6 +33,7 @@ import org.apache.hadoop.classification.InterfaceStability.Evolving;
 public class TezJobConfig {
 
 
+  public static final String TEZ_TASK_PREFIX = TezConfiguration.TEZ_TASK_PREFIX;
 
 
   /** The number of milliseconds between progress reports. */
@@ -263,7 +266,7 @@ public class TezJobConfig {
   public static final String TEZ_RUNTIME_SHUFFLE_MERGE_PERCENT = 
       "tez.runtime.shuffle.merge.percent";
   public static final float DEFAULT_TEZ_RUNTIME_SHUFFLE_MERGE_PERCENT = 0.90f;
-  
+
   /**
    * TODO TEZAM3 default value ?
    */
@@ -344,6 +347,52 @@ public class TezJobConfig {
   
   /** Defines the ProcessTree implementation which will be used to collect resource utilization. */
   public static final String TEZ_RESOURCE_CALCULATOR_PROCESS_TREE_CLASS = "tez.resource.calculator.process-tree.class";
+
+  /**
+   * Whether to scale down memory requested by each component if the total
+   * exceeds the available JVM memory
+   */
+  @Private @Unstable
+  public static final String TEZ_TASK_SCALE_MEMORY_ENABLED = TEZ_TASK_PREFIX
+      + "scale.memory.enabled";
+  public static final boolean TEZ_TASK_SCALE_MEMORY_ENABLED_DEFAULT = true;
+  
+  /**
+   * The allocator to use for initial memory allocation
+   */
+  @Private @Unstable
+  public static final String TEZ_TASK_SCALE_MEMORY_ALLOCATOR_CLASS = TEZ_TASK_PREFIX
+      + "scale.memory.allocator.class";
+  public static final String TEZ_TASK_SCALE_MEMORY_ALLOCATOR_CLASS_DEFAULT = "org.apache.tez.runtime.common.resources.ScalingAllocator";
+  
+  /**
+   * The fraction of the JVM memory which will not be considered for allocation.
+   * No defaults, since there are pre-existing defaults based on different scenarios.
+   */
+  @Private @Unstable
+  public static final String TEZ_TASK_SCALE_MEMORY_RESERVE_FRACTION = TEZ_TASK_PREFIX
+      + "scale.memory.reserve-fraction";
+
+  /**
+   * Fraction of available memory to reserve per input/output. This amount is
+   * removed from the total available pool before allocation and is for factoring in overheads.
+   */
+  @Private @Unstable
+  public static final String TEZ_TASK_SCALE_MEMORY_ADDITIONAL_RESERVATION_FRACTION_PER_IO = TEZ_TASK_PREFIX
+      + "scale.memory.additional.reservation.fraction.per-io";
+
+  /**
+   * Max cumulative total reservation for additional IOs.
+   */
+  public static final String TEZ_TASK_SCALE_MEMORY_ADDITIONAL_RESERVATION_FRACTION_MAX = TEZ_TASK_PREFIX
+      + "scale.memory.additional reservation.fraction.max";
+  /*
+   * Weighted ratios for individual component types in the RuntimeLibrary.
+   * e.g. PARTITIONED_UNSORTED_OUTPUT:0,UNSORTED_INPUT:1,SORTED_OUTPUT:2,SORTED_MERGED_INPUT:3,PROCESSOR:1,OTHER:1
+   */
+  @Private @Unstable
+  public static final String TEZ_TASK_SCALE_MEMORY_WEIGHTED_RATIOS = TEZ_TASK_PREFIX
+      + "initial.memory.scale.ratios";
   
   /**
    * Path to a credentials file located on the local file system with serialized credentials

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/792cda59/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java b/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
index 3d42577..63b0335 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
@@ -174,23 +174,6 @@ public class TezConfiguration extends Configuration {
   public static final String TEZ_TASK_MAX_EVENTS_PER_HEARTBEAT = TEZ_TASK_PREFIX
       + "max-events-per-heartbeat.max";
   public static final int TEZ_TASK_MAX_EVENTS_PER_HEARTBEAT_DEFAULT = 100;
-  
-  /**
-   * Whether to scale down memory requested by each component if the total
-   * exceeds the available JVM memory
-   */
-  @Unstable
-  public static final String TEZ_TASK_SCALE_MEMORY_ENABLED = TEZ_TASK_PREFIX
-      + "scale.memory.enabled";
-  public static final boolean TEZ_TASK_SCALE_MEMORY_ENABLED_DEFAULT = true;
-  
-  /**
-   * The fraction of the JVM memory which will not be considered for allocation.
-   * No defaults, since there are pre-existing defaults based on different scenarios.
-   */
-  @Unstable
-  public static final String TEZ_TASK_SCALE_MEMORY_RESERVE_FRACTION = TEZ_TASK_PREFIX
-      + "scale.memory.reserve-fraction";
 
   /**
    * Whether to generate counters per IO or not. Enabling this will rename

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/792cda59/tez-api/src/main/java/org/apache/tez/runtime/common/resources/InitialMemoryAllocator.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/runtime/common/resources/InitialMemoryAllocator.java b/tez-api/src/main/java/org/apache/tez/runtime/common/resources/InitialMemoryAllocator.java
new file mode 100644
index 0000000..62b9433
--- /dev/null
+++ b/tez-api/src/main/java/org/apache/tez/runtime/common/resources/InitialMemoryAllocator.java
@@ -0,0 +1,29 @@
+package org.apache.tez.runtime.common.resources;
+
+import org.apache.hadoop.classification.InterfaceAudience.Private;
+import org.apache.hadoop.conf.Configurable;
+
+
+
+/**
+ * Used to balance memory requests before a task starts executing.
+ */
+@Private
+public interface InitialMemoryAllocator extends Configurable {
+
+  /**
+   * @param availableForAllocation
+   *          memory available for allocation
+   * @param numTotalInputs
+   *          number of inputs for the task
+   * @param numTotalOutputs
+   *          number of outputs for the tasks
+   * @param requests
+   *          Iterable view of requests received
+   * @return list of allocations, one per request. This must be ordered in the
+   *         same order of the requests.
+   */
+  public abstract Iterable<Long> assignMemory(long availableForAllocation, int numTotalInputs,
+      int numTotalOutputs, Iterable<InitialMemoryRequestContext> requests);
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/792cda59/tez-api/src/main/java/org/apache/tez/runtime/common/resources/InitialMemoryRequestContext.java
----------------------------------------------------------------------
diff --git a/tez-api/src/main/java/org/apache/tez/runtime/common/resources/InitialMemoryRequestContext.java b/tez-api/src/main/java/org/apache/tez/runtime/common/resources/InitialMemoryRequestContext.java
new file mode 100644
index 0000000..4f3fc46
--- /dev/null
+++ b/tez-api/src/main/java/org/apache/tez/runtime/common/resources/InitialMemoryRequestContext.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.common.resources;
+
+import org.apache.hadoop.classification.InterfaceAudience.Private;
+
+
+@Private
+public class InitialMemoryRequestContext {
+
+  public static enum ComponentType {
+    INPUT, OUTPUT, PROCESSOR
+  }
+
+  private long requestedSize;
+  // TODO Replace this with the entire descriptor at some point. ComponentType
+  // automatically goes away.
+  private String componentClassName;
+  private ComponentType componentType;
+  private String componentVertexName;
+
+  public InitialMemoryRequestContext(long requestedSize, String componentClassName,
+      ComponentType componentType, String componentVertexName) {
+    this.requestedSize = requestedSize;
+    this.componentClassName = componentClassName;
+    this.componentType = componentType;
+    this.componentVertexName = componentVertexName;
+  }
+
+  public long getRequestedSize() {
+    return requestedSize;
+  }
+
+  public String getComponentClassName() {
+    return componentClassName;
+  }
+
+  public ComponentType getComponentType() {
+    return componentType;
+  }
+
+  public String getComponentVertexName() {
+    return componentVertexName;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/792cda59/tez-runtime-internals/src/main/java/org/apache/tez/runtime/common/resources/MemoryDistributor.java
----------------------------------------------------------------------
diff --git a/tez-runtime-internals/src/main/java/org/apache/tez/runtime/common/resources/MemoryDistributor.java b/tez-runtime-internals/src/main/java/org/apache/tez/runtime/common/resources/MemoryDistributor.java
index 91a0b22..7126577 100644
--- a/tez-runtime-internals/src/main/java/org/apache/tez/runtime/common/resources/MemoryDistributor.java
+++ b/tez-runtime-internals/src/main/java/org/apache/tez/runtime/common/resources/MemoryDistributor.java
@@ -18,7 +18,6 @@
 
 package org.apache.tez.runtime.common.resources;
 
-import java.text.DecimalFormat;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.LinkedList;
@@ -31,7 +30,8 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.classification.InterfaceAudience.Private;
 import org.apache.hadoop.conf.Configuration;
-import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.common.RuntimeUtils;
+import org.apache.tez.common.TezJobConfig;
 import org.apache.tez.dag.api.TezEntityDescriptor;
 import org.apache.tez.dag.api.TezUncheckedException;
 import org.apache.tez.runtime.api.MemoryUpdateCallback;
@@ -44,7 +44,6 @@ import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Function;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.Iterables;
-import com.google.common.collect.Lists;
 
 // Not calling this a MemoryManager explicitly. Not yet anyway.
 @Private
@@ -54,24 +53,16 @@ public class MemoryDistributor {
 
   private final int numTotalInputs;
   private final int numTotalOutputs;
+  private final Configuration conf;
   
   private AtomicInteger numInputsSeen = new AtomicInteger(0);
   private AtomicInteger numOutputsSeen = new AtomicInteger(0);
 
   private long totalJvmMemory;
-  private volatile long totalAssignableMemory;
   private final boolean isEnabled;
-  private final boolean reserveFractionConfigured;
-  private float reserveFraction;
   private final Set<TezTaskContext> dupSet = Collections
       .newSetFromMap(new ConcurrentHashMap<TezTaskContext, Boolean>());
   private final List<RequestorInfo> requestList;
-  
-  // Maybe make the reserve fraction configurable. Or scale it based on JVM heap.
-  @VisibleForTesting
-  static final float RESERVE_FRACTION_NO_PROCESSOR = 0.3f;
-  @VisibleForTesting
-  static final float RESERVE_FRACTION_WITH_PROCESSOR = 0.05f;
 
   /**
    * @param numInputs
@@ -82,28 +73,18 @@ public class MemoryDistributor {
    *          Tez specific task configuration
    */
   public MemoryDistributor(int numTotalInputs, int numTotalOutputs, Configuration conf) {
-    isEnabled = conf.getBoolean(TezConfiguration.TEZ_TASK_SCALE_MEMORY_ENABLED,
-        TezConfiguration.TEZ_TASK_SCALE_MEMORY_ENABLED_DEFAULT);
-    if (conf.get(TezConfiguration.TEZ_TASK_SCALE_MEMORY_RESERVE_FRACTION) != null) {
-      reserveFractionConfigured = true;
-      reserveFraction = conf.getFloat(TezConfiguration.TEZ_TASK_SCALE_MEMORY_RESERVE_FRACTION,
-          RESERVE_FRACTION_NO_PROCESSOR);
-      Preconditions.checkArgument(reserveFraction >= 0.0f && reserveFraction <= 1.0f);
-    } else {
-      reserveFractionConfigured = false;
-      reserveFraction = RESERVE_FRACTION_NO_PROCESSOR;
-    }
+    this.conf = conf;
+    isEnabled = conf.getBoolean(TezJobConfig.TEZ_TASK_SCALE_MEMORY_ENABLED,
+        TezJobConfig.TEZ_TASK_SCALE_MEMORY_ENABLED_DEFAULT);
+    
 
     this.numTotalInputs = numTotalInputs;
     this.numTotalOutputs = numTotalOutputs;
     this.totalJvmMemory = Runtime.getRuntime().maxMemory();
-    computeAssignableMemory();
     this.requestList = Collections.synchronizedList(new LinkedList<RequestorInfo>());
     LOG.info("InitialMemoryDistributor (isEnabled=" + isEnabled + ") invoked with: numInputs="
         + numTotalInputs + ", numOutputs=" + numTotalOutputs
-        + ". Configuration: reserveFractionSpecified= " + reserveFractionConfigured
-        + ", reserveFraction=" + reserveFraction + ", JVM.maxFree=" + totalJvmMemory
-        + ", assignableMemory=" + totalAssignableMemory);
+        + ", JVM.maxFree=" + totalJvmMemory);
   }
 
 
@@ -123,9 +104,9 @@ public class MemoryDistributor {
   public void makeInitialAllocations() {
     Preconditions.checkState(numInputsSeen.get() == numTotalInputs, "All inputs are expected to ask for memory");
     Preconditions.checkState(numOutputsSeen.get() == numTotalOutputs, "All outputs are expected to ask for memory");
-    Iterable<RequestContext> requestContexts = Iterables.transform(requestList,
-        new Function<RequestorInfo, RequestContext>() {
-          public RequestContext apply(RequestorInfo requestInfo) {
+    Iterable<InitialMemoryRequestContext> requestContexts = Iterables.transform(requestList,
+        new Function<RequestorInfo, InitialMemoryRequestContext>() {
+          public InitialMemoryRequestContext apply(RequestorInfo requestInfo) {
             return requestInfo.getRequestContext();
           }
         });
@@ -138,8 +119,12 @@ public class MemoryDistributor {
         }
       });
     } else {
-      InitialMemoryAllocator allocator = new ScalingAllocator();
-      allocations = allocator.assignMemory(totalAssignableMemory, numTotalInputs, numTotalOutputs,
+      String allocatorClassName = conf.get(TezJobConfig.TEZ_TASK_SCALE_MEMORY_ALLOCATOR_CLASS,
+          TezJobConfig.TEZ_TASK_SCALE_MEMORY_ALLOCATOR_CLASS_DEFAULT);
+      LOG.info("Using Allocator class: " + allocatorClassName);
+      InitialMemoryAllocator allocator = RuntimeUtils.createClazzInstance(allocatorClassName);
+      allocator.setConf(conf);
+      allocations = allocator.assignMemory(totalJvmMemory, numTotalInputs, numTotalOutputs,
           Iterables.unmodifiableIterable(requestContexts));
       validateAllocations(allocations, requestList.size());
     }
@@ -166,11 +151,6 @@ public class MemoryDistributor {
   @VisibleForTesting
   void setJvmMemory(long size) {
     this.totalJvmMemory = size;
-    computeAssignableMemory();
-  }
-  
-  private void computeAssignableMemory() {
-    this.totalAssignableMemory = totalJvmMemory - ((long) (reserveFraction * totalJvmMemory));
   }
 
   private long registerRequest(long requestSize, MemoryUpdateCallback callback,
@@ -203,13 +183,6 @@ public class MemoryDistributor {
       break;
     }
     requestList.add(requestInfo);
-    if (!reserveFractionConfigured
-        && requestInfo.getRequestContext().getComponentType() == RequestContext.ComponentType.PROCESSOR) {
-      reserveFraction = RESERVE_FRACTION_WITH_PROCESSOR;
-      computeAssignableMemory();
-      LOG.info("Processor request for initial memory. Updating assignableMemory to : "
-          + totalAssignableMemory);
-    }
     return -1;
   }
 
@@ -224,155 +197,50 @@ public class MemoryDistributor {
     Preconditions.checkState(numAllocations == numRequestors,
         "Number of allocations must match number of requestors. Allocated=" + numAllocations
             + ", Requests: " + numRequestors);
-    Preconditions.checkState(totalAllocated <= totalAssignableMemory,
+    Preconditions.checkState(totalAllocated <= totalJvmMemory,
         "Total allocation should be <= availableMem. TotalAllocated: " + totalAllocated
-            + ", totalAssignable: " + totalAssignableMemory);
-  }
-
-  /**
-   * Used to balance memory requests before a task starts executing.
-   */
-  private static interface InitialMemoryAllocator {
-    /**
-     * @param availableForAllocation
-     *          memory available for allocation
-     * @param numTotalInputs
-     *          number of inputs for the task
-     * @param numTotalOutputs
-     *          number of outputs for the tasks
-     * @param requests
-     *          Iterable view of requests received
-     * @return list of allocations, one per request. This must be ordered in the
-     *         same order of the requests.
-     */
-    Iterable<Long> assignMemory(long availableForAllocation, int numTotalInputs,
-        int numTotalOutputs, Iterable<RequestContext> requests);
+            + ", totalJvmMemory: " + totalJvmMemory);
   }
 
-  // Make this a public class if pulling the interface out.
-  // Custom allocator based on The classes being used. Broadcast typically needs
-  // a lot less than sort etc.
-  private static class RequestContext {
-
-    private static enum ComponentType {
-      INPUT, OUTPUT, PROCESSOR
-    }
-
-    private long requestedSize;
-    private String componentClassName;
-    private ComponentType componentType;
-    private String componentVertexName;
 
-    public RequestContext(long requestedSize, String componentClassName,
-        ComponentType componentType, String componentVertexName) {
-      this.requestedSize = requestedSize;
-      this.componentClassName = componentClassName;
-      this.componentType = componentType;
-      this.componentVertexName = componentVertexName;
-    }
-
-    public long getRequestedSize() {
-      return requestedSize;
-    }
-
-    public String getComponentClassName() {
-      return componentClassName;
-    }
-
-    public ComponentType getComponentType() {
-      return componentType;
-    }
+  private static class RequestorInfo {
 
-    public String getComponentVertexName() {
-      return componentVertexName;
-    }
-  }
+    private static final Log LOG = LogFactory.getLog(RequestorInfo.class);
 
-  @Private
-  private static class RequestorInfo {
     private final MemoryUpdateCallback callback;
-    private final RequestContext requestContext;
+    private final InitialMemoryRequestContext requestContext;
 
-    RequestorInfo(TezTaskContext taskContext, long requestSize,
+    public RequestorInfo(TezTaskContext taskContext, long requestSize,
         final MemoryUpdateCallback callback, TezEntityDescriptor descriptor) {
-      RequestContext.ComponentType type;
+      InitialMemoryRequestContext.ComponentType type;
       String componentVertexName;
       if (taskContext instanceof TezInputContext) {
-        type = RequestContext.ComponentType.INPUT;
+        type = InitialMemoryRequestContext.ComponentType.INPUT;
         componentVertexName = ((TezInputContext) taskContext).getSourceVertexName();
       } else if (taskContext instanceof TezOutputContext) {
-        type = RequestContext.ComponentType.OUTPUT;
+        type = InitialMemoryRequestContext.ComponentType.OUTPUT;
         componentVertexName = ((TezOutputContext) taskContext).getDestinationVertexName();
       } else if (taskContext instanceof TezProcessorContext) {
-        type = RequestContext.ComponentType.PROCESSOR;
+        type = InitialMemoryRequestContext.ComponentType.PROCESSOR;
         componentVertexName = ((TezProcessorContext) taskContext).getTaskVertexName();
       } else {
         throw new IllegalArgumentException("Unknown type of entityContext: "
             + taskContext.getClass().getName());
       }
-      this.requestContext = new RequestContext(requestSize, descriptor.getClassName(), type,
-          componentVertexName);
+      this.requestContext = new InitialMemoryRequestContext(requestSize, descriptor.getClassName(),
+          type, componentVertexName);
       this.callback = callback;
-      LOG.info("Received request: " + requestSize + ", type: " + type
-          + ", componentVertexName: " + componentVertexName);
+      LOG.info("Received request: " + requestSize + ", type: " + type + ", componentVertexName: "
+          + componentVertexName);
     }
 
     public MemoryUpdateCallback getCallback() {
       return callback;
     }
 
-    public RequestContext getRequestContext() {
+    public InitialMemoryRequestContext getRequestContext() {
       return requestContext;
     }
   }
 
-  private static class ScalingAllocator implements InitialMemoryAllocator {
-
-    @Override
-    public Iterable<Long> assignMemory(long availableForAllocation, int numTotalInputs,
-        int numTotalOutputs, Iterable<RequestContext> requests) {
-      int numRequests = 0;
-      long totalRequested = 0;
-      for (RequestContext context : requests) {
-        totalRequested += context.getRequestedSize();
-        numRequests++;
-      }
-
-      long totalJvmMem = Runtime.getRuntime().maxMemory();
-      double ratio = totalRequested / (double) totalJvmMem;
-      LOG.info("Scaling Requests. TotalRequested: " + totalRequested + ", TotalJVMMem: "
-          + totalJvmMem + ", TotalAvailable: " + availableForAllocation
-          + ", TotalRequested/TotalHeap:" + new DecimalFormat("0.00").format(ratio));
-
-      if (totalRequested < availableForAllocation || totalRequested == 0) {
-        // Not scaling up requests. Assuming things were setup correctly by
-        // users in this case, keeping Processor, caching etc in mind.
-        return Lists.newArrayList(Iterables.transform(requests,
-            new Function<RequestContext, Long>() {
-              public Long apply(RequestContext requestContext) {
-                return requestContext.getRequestedSize();
-              }
-            }));
-      }
-
-      List<Long> allocations = Lists.newArrayListWithCapacity(numRequests);
-      for (RequestContext request : requests) {
-        long requestedSize = request.getRequestedSize();
-        if (requestedSize == 0) {
-          allocations.add(0l);
-          if (LOG.isDebugEnabled()) {
-            LOG.debug("Scaling requested: 0 to allocated: 0");
-          }
-        } else {
-          long allocated = (long) ((requestedSize / (double) totalRequested) * availableForAllocation);
-          allocations.add(allocated);
-          if (LOG.isDebugEnabled()) {
-            LOG.debug("Scaling requested: " + requestedSize + " to allocated: " + allocated);  
-          }
-          
-        }
-      }
-      return allocations;
-    }
-  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/792cda59/tez-runtime-internals/src/main/java/org/apache/tez/runtime/common/resources/ScalingAllocator.java
----------------------------------------------------------------------
diff --git a/tez-runtime-internals/src/main/java/org/apache/tez/runtime/common/resources/ScalingAllocator.java b/tez-runtime-internals/src/main/java/org/apache/tez/runtime/common/resources/ScalingAllocator.java
new file mode 100644
index 0000000..655521a
--- /dev/null
+++ b/tez-runtime-internals/src/main/java/org/apache/tez/runtime/common/resources/ScalingAllocator.java
@@ -0,0 +1,107 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.common.resources;
+
+import java.text.DecimalFormat;
+import java.util.List;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.common.TezJobConfig;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+
+public class ScalingAllocator implements InitialMemoryAllocator {
+
+  private static final Log LOG = LogFactory.getLog(ScalingAllocator.class);
+
+  @VisibleForTesting
+  static final double DEFAULT_RESERVE_FRACTION = 0.3d;
+
+  private Configuration conf;
+
+  @Override
+  public Iterable<Long> assignMemory(long availableForAllocation, int numTotalInputs,
+      int numTotalOutputs, Iterable<InitialMemoryRequestContext> requests) {
+
+    int numRequests = 0;
+    long totalRequested = 0;
+    for (InitialMemoryRequestContext context : requests) {
+      totalRequested += context.getRequestedSize();
+      numRequests++;
+    }
+
+    // Take a certain amount of memory away for general usage.
+    double reserveFraction = conf.getDouble(TezJobConfig.TEZ_TASK_SCALE_MEMORY_RESERVE_FRACTION,
+        DEFAULT_RESERVE_FRACTION);
+    Preconditions.checkState(reserveFraction >= 0.0d && reserveFraction <= 1.0d);
+    availableForAllocation = (long) (availableForAllocation - (reserveFraction * availableForAllocation));
+
+    long totalJvmMem = Runtime.getRuntime().maxMemory();
+    double ratio = totalRequested / (double) totalJvmMem;
+    LOG.info("Scaling Requests. TotalRequested: " + totalRequested + ", TotalJVMHeap: "
+        + totalJvmMem + ", TotalAvailable: " + availableForAllocation
+        + ", TotalRequested/TotalJVMHeap:" + new DecimalFormat("0.00").format(ratio));
+
+    if (totalRequested < availableForAllocation || totalRequested == 0) {
+      // Not scaling up requests. Assuming things were setup correctly by
+      // users in this case, keeping Processor, caching etc in mind.
+      return Lists.newArrayList(Iterables.transform(requests,
+          new Function<InitialMemoryRequestContext, Long>() {
+        public Long apply(InitialMemoryRequestContext requestContext) {
+          return requestContext.getRequestedSize();
+        }
+      }));
+    }
+
+    List<Long> allocations = Lists.newArrayListWithCapacity(numRequests);
+    for (InitialMemoryRequestContext request : requests) {
+      long requestedSize = request.getRequestedSize();
+      if (requestedSize == 0) {
+        allocations.add(0l);
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Scaling requested: 0 to allocated: 0");
+        }
+      } else {
+        long allocated = (long) ((requestedSize / (double) totalRequested) * availableForAllocation);
+        allocations.add(allocated);
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Scaling requested: " + requestedSize + " to allocated: " + allocated);  
+        }
+
+      }
+    }
+    return allocations;
+  }
+
+  @Override
+  public void setConf(Configuration conf) {
+    this.conf = conf;
+  }
+
+  @Override
+  public Configuration getConf() {
+    return this.conf;
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/792cda59/tez-runtime-internals/src/test/java/org/apache/tez/runtime/common/resources/TestMemoryDistributor.java
----------------------------------------------------------------------
diff --git a/tez-runtime-internals/src/test/java/org/apache/tez/runtime/common/resources/TestMemoryDistributor.java b/tez-runtime-internals/src/test/java/org/apache/tez/runtime/common/resources/TestMemoryDistributor.java
index f5bc8d6..61c2e86 100644
--- a/tez-runtime-internals/src/test/java/org/apache/tez/runtime/common/resources/TestMemoryDistributor.java
+++ b/tez-runtime-internals/src/test/java/org/apache/tez/runtime/common/resources/TestMemoryDistributor.java
@@ -24,23 +24,30 @@ import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 
 import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.common.TezJobConfig;
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.OutputDescriptor;
 import org.apache.tez.dag.api.ProcessorDescriptor;
-import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.runtime.api.MemoryUpdateCallback;
 import org.apache.tez.runtime.api.TezInputContext;
 import org.apache.tez.runtime.api.TezOutputContext;
 import org.apache.tez.runtime.api.TezProcessorContext;
+import org.junit.Before;
 import org.junit.Test;
 
 public class TestMemoryDistributor {
 
-
+  protected Configuration conf = new Configuration();
+  
+  @Before
+  public void setup() {
+    conf.setBoolean(TezJobConfig.TEZ_TASK_SCALE_MEMORY_ENABLED, true);
+    conf.set(TezJobConfig.TEZ_TASK_SCALE_MEMORY_ALLOCATOR_CLASS,
+        ScalingAllocator.class.getName());
+  }
+  
   @Test(timeout = 5000)
   public void testScalingNoProcessor() {
-    Configuration conf = new Configuration();
-    conf.setBoolean(TezConfiguration.TEZ_TASK_SCALE_MEMORY_ENABLED, true);
     MemoryDistributor dist = new MemoryDistributor(2, 1, conf);
     
     dist.setJvmMemory(10000l);
@@ -76,11 +83,9 @@ public class TestMemoryDistributor {
   @Test(timeout = 5000)
   public void testScalingNoProcessor2() {
     // Real world values
-    Configuration conf = new Configuration();
-    conf.setBoolean(TezConfiguration.TEZ_TASK_SCALE_MEMORY_ENABLED, true);
     MemoryDistributor dist = new MemoryDistributor(2, 0, conf);
     
-    dist.setJvmMemory(207093760l);
+    dist.setJvmMemory(209715200l);
 
     // First request
     MemoryUpdateCallbackForTest e1Callback = new MemoryUpdateCallbackForTest();
@@ -92,18 +97,16 @@ public class TestMemoryDistributor {
     MemoryUpdateCallbackForTest e2Callback = new MemoryUpdateCallbackForTest();
     TezInputContext e2InputContext2 = createTestInputContext();
     InputDescriptor e2InDesc2 = createTestInputDescriptor();
-    dist.requestMemory(144965632l, e2Callback, e2InputContext2, e2InDesc2);
+    dist.requestMemory(157286400l, e2Callback, e2InputContext2, e2InDesc2);
     
     dist.makeInitialAllocations();
 
-    assertEquals(60846013, e1Callback.assigned);
-    assertEquals(84119614, e2Callback.assigned);
+    assertEquals(58720256l, e1Callback.assigned);
+    assertEquals(88080384l, e2Callback.assigned);
   }
   
   @Test(timeout = 5000)
   public void testScalingProcessor() {
-    Configuration conf = new Configuration();
-    conf.setBoolean(TezConfiguration.TEZ_TASK_SCALE_MEMORY_ENABLED, true);
     MemoryDistributor dist = new MemoryDistributor(2, 1, conf);
     
     dist.setJvmMemory(10000l);
@@ -135,20 +138,20 @@ public class TestMemoryDistributor {
     
     dist.makeInitialAllocations();
     
-    // Total available: 95% of 10K = 9500
+    // Total available: 70% of 10K = 7000
     // 4 requests - 10K, 10K, 5K, 5K
-    // Scale down to - 3166.66, 3166.66, 1583.33, 1583.33
-    assertTrue(e1Callback.assigned >= 3166 && e1Callback.assigned <= 3177);
-    assertTrue(e2Callback.assigned >= 3166 && e2Callback.assigned <= 3177);
-    assertTrue(e3Callback.assigned >= 1583 && e3Callback.assigned <= 1583);
-    assertTrue(e4Callback.assigned >= 1583 && e4Callback.assigned <= 1583);
+    // Scale down to - 2333.33, 2333.33, 1166.66, 1166.66
+    assertTrue(e1Callback.assigned >= 2333 && e1Callback.assigned <= 2334);
+    assertTrue(e2Callback.assigned >= 2333 && e2Callback.assigned <= 2334);
+    assertTrue(e3Callback.assigned >= 1166 && e3Callback.assigned <= 1167);
+    assertTrue(e4Callback.assigned >= 1166 && e4Callback.assigned <= 1167);
   }
   
   @Test(timeout = 5000)
   public void testScalingDisabled() {
     // Real world values
-    Configuration conf = new Configuration();
-    conf.setBoolean(TezConfiguration.TEZ_TASK_SCALE_MEMORY_ENABLED, false);
+    Configuration conf = new Configuration(this.conf);
+    conf.setBoolean(TezJobConfig.TEZ_TASK_SCALE_MEMORY_ENABLED, false);
     MemoryDistributor dist = new MemoryDistributor(2, 0, conf);
     
     dist.setJvmMemory(207093760l);
@@ -173,9 +176,8 @@ public class TestMemoryDistributor {
   
   @Test(timeout = 5000)
   public void testReserveFractionConfigured() {
-    Configuration conf = new Configuration();
-    conf.setBoolean(TezConfiguration.TEZ_TASK_SCALE_MEMORY_ENABLED, true);
-    conf.setFloat(TezConfiguration.TEZ_TASK_SCALE_MEMORY_RESERVE_FRACTION, 0.5f);
+    Configuration conf = new Configuration(this.conf);
+    conf.setDouble(TezJobConfig.TEZ_TASK_SCALE_MEMORY_RESERVE_FRACTION, 0.5d);
     MemoryDistributor dist = new MemoryDistributor(2, 1, conf);
     
     dist.setJvmMemory(10000l);
@@ -219,39 +221,39 @@ public class TestMemoryDistributor {
     }
   }
 
-  private InputDescriptor createTestInputDescriptor() {
+  protected InputDescriptor createTestInputDescriptor() {
     InputDescriptor desc = mock(InputDescriptor.class);
     doReturn("InputClass").when(desc).getClassName();
     return desc;
   }
 
-  private OutputDescriptor createTestOutputDescriptor() {
+  protected OutputDescriptor createTestOutputDescriptor() {
     OutputDescriptor desc = mock(OutputDescriptor.class);
     doReturn("OutputClass").when(desc).getClassName();
     return desc;
   }
 
-  private ProcessorDescriptor createTestProcessorDescriptor() {
+  protected ProcessorDescriptor createTestProcessorDescriptor() {
     ProcessorDescriptor desc = mock(ProcessorDescriptor.class);
     doReturn("ProcessorClass").when(desc).getClassName();
     return desc;
   }
 
-  private TezInputContext createTestInputContext() {
+  protected TezInputContext createTestInputContext() {
     TezInputContext context = mock(TezInputContext.class);
     doReturn("input").when(context).getSourceVertexName();
     doReturn("task").when(context).getTaskVertexName();
     return context;
   }
   
-  private TezOutputContext createTestOutputContext() {
+  protected TezOutputContext createTestOutputContext() {
     TezOutputContext context = mock(TezOutputContext.class);
     doReturn("output").when(context).getDestinationVertexName();
     doReturn("task").when(context).getTaskVertexName();
     return context;
   }
   
-  private TezProcessorContext createTestProcessortContext() {
+  protected TezProcessorContext createTestProcessortContext() {
     TezProcessorContext context = mock(TezProcessorContext.class);
     doReturn("task").when(context).getTaskVertexName();
     return context;

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/792cda59/tez-runtime-library/pom.xml
----------------------------------------------------------------------
diff --git a/tez-runtime-library/pom.xml b/tez-runtime-library/pom.xml
index 88becde..c14fc6d 100644
--- a/tez-runtime-library/pom.xml
+++ b/tez-runtime-library/pom.xml
@@ -39,6 +39,12 @@
       <scope>test</scope>
     </dependency>
     <dependency>
+      <groupId>org.apache.tez</groupId>
+      <artifactId>tez-runtime-internals</artifactId>
+      <scope>test</scope>
+      <type>test-jar</type>
+    </dependency>
+    <dependency>
       <groupId>org.mockito</groupId>
       <artifactId>mockito-all</artifactId>
       <scope>test</scope>

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/792cda59/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/resources/WeightedScalingMemoryDistributor.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/resources/WeightedScalingMemoryDistributor.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/resources/WeightedScalingMemoryDistributor.java
new file mode 100644
index 0000000..c295601
--- /dev/null
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/resources/WeightedScalingMemoryDistributor.java
@@ -0,0 +1,283 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.resources;
+
+import java.text.DecimalFormat;
+import java.util.EnumMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.classification.InterfaceAudience.Private;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.common.TezJobConfig;
+import org.apache.tez.runtime.common.resources.InitialMemoryAllocator;
+import org.apache.tez.runtime.common.resources.InitialMemoryRequestContext;
+import org.apache.tez.runtime.library.input.ShuffledMergedInput;
+import org.apache.tez.runtime.library.input.ShuffledMergedInputLegacy;
+import org.apache.tez.runtime.library.input.ShuffledUnorderedKVInput;
+import org.apache.tez.runtime.library.output.OnFileSortedOutput;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+/**
+ * Distributes memory between various requesting components by applying a
+ * weighted scaling function. Overall, ensures that all requestors stay within the JVM limits.
+ * 
+ * Configuration involves specifying weights for the different Inputs available
+ * in the tez-runtime-library. As an example, SortedShuffle : SortedOutput :
+ * UnsortedShuffle could be configured to be 20:10:1. In this case, if both
+ * SortedShuffle and UnsortedShuffle ask for the same amount of initial memory,
+ * SortedShuffle will be given 20 times more; both may be scaled down to fit within the JVM though.
+ * 
+ */
+public class WeightedScalingMemoryDistributor implements InitialMemoryAllocator {
+
+  private static final Log LOG = LogFactory.getLog(WeightedScalingMemoryDistributor.class);
+
+  @VisibleForTesting
+  static final double DEFAULT_RESERVE_FRACTION = 0.25d;
+
+  static final double MAX_ADDITIONAL_RESERVATION_FRACTION_PER_IO = 0.3d;
+  static final double RESERVATION_FRACTION_PER_IO = 0.025d;
+
+  private Configuration conf;
+
+  public WeightedScalingMemoryDistributor() {
+  }
+
+  @Private
+  @VisibleForTesting
+  public enum RequestType {
+    PARTITIONED_UNSORTED_OUTPUT, UNSORTED_INPUT, SORTED_OUTPUT, SORTED_MERGED_INPUT, PROCESSOR, OTHER
+  };
+
+  private EnumMap<RequestType, Integer> typeScaleMap = Maps.newEnumMap(RequestType.class);
+
+  private int numRequests = 0;
+  private int numRequestsScaled = 0;
+  private long totalRequested = 0;
+
+  private List<Request> requests = Lists.newArrayList();
+
+  @Override
+  public Iterable<Long> assignMemory(long availableForAllocation, int numTotalInputs,
+      int numTotalOutputs, Iterable<InitialMemoryRequestContext> initialRequests) {
+
+    // Read in configuration
+    populateTypeScaleMap();
+
+    for (InitialMemoryRequestContext context : initialRequests) {
+      initialProcessMemoryRequestContext(context);
+    }
+
+    if (numRequestsScaled == 0) {
+      // Fall back to regular scaling. e.g. BROADCAST : SHUFFLE = 0:1. 
+      // i.e. if Shuffle present, Broadcast gets nothing, but otherwise it
+      // should get an allocation
+      numRequestsScaled = numRequests;
+      for (Request request : requests) {
+        request.requestWeight = 1;
+      }
+    }
+
+    // Scale down while adding requests - don't want to hit Long limits.
+    double totalScaledRequest = 0d;
+    for (Request request : requests) {
+      double requested = request.requestSize * (request.requestWeight / (double) numRequestsScaled);
+      totalScaledRequest += requested;
+    }
+
+    // Take a certain amount of memory away for general usage.
+    double reserveFraction = computeReservedFraction(numRequests);
+
+    Preconditions.checkState(reserveFraction >= 0.0d && reserveFraction <= 1.0d);
+    availableForAllocation = (long) (availableForAllocation - (reserveFraction * availableForAllocation));
+
+    long totalJvmMem = Runtime.getRuntime().maxMemory();
+    double ratio = totalRequested / (double) totalJvmMem;
+
+    LOG.info("Scaling Requests. NumRequests: " + numRequests + ", numScaledRequests: "
+        + numRequestsScaled + ", TotalRequested: " + totalRequested + ", TotalRequestedScaled: "
+        + totalScaledRequest + ", TotalJVMHeap: " + totalJvmMem + ", TotalAvailable: "
+        + availableForAllocation + ", TotalRequested/TotalJVMHeap:"
+        + new DecimalFormat("0.00").format(ratio));
+
+    // Actual scaling
+    List<Long> allocations = Lists.newArrayListWithCapacity(numRequests);
+    for (Request request : requests) {
+      if (request.requestSize == 0) {
+        allocations.add(0l);
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Scaling requested " + request.componentClassname + " of type "
+              + request.requestType + " 0 to allocated: 0");
+        }
+      } else {
+        double requestFactor = request.requestWeight / (double) numRequestsScaled;
+        double scaledRequest = requestFactor * request.requestSize;
+        long allocated = Math.min(
+            (long) ((scaledRequest / totalScaledRequest) * availableForAllocation),
+            request.requestSize);
+        // TODO Later - If requestedSize is used, the difference (allocated -
+        // requestedSize) could be allocated to others.
+        allocations.add(allocated);
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Scaling requested " + request.componentClassname + " of type "
+              + request.requestType + " " + request.requestSize + "  to allocated: " + allocated);
+        }
+      }
+    }
+    return allocations;
+
+  }
+
+  private void initialProcessMemoryRequestContext(InitialMemoryRequestContext context) {
+    RequestType requestType;
+    numRequests++;
+    totalRequested += context.getRequestedSize();
+    String className = context.getComponentClassName();
+    requestType = getRequestTypeForClass(className);
+    Integer typeScaleFactor = getScaleFactorForType(requestType);
+
+    Request request = new Request(context.getComponentClassName(), context.getRequestedSize(),
+        requestType, typeScaleFactor);
+    requests.add(request);
+    LOG.info("ScaleFactor: " + typeScaleFactor + ", for type: " + requestType);
+    numRequestsScaled += typeScaleFactor;
+  }
+
+  private Integer getScaleFactorForType(RequestType requestType) {
+    Integer typeScaleFactor = typeScaleMap.get(requestType);
+    if (typeScaleFactor == null) {
+      LOG.warn("Bad scale factor for requestType: " + requestType + ", Using factor 0");
+      typeScaleFactor = 0;
+    }
+    return typeScaleFactor;
+  }
+
+  private RequestType getRequestTypeForClass(String className) {
+    RequestType requestType;
+    if (className.equals(OnFileSortedOutput.class.getName())) {
+      requestType = RequestType.SORTED_OUTPUT;
+    } else if (className.equals(ShuffledMergedInput.class.getName())
+        || className.equals(ShuffledMergedInputLegacy.class.getName())) {
+      requestType = RequestType.SORTED_MERGED_INPUT;
+    } else if (className.equals(ShuffledUnorderedKVInput.class.getName())) {
+      requestType = RequestType.UNSORTED_INPUT;
+    } else {
+      requestType = RequestType.OTHER;
+      LOG.info("Falling back to RequestType.OTHER for class: " + className);
+    }
+    return requestType;
+  }
+
+  private void populateTypeScaleMap() {
+    String[] ratios = conf.getStrings(TezJobConfig.TEZ_TASK_SCALE_MEMORY_WEIGHTED_RATIOS);
+    int numExpectedValues = RequestType.values().length;
+    if (ratios == null) {
+      LOG.info("No ratio specified. Falling back to Linear scaling");
+      ratios = new String[numExpectedValues];
+      int i = 0;
+      for (RequestType requestType : RequestType.values()) {
+        ratios[i] = requestType.name() + ":1"; // Linear scale
+        i++;
+      }
+    } else {
+      if (ratios.length != RequestType.values().length) {
+        throw new IllegalArgumentException(
+            "Number of entries in the configured ratios should be equal to the number of entries in RequestType: "
+                + numExpectedValues);
+      }
+    }
+
+    Set<RequestType> seenTypes = new HashSet<RequestType>();
+
+    for (String ratio : ratios) {
+      String[] parts = ratio.split(":");
+      Preconditions.checkState(parts.length == 2);
+      RequestType requestType = RequestType.valueOf(parts[0]);
+      Integer ratioVal = Integer.parseInt(parts[1]);
+      if (!seenTypes.add(requestType)) {
+        throw new IllegalArgumentException("Cannot configure the same RequestType: " + requestType
+            + " multiple times");
+      }
+      Preconditions.checkState(ratioVal >= 0, "Ratio must be >= 0");
+      typeScaleMap.put(requestType, ratioVal);
+    }
+  }
+
+  private double computeReservedFraction(int numTotalRequests) {
+
+    double reserveFractionPerIo = conf.getDouble(
+        TezJobConfig.TEZ_TASK_SCALE_MEMORY_ADDITIONAL_RESERVATION_FRACTION_PER_IO,
+        RESERVATION_FRACTION_PER_IO);
+    double maxAdditionalReserveFraction = conf.getDouble(
+        TezJobConfig.TEZ_TASK_SCALE_MEMORY_ADDITIONAL_RESERVATION_FRACTION_MAX,
+        MAX_ADDITIONAL_RESERVATION_FRACTION_PER_IO);
+    Preconditions.checkArgument(maxAdditionalReserveFraction >= 0f
+        && maxAdditionalReserveFraction <= 1f);
+    Preconditions.checkArgument(reserveFractionPerIo <= maxAdditionalReserveFraction
+        && reserveFractionPerIo >= 0f);
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("ReservationFractionPerIO=" + reserveFractionPerIo + ", MaxPerIOReserveFraction="
+          + maxAdditionalReserveFraction);
+    }
+
+    double initialReserveFraction = conf.getDouble(TezJobConfig.TEZ_TASK_SCALE_MEMORY_RESERVE_FRACTION,
+        DEFAULT_RESERVE_FRACTION);
+    double additionalReserveFraction = Math.min(maxAdditionalReserveFraction, numTotalRequests
+        * reserveFractionPerIo);
+
+    double reserveFraction = initialReserveFraction + additionalReserveFraction;
+    Preconditions.checkState(reserveFraction <= 1.0d);
+    LOG.info("InitialReservationFraction=" + initialReserveFraction
+        + ", AdditionalReservationFractionForIOs=" + additionalReserveFraction
+        + ", finalReserveFractionUsed=" + reserveFraction);
+    return reserveFraction;
+  }
+
+  @Override
+  public void setConf(Configuration conf) {
+    this.conf = conf;
+  }
+
+  @Override
+  public Configuration getConf() {
+    return this.conf;
+  }
+
+  private static class Request {
+    Request(String componentClassname, long requestSize, RequestType requestType, int requestWeight) {
+      this.componentClassname = componentClassname;
+      this.requestSize = requestSize;
+      this.requestType = requestType;
+      this.requestWeight = requestWeight;
+    }
+
+    String componentClassname;
+    long requestSize;
+    private RequestType requestType;
+    private int requestWeight;
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-tez/blob/792cda59/tez-runtime-library/src/test/java/org/apache/tez/runtime/common/resources/TestWeightedScalingMemoryDistributor.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/common/resources/TestWeightedScalingMemoryDistributor.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/common/resources/TestWeightedScalingMemoryDistributor.java
new file mode 100644
index 0000000..2b6e9c4
--- /dev/null
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/common/resources/TestWeightedScalingMemoryDistributor.java
@@ -0,0 +1,181 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.common.resources;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.common.TezJobConfig;
+import org.apache.tez.dag.api.InputDescriptor;
+import org.apache.tez.dag.api.OutputDescriptor;
+import org.apache.tez.runtime.api.LogicalInput;
+import org.apache.tez.runtime.api.LogicalOutput;
+import org.apache.tez.runtime.api.MemoryUpdateCallback;
+import org.apache.tez.runtime.api.TezInputContext;
+import org.apache.tez.runtime.api.TezOutputContext;
+import org.apache.tez.runtime.library.input.ShuffledMergedInput;
+import org.apache.tez.runtime.library.input.ShuffledUnorderedKVInput;
+import org.apache.tez.runtime.library.output.OnFileSortedOutput;
+import org.apache.tez.runtime.library.resources.WeightedScalingMemoryDistributor;
+import org.apache.tez.runtime.library.resources.WeightedScalingMemoryDistributor.RequestType;
+import org.junit.Test;
+
+import com.google.common.base.Joiner;
+
+public class TestWeightedScalingMemoryDistributor extends TestMemoryDistributor {
+  
+  @Override
+  public void setup() {
+    conf.setBoolean(TezJobConfig.TEZ_TASK_SCALE_MEMORY_ENABLED, true);
+    conf.set(TezJobConfig.TEZ_TASK_SCALE_MEMORY_ALLOCATOR_CLASS,
+        WeightedScalingMemoryDistributor.class.getName());
+    conf.setDouble(TezJobConfig.TEZ_TASK_SCALE_MEMORY_RESERVE_FRACTION, 0.3d);
+    conf.setDouble(TezJobConfig.TEZ_TASK_SCALE_MEMORY_ADDITIONAL_RESERVATION_FRACTION_PER_IO, 0.0d);
+  }
+  
+  @Test(timeout = 5000)
+  public void testSimpleWeightedScaling() {
+    Configuration conf = new Configuration(this.conf);
+    conf.setStrings(TezJobConfig.TEZ_TASK_SCALE_MEMORY_WEIGHTED_RATIOS,
+        generateWeightStrings(1, 2, 3, 1, 1));
+    System.err.println(Joiner.on(",").join(conf.getStringCollection(TezJobConfig.TEZ_TASK_SCALE_MEMORY_WEIGHTED_RATIOS)));
+
+    MemoryDistributor dist = new MemoryDistributor(2, 2, conf);
+
+    dist.setJvmMemory(10000l);
+
+    // First request - ScatterGatherShuffleInput
+    MemoryUpdateCallbackForTest e1Callback = new MemoryUpdateCallbackForTest();
+    TezInputContext e1InputContext1 = createTestInputContext();
+    InputDescriptor e1InDesc1 = createTestInputDescriptor(ShuffledMergedInput.class);
+    dist.requestMemory(10000, e1Callback, e1InputContext1, e1InDesc1);
+
+    // Second request - BroadcastInput
+    MemoryUpdateCallbackForTest e2Callback = new MemoryUpdateCallbackForTest();
+    TezInputContext e2InputContext2 = createTestInputContext();
+    InputDescriptor e2InDesc2 = createTestInputDescriptor(ShuffledUnorderedKVInput.class);
+    dist.requestMemory(10000, e2Callback, e2InputContext2, e2InDesc2);
+
+    // Third request - randomOutput (simulates MROutput)
+    MemoryUpdateCallbackForTest e3Callback = new MemoryUpdateCallbackForTest();
+    TezOutputContext e3OutputContext1 = createTestOutputContext();
+    OutputDescriptor e3OutDesc1 = createTestOutputDescriptor();
+    dist.requestMemory(10000, e3Callback, e3OutputContext1, e3OutDesc1);
+
+    // Fourth request - OnFileSortedOutput
+    MemoryUpdateCallbackForTest e4Callback = new MemoryUpdateCallbackForTest();
+    TezOutputContext e4OutputContext2 = createTestOutputContext();
+    OutputDescriptor e4OutDesc2 = createTestOutputDescriptor(OnFileSortedOutput.class);
+    dist.requestMemory(10000, e4Callback, e4OutputContext2, e4OutDesc2);
+
+    dist.makeInitialAllocations();
+
+    // Total available: 70% of 10K = 7000
+    // 4 requests (weight) - 10K (3), 10K(1), 10K(1), 10K(2)
+    // Scale down to - 3000, 1000, 1000, 2000
+    assertEquals(3000, e1Callback.assigned);
+    assertEquals(1000, e2Callback.assigned);
+    assertEquals(1000, e3Callback.assigned);
+    assertEquals(2000, e4Callback.assigned);
+  }
+
+  @Test(timeout = 5000)
+  public void testAdditionalReserveFractionWeightedScaling() {
+    Configuration conf = new Configuration(this.conf);
+    conf.setStrings(TezJobConfig.TEZ_TASK_SCALE_MEMORY_WEIGHTED_RATIOS,
+        generateWeightStrings(2, 3, 6, 1, 1));
+    conf.setDouble(TezJobConfig.TEZ_TASK_SCALE_MEMORY_ADDITIONAL_RESERVATION_FRACTION_PER_IO, 0.025d);
+    conf.setDouble(TezJobConfig.TEZ_TASK_SCALE_MEMORY_ADDITIONAL_RESERVATION_FRACTION_MAX, 0.2d);
+
+    MemoryDistributor dist = new MemoryDistributor(2, 2, conf);
+
+    dist.setJvmMemory(10000l);
+
+    // First request - ScatterGatherShuffleInput [weight 6]
+    MemoryUpdateCallbackForTest e1Callback = new MemoryUpdateCallbackForTest();
+    TezInputContext e1InputContext1 = createTestInputContext();
+    InputDescriptor e1InDesc1 = createTestInputDescriptor(ShuffledMergedInput.class);
+    dist.requestMemory(10000, e1Callback, e1InputContext1, e1InDesc1);
+
+    // Second request - BroadcastInput [weight 2]
+    MemoryUpdateCallbackForTest e2Callback = new MemoryUpdateCallbackForTest();
+    TezInputContext e2InputContext2 = createTestInputContext();
+    InputDescriptor e2InDesc2 = createTestInputDescriptor(ShuffledUnorderedKVInput.class);
+    dist.requestMemory(10000, e2Callback, e2InputContext2, e2InDesc2);
+
+    // Third request - randomOutput (simulates MROutput) [weight 1]
+    MemoryUpdateCallbackForTest e3Callback = new MemoryUpdateCallbackForTest();
+    TezOutputContext e3OutputContext1 = createTestOutputContext();
+    OutputDescriptor e3OutDesc1 = createTestOutputDescriptor();
+    dist.requestMemory(10000, e3Callback, e3OutputContext1, e3OutDesc1);
+
+    // Fourth request - OnFileSortedOutput [weight 3]
+    MemoryUpdateCallbackForTest e4Callback = new MemoryUpdateCallbackForTest();
+    TezOutputContext e4OutputContext2 = createTestOutputContext();
+    OutputDescriptor e4OutDesc2 = createTestOutputDescriptor(OnFileSortedOutput.class);
+    dist.requestMemory(10000, e4Callback, e4OutputContext2, e4OutDesc2);
+
+    dist.makeInitialAllocations();
+
+    // Total available: 60% of 10K = 7000
+    // 4 requests (weight) - 10K (6), 10K(2), 10K(1), 10K(3)
+    // Scale down to - 3000, 1000, 500, 1500
+    assertEquals(3000, e1Callback.assigned);
+    assertEquals(1000, e2Callback.assigned);
+    assertEquals(500, e3Callback.assigned);
+    assertEquals(1500, e4Callback.assigned);
+  }
+  
+  private static class MemoryUpdateCallbackForTest implements MemoryUpdateCallback {
+
+    long assigned = -1000;
+
+    @Override
+    public void memoryAssigned(long assignedSize) {
+      this.assigned = assignedSize;
+    }
+  }
+
+  private InputDescriptor createTestInputDescriptor(Class<? extends LogicalInput> inputClazz) {
+    InputDescriptor desc = mock(InputDescriptor.class);
+    doReturn(inputClazz.getName()).when(desc).getClassName();
+    return desc;
+  }
+
+  private OutputDescriptor createTestOutputDescriptor(Class<? extends LogicalOutput> outputClazz) {
+    OutputDescriptor desc = mock(OutputDescriptor.class);
+    doReturn(outputClazz.getName()).when(desc).getClassName();
+    return desc;
+  }
+
+  private String[] generateWeightStrings(int broadcastIn, int sortedOut,
+      int scatterGatherShuffleIn, int proc, int other) {
+    String[] weights = new String[RequestType.values().length];
+    weights[0] = RequestType.PARTITIONED_UNSORTED_OUTPUT + ":" + 0;
+    weights[1] = RequestType.UNSORTED_INPUT.name() + ":" + broadcastIn;
+    weights[2] = RequestType.SORTED_OUTPUT.name() + ":" + sortedOut;
+    weights[3] = RequestType.SORTED_MERGED_INPUT.name() + ":" + scatterGatherShuffleIn;
+    weights[4] = RequestType.PROCESSOR.name() + ":" + proc;
+    weights[5] = RequestType.OTHER.name() + ":" + other;
+    return weights;
+  }
+
+}