You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mrunit.apache.org by br...@apache.org on 2012/07/01 21:34:19 UTC

svn commit: r1355989 - in /mrunit/trunk/src: main/java/org/apache/hadoop/mrunit/ main/java/org/apache/hadoop/mrunit/internal/mapreduce/ main/java/org/apache/hadoop/mrunit/mapreduce/ test/java/org/apache/hadoop/mrunit/mapreduce/

Author: brock
Date: Sun Jul  1 19:34:18 2012
New Revision: 1355989

URL: http://svn.apache.org/viewvc?rev=1355989&view=rev
Log:
MRUNIT-122: Context should be mockable

Modified:
    mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/MapDriverBase.java
    mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/AbstractMockContextWrapper.java
    mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/MockMapContextWrapper.java
    mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/MockReduceContextWrapper.java
    mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/MapDriver.java
    mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/MapReduceDriver.java
    mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/ReduceDriver.java
    mrunit/trunk/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestMapDriver.java
    mrunit/trunk/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestReduceDriver.java

Modified: mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/MapDriverBase.java
URL: http://svn.apache.org/viewvc/mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/MapDriverBase.java?rev=1355989&r1=1355988&r2=1355989&view=diff
==============================================================================
--- mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/MapDriverBase.java (original)
+++ mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/MapDriverBase.java Sun Jul  1 19:34:18 2012
@@ -18,6 +18,7 @@
 package org.apache.hadoop.mrunit;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.commons.logging.Log;
@@ -39,6 +40,7 @@ public abstract class MapDriverBase<K1, 
 
   public static final Log LOG = LogFactory.getLog(MapDriverBase.class);
 
+  protected List<Pair<K1, V1>> inputs = new ArrayList<Pair<K1, V1>>();
   protected K1 inputKey;
   protected V1 inputVal;
 

Modified: mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/AbstractMockContextWrapper.java
URL: http://svn.apache.org/viewvc/mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/AbstractMockContextWrapper.java?rev=1355989&r1=1355988&r2=1355989&view=diff
==============================================================================
--- mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/AbstractMockContextWrapper.java (original)
+++ mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/AbstractMockContextWrapper.java Sun Jul  1 19:34:18 2012
@@ -17,48 +17,44 @@
  */
 package org.apache.hadoop.mrunit.internal.mapreduce;
 
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyString;
-import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.when;
+import static org.mockito.Matchers.*;
+import static org.mockito.Mockito.*;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.mapreduce.Counter;
-import org.apache.hadoop.mapreduce.Counters;
 import org.apache.hadoop.mapreduce.TaskInputOutputContext;
+import org.apache.hadoop.mrunit.internal.output.MockOutputCreator;
 import org.apache.hadoop.mrunit.internal.output.OutputCollectable;
+import org.apache.hadoop.mrunit.types.Pair;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
-abstract class AbstractMockContextWrapper<KEYIN, VALUEIN, KEYOUT, VALUEOUT, CONTEXT extends TaskInputOutputContext<KEYIN, VALUEIN, KEYOUT, VALUEOUT>> {
+abstract class AbstractMockContextWrapper<KEYIN, VALUEIN, KEYOUT, VALUEOUT, CONTEXT 
+extends TaskInputOutputContext<KEYIN, VALUEIN, KEYOUT, VALUEOUT>> {
 
-  protected final Counters counters;
-  protected final Configuration conf;
+  protected CONTEXT context;
+  protected final MockOutputCreator<KEYOUT, VALUEOUT> mockOutputCreator;
+  protected OutputCollectable<KEYOUT, VALUEOUT> outputCollectable;
 
-  protected final CONTEXT context;
-  private final OutputCollectable<KEYOUT, VALUEOUT> outputCollectable;
-
-  public AbstractMockContextWrapper(final Counters counters,
-      final Configuration conf,
-      final OutputCollectable<KEYOUT, VALUEOUT> outputCollectable)
-      throws IOException, InterruptedException {
-    this.conf = conf;
-    this.counters = counters;
-    this.outputCollectable = outputCollectable;
-    context = create();
+  public AbstractMockContextWrapper(final MockOutputCreator<KEYOUT, VALUEOUT> mockOutputCreator) {
+    this.mockOutputCreator = mockOutputCreator;
   }
 
   @SuppressWarnings({ "rawtypes", "unchecked" })
   protected void createCommon(
-      final TaskInputOutputContext<KEYIN, VALUEIN, KEYOUT, VALUEOUT> context)
-      throws IOException, InterruptedException {
+      final TaskInputOutputContext context,
+      final ContextDriver contextDriver,
+      final MockOutputCreator mockOutputCreator) {
+        
     when(context.getCounter((Enum) any())).thenAnswer(new Answer<Counter>() {
       @Override
       public Counter answer(final InvocationOnMock invocation) {
         final Object[] args = invocation.getArguments();
-        return counters.findCounter((Enum) args[0]);
+        return contextDriver.getCounters().findCounter((Enum) args[0]);
       }
     });
     when(context.getCounter(anyString(), anyString())).thenAnswer(
@@ -66,31 +62,48 @@ abstract class AbstractMockContextWrappe
           @Override
           public Counter answer(final InvocationOnMock invocation) {
             final Object[] args = invocation.getArguments();
-            return counters.findCounter((String) args[0], (String) args[1]);
+            return contextDriver.getCounters().findCounter((String) args[0], (String) args[1]);
           }
-        });
+    });
     when(context.getConfiguration()).thenAnswer(new Answer<Configuration>() {
       @Override
       public Configuration answer(final InvocationOnMock invocation) {
-        return conf;
+        return contextDriver.getConfiguration();
       }
     });
-    doAnswer(new Answer<Object>() {
-      @Override
-      public Object answer(final InvocationOnMock invocation) {
-        final Object[] args = invocation.getArguments();
-        try {
-          outputCollectable.collect((KEYOUT) args[0], (VALUEOUT) args[1]);
-        } catch (IOException e) {
-          throw new RuntimeException(e);
+    try {
+      doAnswer(new Answer<Object>() {
+        @Override
+        public Object answer(final InvocationOnMock invocation) {
+          final Object[] args = invocation.getArguments();
+          try {
+            if(outputCollectable == null) {
+              outputCollectable = mockOutputCreator.createOutputCollectable(contextDriver.getConfiguration(), 
+                  contextDriver.getOutputCopyingOrInputFormatConfiguration());
+            }
+            outputCollectable.collect((KEYOUT)args[0], (VALUEOUT)args[1]);
+          } catch (IOException e) {
+            throw new RuntimeException(e);
+          }
+          return null;
         }
-        return null;
-      }
-    }).when(context).write((KEYOUT) any(), (VALUEOUT) any());
+      }).when(context).write(any(), any());
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    } catch (InterruptedException e) {
+      throw new RuntimeException(e);
+    }
   }
 
   protected abstract CONTEXT create() throws IOException, InterruptedException;
 
+  public List<Pair<KEYOUT, VALUEOUT>> getOutputs() throws IOException {
+    if(outputCollectable == null) {
+      return new ArrayList<Pair<KEYOUT, VALUEOUT>>();
+    }
+    return outputCollectable.getOutputs();
+  }
+  
   public CONTEXT getMockContext() {
     return context;
   }

Modified: mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/MockMapContextWrapper.java
URL: http://svn.apache.org/viewvc/mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/MockMapContextWrapper.java?rev=1355989&r1=1355988&r2=1355989&view=diff
==============================================================================
--- mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/MockMapContextWrapper.java (original)
+++ mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/MockMapContextWrapper.java Sun Jul  1 19:34:18 2012
@@ -18,20 +18,17 @@
 
 package org.apache.hadoop.mrunit.internal.mapreduce;
 
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.*;
 
 import java.io.IOException;
 import java.util.List;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.mapreduce.Counters;
 import org.apache.hadoop.mapreduce.InputSplit;
 import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.hadoop.mrunit.internal.output.OutputCollectable;
+import org.apache.hadoop.mrunit.internal.output.MockOutputCreator;
+import org.apache.hadoop.mrunit.mapreduce.MapDriver;
 import org.apache.hadoop.mrunit.types.Pair;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
@@ -46,65 +43,71 @@ import org.mockito.stubbing.Answer;
  * This wrapper class exists for that purpose.
  */
 public class MockMapContextWrapper<KEYIN, VALUEIN, KEYOUT, VALUEOUT>
-    extends
-    AbstractMockContextWrapper<KEYIN, VALUEIN, KEYOUT, VALUEOUT, Mapper<KEYIN, VALUEIN, KEYOUT, VALUEOUT>.Context> {
+    extends AbstractMockContextWrapper<KEYIN, VALUEIN, KEYOUT, VALUEOUT, Mapper<KEYIN, VALUEIN, KEYOUT, VALUEOUT>.Context> {
 
   protected static final Log LOG = LogFactory
       .getLog(MockMapContextWrapper.class);
 
   protected final List<Pair<KEYIN, VALUEIN>> inputs;
+  protected final MapDriver<KEYIN, VALUEIN, KEYOUT, VALUEOUT> driver;
+  
   protected Pair<KEYIN, VALUEIN> currentKeyValue;
-  protected InputSplit inputSplit;
-
+  
   public MockMapContextWrapper(final List<Pair<KEYIN, VALUEIN>> inputs,
-      final Counters counters, final Configuration conf,
-      final OutputCollectable<KEYOUT, VALUEOUT> outputCollectable,
-      final Path mapInputPath)
-      throws IOException, InterruptedException {
-    super(counters, conf, outputCollectable);
+      final MockOutputCreator<KEYOUT, VALUEOUT> mockOutputCreator,
+      final MapDriver<KEYIN, VALUEIN, KEYOUT, VALUEOUT> driver) {
+    super(mockOutputCreator);
     this.inputs = inputs;
-    this.inputSplit = new MockInputSplit(mapInputPath);
+    this.driver = driver;
+    context = create();
+
   }
 
+  @Override
   @SuppressWarnings({ "unchecked" })
-  protected Mapper<KEYIN, VALUEIN, KEYOUT, VALUEOUT>.Context create()
-      throws IOException, InterruptedException {
+  protected Mapper<KEYIN, VALUEIN, KEYOUT, VALUEOUT>.Context create() {
     final Mapper<KEYIN, VALUEIN, KEYOUT, VALUEOUT>.Context context = mock(org.apache.hadoop.mapreduce.Mapper.Context.class);
 
-    createCommon(context);
+    createCommon(context, driver, mockOutputCreator);
 
-    /*
-     * In actual context code nextKeyValue() modifies the current state so we
-     * can here as well.
-     */
-    when(context.nextKeyValue()).thenAnswer(new Answer<Boolean>() {
-      @Override
-      public Boolean answer(final InvocationOnMock invocation) {
-        if (inputs.size() > 0) {
-          currentKeyValue = inputs.remove(0);
-          return true;
-        } else {
-          currentKeyValue = null;
-          return false;
+    try {
+      /*
+       * In actual context code nextKeyValue() modifies the current state so we
+       * can here as well.
+       */
+      when(context.nextKeyValue()).thenAnswer(new Answer<Boolean>() {
+        @Override
+        public Boolean answer(final InvocationOnMock invocation) {
+          if (inputs.size() > 0) {
+            currentKeyValue = inputs.remove(0);
+            return true;
+          } else {
+            currentKeyValue = null;
+            return false;
+          }
         }
-      }
-    });
-    when(context.getCurrentKey()).thenAnswer(new Answer<KEYIN>() {
-      @Override
-      public KEYIN answer(final InvocationOnMock invocation) {
-        return currentKeyValue.getFirst();
-      }
-    });
-    when(context.getCurrentValue()).thenAnswer(new Answer<VALUEIN>() {
-      @Override
-      public VALUEIN answer(final InvocationOnMock invocation) {
-        return currentKeyValue.getSecond();
-      }
-    });
+      });
+      when(context.getCurrentKey()).thenAnswer(new Answer<KEYIN>() {
+        @Override
+        public KEYIN answer(final InvocationOnMock invocation) {
+          return currentKeyValue.getFirst();
+        }
+      });
+      when(context.getCurrentValue()).thenAnswer(new Answer<VALUEIN>() {
+        @Override
+        public VALUEIN answer(final InvocationOnMock invocation) {
+          return currentKeyValue.getSecond();
+        }
+      });
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    } catch (InterruptedException e) {
+      throw new RuntimeException(e);
+    }
     when(context.getInputSplit()).thenAnswer(new Answer<InputSplit>() {
       @Override
       public InputSplit answer(InvocationOnMock invocation) throws Throwable {
-        return inputSplit;
+        return new MockInputSplit(driver.getMapInputPath());
       }
     });
     return context;

Modified: mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/MockReduceContextWrapper.java
URL: http://svn.apache.org/viewvc/mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/MockReduceContextWrapper.java?rev=1355989&r1=1355988&r2=1355989&view=diff
==============================================================================
--- mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/MockReduceContextWrapper.java (original)
+++ mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/internal/mapreduce/MockReduceContextWrapper.java Sun Jul  1 19:34:18 2012
@@ -18,8 +18,7 @@
 
 package org.apache.hadoop.mrunit.internal.mapreduce;
 
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.*;
 
 import java.io.IOException;
 import java.util.Iterator;
@@ -27,10 +26,9 @@ import java.util.List;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.mapreduce.Counters;
 import org.apache.hadoop.mapreduce.Reducer;
-import org.apache.hadoop.mrunit.internal.output.OutputCollectable;
+import org.apache.hadoop.mrunit.internal.output.MockOutputCreator;
+import org.apache.hadoop.mrunit.mapreduce.ReduceDriver;
 import org.apache.hadoop.mrunit.types.Pair;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
@@ -51,51 +49,61 @@ public class MockReduceContextWrapper<KE
   protected static final Log LOG = LogFactory
       .getLog(MockReduceContextWrapper.class);
   protected final List<Pair<KEYIN, List<VALUEIN>>> inputs;
+  protected final ReduceDriver<KEYIN, VALUEIN, KEYOUT, VALUEOUT> driver;
+  
   protected Pair<KEYIN, List<VALUEIN>> currentKeyValue;
-
+  
   public MockReduceContextWrapper(
-      final List<Pair<KEYIN, List<VALUEIN>>> inputs, final Counters counters,
-      final Configuration conf,
-      OutputCollectable<KEYOUT, VALUEOUT> outputCollectable)
-      throws IOException, InterruptedException {
-    super(counters, conf, outputCollectable);
+      final List<Pair<KEYIN, List<VALUEIN>>> inputs, 
+      final MockOutputCreator<KEYOUT, VALUEOUT> mockOutputCreator,
+      final ReduceDriver<KEYIN, VALUEIN, KEYOUT, VALUEOUT> driver) {
+    super(mockOutputCreator);
     this.inputs = inputs;
+    this.driver = driver;
+    context = create();
   }
 
+  @Override
   @SuppressWarnings({ "unchecked" })
-  protected Reducer<KEYIN, VALUEIN, KEYOUT, VALUEOUT>.Context create()
-      throws IOException, InterruptedException {
+  protected Reducer<KEYIN, VALUEIN, KEYOUT, VALUEOUT>.Context create() {
 
     final Reducer<KEYIN, VALUEIN, KEYOUT, VALUEOUT>.Context context = mock(org.apache.hadoop.mapreduce.Reducer.Context.class);
-    createCommon(context);
-    /*
-     * In actual context code nextKeyValue() modifies the current state so we
-     * can here as well.
-     */
-    when(context.nextKey()).thenAnswer(new Answer<Boolean>() {
-      @Override
-      public Boolean answer(final InvocationOnMock invocation) {
-        if (inputs.size() > 0) {
-          currentKeyValue = inputs.remove(0);
-          return true;
-        } else {
-          currentKeyValue = null;
-          return false;
+    
+    createCommon(context, driver, mockOutputCreator);
+    try {
+      /*
+       * In actual context code nextKeyValue() modifies the current state so we
+       * can here as well.
+       */
+      when(context.nextKey()).thenAnswer(new Answer<Boolean>() {
+        @Override
+        public Boolean answer(final InvocationOnMock invocation) {
+          if (inputs.size() > 0) {
+            currentKeyValue = inputs.remove(0);
+            return true;
+          } else {
+            currentKeyValue = null;
+            return false;
+          }
         }
-      }
-    });
-    when(context.getCurrentKey()).thenAnswer(new Answer<KEYIN>() {
-      @Override
-      public KEYIN answer(final InvocationOnMock invocation) {
-        return currentKeyValue.getFirst();
-      }
-    });
-    when(context.getValues()).thenAnswer(new Answer<Iterable<VALUEIN>>() {
-      @Override
-      public Iterable<VALUEIN> answer(final InvocationOnMock invocation) {
-        return makeOneUseIterator(currentKeyValue.getSecond().iterator());
-      }
-    });
+      });
+      when(context.getCurrentKey()).thenAnswer(new Answer<KEYIN>() {
+        @Override
+        public KEYIN answer(final InvocationOnMock invocation) {
+          return currentKeyValue.getFirst();
+        }
+      });
+      when(context.getValues()).thenAnswer(new Answer<Iterable<VALUEIN>>() {
+        @Override
+        public Iterable<VALUEIN> answer(final InvocationOnMock invocation) {
+          return makeOneUseIterator(currentKeyValue.getSecond().iterator());
+        }
+      });
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    } catch (InterruptedException e) {
+      throw new RuntimeException(e);
+    }
     return context;
   }
 

Modified: mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/MapDriver.java
URL: http://svn.apache.org/viewvc/mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/MapDriver.java?rev=1355989&r1=1355988&r2=1355989&view=diff
==============================================================================
--- mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/MapDriver.java (original)
+++ mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/MapDriver.java Sun Jul  1 19:34:18 2012
@@ -18,10 +18,9 @@
 
 package org.apache.hadoop.mrunit.mapreduce;
 
-import static org.apache.hadoop.mrunit.internal.util.ArgumentChecker.returnNonNull;
+import static org.apache.hadoop.mrunit.internal.util.ArgumentChecker.*;
 
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.commons.logging.Log;
@@ -33,11 +32,10 @@ import org.apache.hadoop.mapreduce.Input
 import org.apache.hadoop.mapreduce.Mapper;
 import org.apache.hadoop.mapreduce.OutputFormat;
 import org.apache.hadoop.mrunit.MapDriverBase;
-import org.apache.hadoop.mrunit.MapReduceDriver;
 import org.apache.hadoop.mrunit.internal.counters.CounterWrapper;
+import org.apache.hadoop.mrunit.internal.mapreduce.ContextDriver;
 import org.apache.hadoop.mrunit.internal.mapreduce.MockMapContextWrapper;
 import org.apache.hadoop.mrunit.internal.output.MockOutputCreator;
-import org.apache.hadoop.mrunit.internal.output.OutputCollectable;
 import org.apache.hadoop.mrunit.types.Pair;
 
 /**
@@ -49,7 +47,8 @@ import org.apache.hadoop.mrunit.types.Pa
  * (k, v)* case from the Mapper, representing a single unit test. Multiple input
  * (k, v) pairs should go in separate unit tests.
  */
-public class MapDriver<K1, V1, K2, V2> extends MapDriverBase<K1, V1, K2, V2> {
+public class MapDriver<K1, V1, K2, V2> 
+extends MapDriverBase<K1, V1, K2, V2> implements ContextDriver {
 
   public static final Log LOG = LogFactory.getLog(MapDriver.class);
 
@@ -57,6 +56,9 @@ public class MapDriver<K1, V1, K2, V2> e
   private Counters counters;
 
   private final MockOutputCreator<K2, V2> mockOutputCreator = new MockOutputCreator<K2, V2>();
+  private final MockMapContextWrapper<K1, V1, K2, V2> wrapper = new MockMapContextWrapper<K1, V1, K2, V2>(
+      inputs, mockOutputCreator,  this);
+
 
   public MapDriver(final Mapper<K1, V1, K2, V2> m) {
     this();
@@ -91,6 +93,7 @@ public class MapDriver<K1, V1, K2, V2> e
   }
 
   /** @return the counters used in this test */
+  @Override
   public Counters getCounters() {
     return counters;
   }
@@ -227,18 +230,13 @@ public class MapDriver<K1, V1, K2, V2> e
     if (myMapper == null) {
       throw new IllegalStateException("No Mapper class was provided");
     }
-
-    final List<Pair<K1, V1>> inputs = new ArrayList<Pair<K1, V1>>();
+    
+    inputs.clear();
     inputs.add(new Pair<K1, V1>(inputKey, inputVal));
 
     try {
-      final OutputCollectable<K2, V2> outputCollectable = mockOutputCreator
-          .createOutputCollectable(getConfiguration(),
-              getOutputCopyingOrInputFormatConfiguration());
-      final MockMapContextWrapper<K1, V1, K2, V2> wrapper = new MockMapContextWrapper<K1, V1, K2, V2>(
-          inputs, getCounters(), getConfiguration(), outputCollectable, getMapInputPath());
       myMapper.run(wrapper.getMockContext());
-      return outputCollectable.getOutputs();
+      return wrapper.getOutputs();
     } catch (final InterruptedException ie) {
       throw new IOException(ie);
     }
@@ -278,6 +276,34 @@ public class MapDriver<K1, V1, K2, V2> e
     super.withCounter(e, expectedValue);
     return this;
   }
+  
+  /**
+   * <p>Obtain Context object for furthering mocking with Mockito.
+   * For example, causing write() to throw an exception:</p>
+   * 
+   * <pre>
+   * import static org.mockito.Matchers.*;
+   * import static org.mockito.Mockito.*;
+   * doThrow(new IOException()).when(context).write(any(), any());
+   * </pre>
+   * 
+   * <p>Or implement other logic:</p>
+   * 
+   * <pre>
+   * import static org.mockito.Matchers.*;
+   * import static org.mockito.Mockito.*;
+   * doAnswer(new Answer<Object>() {
+   *    public Object answer(final InvocationOnMock invocation) {
+   *    ...
+   *     return null;
+   *   }
+   * }).when(context).write(any(), any());
+   * </pre>
+   * @return the mocked context
+   */
+  public Mapper<K1, V1, K2, V2>.Context getContext() {
+    return wrapper.getMockContext();
+  }
 
   @Override
   public MapDriver<K1, V1, K2, V2> withCounter(final String g, final String n,

Modified: mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/MapReduceDriver.java
URL: http://svn.apache.org/viewvc/mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/MapReduceDriver.java?rev=1355989&r1=1355988&r2=1355989&view=diff
==============================================================================
--- mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/MapReduceDriver.java (original)
+++ mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/MapReduceDriver.java Sun Jul  1 19:34:18 2012
@@ -17,7 +17,7 @@
  */
 package org.apache.hadoop.mrunit.mapreduce;
 
-import static org.apache.hadoop.mrunit.internal.util.ArgumentChecker.returnNonNull;
+import static org.apache.hadoop.mrunit.internal.util.ArgumentChecker.*;
 
 import java.io.IOException;
 import java.util.ArrayList;

Modified: mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/ReduceDriver.java
URL: http://svn.apache.org/viewvc/mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/ReduceDriver.java?rev=1355989&r1=1355988&r2=1355989&view=diff
==============================================================================
--- mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/ReduceDriver.java (original)
+++ mrunit/trunk/src/main/java/org/apache/hadoop/mrunit/mapreduce/ReduceDriver.java Sun Jul  1 19:34:18 2012
@@ -18,7 +18,7 @@
 
 package org.apache.hadoop.mrunit.mapreduce;
 
-import static org.apache.hadoop.mrunit.internal.util.ArgumentChecker.returnNonNull;
+import static org.apache.hadoop.mrunit.internal.util.ArgumentChecker.*;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -33,9 +33,9 @@ import org.apache.hadoop.mapreduce.Outpu
 import org.apache.hadoop.mapreduce.Reducer;
 import org.apache.hadoop.mrunit.ReduceDriverBase;
 import org.apache.hadoop.mrunit.internal.counters.CounterWrapper;
+import org.apache.hadoop.mrunit.internal.mapreduce.ContextDriver;
 import org.apache.hadoop.mrunit.internal.mapreduce.MockReduceContextWrapper;
 import org.apache.hadoop.mrunit.internal.output.MockOutputCreator;
-import org.apache.hadoop.mrunit.internal.output.OutputCollectable;
 import org.apache.hadoop.mrunit.types.Pair;
 
 /**
@@ -49,7 +49,7 @@ import org.apache.hadoop.mrunit.types.Pa
  * sets should go in separate unit tests.
  */
 public class ReduceDriver<K1, V1, K2, V2> extends
-    ReduceDriverBase<K1, V1, K2, V2> {
+    ReduceDriverBase<K1, V1, K2, V2> implements ContextDriver {
 
   public static final Log LOG = LogFactory.getLog(ReduceDriver.class);
 
@@ -57,6 +57,10 @@ public class ReduceDriver<K1, V1, K2, V2
   private Counters counters;
 
   private final MockOutputCreator<K2, V2> mockOutputCreator = new MockOutputCreator<K2, V2>();
+  private final List<Pair<K1, List<V1>>> inputs = new ArrayList<Pair<K1, List<V1>>>();
+  private final MockReduceContextWrapper<K1, V1, K2, V2> wrapper = new MockReduceContextWrapper<K1, V1, K2, V2>(
+      inputs, mockOutputCreator, this);
+
 
   public ReduceDriver(final Reducer<K1, V1, K2, V2> r) {
     this();
@@ -95,6 +99,7 @@ public class ReduceDriver<K1, V1, K2, V2
   }
 
   /** @return the counters used in this test */
+  @Override
   public Counters getCounters() {
     return counters;
   }
@@ -240,17 +245,12 @@ public class ReduceDriver<K1, V1, K2, V2
       throw new IllegalStateException("No Reducer class was provided");
     }
 
-    final List<Pair<K1, List<V1>>> inputs = new ArrayList<Pair<K1, List<V1>>>();
+    inputs.clear();
     inputs.add(new Pair<K1, List<V1>>(inputKey, getInputValues()));
 
     try {
-      final OutputCollectable<K2, V2> outputCollectable = mockOutputCreator
-          .createOutputCollectable(getConfiguration(),
-              getOutputCopyingOrInputFormatConfiguration());
-      final MockReduceContextWrapper<K1, V1, K2, V2> wrapper = new MockReduceContextWrapper<K1, V1, K2, V2>(
-          inputs, getCounters(), getConfiguration(), outputCollectable);
       myReducer.run(wrapper.getMockContext());
-      return outputCollectable.getOutputs();
+      return wrapper.getOutputs();
     } catch (final InterruptedException ie) {
       throw new IOException(ie);
     }
@@ -286,6 +286,34 @@ public class ReduceDriver<K1, V1, K2, V2
     super.withCounter(g, n, e);
     return this;
   }
+  
+  /**
+   * <p>Obtain Context object for furthering mocking with Mockito.
+   * For example, causing write() to throw an exception:</p>
+   * 
+   * <pre>
+   * import static org.mockito.Matchers.*;
+   * import static org.mockito.Mockito.*;
+   * doThrow(new IOException()).when(context).write(any(), any());
+   * </pre>
+   * 
+   * <p>Or implement other logic:</p>
+   * 
+   * <pre>
+   * import static org.mockito.Matchers.*;
+   * import static org.mockito.Mockito.*;
+   * doAnswer(new Answer<Object>() {
+   *    public Object answer(final InvocationOnMock invocation) {
+   *    ...
+   *     return null;
+   *   }
+   * }).when(context).write(any(), any());
+   * </pre>
+   * @return the mocked context
+   */
+  public Reducer<K1, V1, K2, V2>.Context getContext() {
+    return wrapper.getMockContext();
+  }
 
   /**
    * Returns a new ReduceDriver without having to specify the generic types on

Modified: mrunit/trunk/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestMapDriver.java
URL: http://svn.apache.org/viewvc/mrunit/trunk/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestMapDriver.java?rev=1355989&r1=1355988&r2=1355989&view=diff
==============================================================================
--- mrunit/trunk/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestMapDriver.java (original)
+++ mrunit/trunk/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestMapDriver.java Sun Jul  1 19:34:18 2012
@@ -17,8 +17,11 @@
  */
 package org.apache.hadoop.mrunit.mapreduce;
 
-import static org.apache.hadoop.mrunit.ExtendedAssert.assertListEquals;
+import static org.apache.hadoop.mrunit.ExtendedAssert.*;
 import static org.junit.Assert.*;
+import static org.mockito.Matchers.*;
+import static org.mockito.Mockito.*;
+
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -260,6 +263,7 @@ public class TestMapDriver {
   
   public static class InputSplitDetailMapper
     extends Mapper<NullWritable, NullWritable, Text, LongWritable> {
+    @Override
     protected void map(NullWritable key, NullWritable value, Context context) 
         throws IOException, InterruptedException {
       FileSplit split = (FileSplit)context.getInputSplit();
@@ -354,4 +358,16 @@ public class TestMapDriver {
     assertNotNull(mapper.getMapInputPath());
     assertEquals(mapInputPath.getName(), mapper.getMapInputPath().getName());
   }
+  
+  @Test
+  public void textMockContext() throws IOException, InterruptedException {
+    thrown.expectMessage(RuntimeException.class, "Injected!");
+    Mapper<Text, Text, Text, Text>.Context context = driver.getContext();
+    doThrow(new RuntimeException("Injected!"))
+      .when(context)
+        .write(any(Text.class), any(Text.class));
+    driver.withInput(new Text("a"), new Text("1"));
+    driver.withOutput(new Text("a"), new Text("1"));
+    driver.runTest();
+  }
 }

Modified: mrunit/trunk/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestReduceDriver.java
URL: http://svn.apache.org/viewvc/mrunit/trunk/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestReduceDriver.java?rev=1355989&r1=1355988&r2=1355989&view=diff
==============================================================================
--- mrunit/trunk/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestReduceDriver.java (original)
+++ mrunit/trunk/src/test/java/org/apache/hadoop/mrunit/mapreduce/TestReduceDriver.java Sun Jul  1 19:34:18 2012
@@ -18,8 +18,10 @@
 
 package org.apache.hadoop.mrunit.mapreduce;
 
-import static org.apache.hadoop.mrunit.ExtendedAssert.assertListEquals;
-import static org.junit.Assert.assertEquals;
+import static org.apache.hadoop.mrunit.ExtendedAssert.*;
+import static org.junit.Assert.*;
+import static org.mockito.Matchers.*;
+import static org.mockito.Mockito.*;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -398,4 +400,17 @@ public class TestReduceDriver {
     driver.withOutput(new LongWritable(), new Text("a\t3"));
     driver.runTest();
   }
+  
+  @Test
+  public void textMockContext() throws IOException, InterruptedException {
+    thrown.expectMessage(RuntimeException.class, "Injected!");
+    Reducer<Text, LongWritable, Text, LongWritable>.Context context = driver.getContext();
+    doThrow(new RuntimeException("Injected!"))
+      .when(context)
+        .write(any(Text.class), any(LongWritable.class));
+    driver.withInputKey(new Text("a"));
+    driver.withInputValue(new LongWritable(1)).withInputValue(
+        new LongWritable(2));
+    driver.runTest();
+  }
 }