You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hive.apache.org by om...@apache.org on 2013/05/30 17:15:31 UTC

svn commit: r1487892 - in /hive/branches/vectorization/ql/src: java/org/apache/hadoop/hive/ql/exec/vector/ java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/ test/org/apache/hadoop/hive/ql/exec/vector/

Author: omalley
Date: Thu May 30 15:15:31 2013
New Revision: 1487892

URL: http://svn.apache.org/r1487892
Log:
HIVE-4452 Add support for COUNT(*) in vector aggregates (Remus Rusanu via
omalley)

Added:
    hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFCountStar.java
Modified:
    hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java
    hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java

Modified: hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java
URL: http://svn.apache.org/viewvc/hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java?rev=1487892&r1=1487891&r2=1487892&view=diff
==============================================================================
--- hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java (original)
+++ hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java Thu May 30 15:15:31 2013
@@ -41,6 +41,7 @@ import org.apache.hadoop.hive.ql.exec.ve
 import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsTrue;
 import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
 import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFCountStar;
 import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFAvgDouble;
 import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFAvgLong;
 import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFCountDouble;
@@ -942,6 +943,7 @@ public class VectorizationContext {
     {"min",       "Double", VectorUDAFMinDouble.class},
     {"max",       "Long",   VectorUDAFMaxLong.class},
     {"max",       "Double", VectorUDAFMaxDouble.class},
+    {"count",     null,     VectorUDAFCountStar.class},
     {"count",     "Long",   VectorUDAFCountLong.class},
     {"count",     "Double", VectorUDAFCountDouble.class},
     {"sum",       "Long",   VectorUDAFSumLong.class},
@@ -966,6 +968,7 @@ public class VectorizationContext {
 
   public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc)
       throws HiveException {
+
     ArrayList<ExprNodeDesc> paramDescList = desc.getParameters();
     VectorExpression[] vectorParams = new VectorExpression[paramDescList.size()];
 
@@ -975,22 +978,25 @@ public class VectorizationContext {
     }
 
     String aggregateName = desc.getGenericUDAFName();
-    List<ExprNodeDesc> params = desc.getParameters();
-    //TODO: handle length != 1
-    assert (params.size() == 1);
-    ExprNodeDesc inputExpr = params.get(0);
-    String inputType = getNormalizedTypeName(inputExpr.getTypeString());
+    String inputType = null;
+
+    if (paramDescList.size() > 0) {
+      ExprNodeDesc inputExpr = paramDescList.get(0);
+      inputType = getNormalizedTypeName(inputExpr.getTypeString());
+    }
 
     for (Object[] aggDef : aggregatesDefinition) {
-      if (aggDef[0].equals (aggregateName) &&
-          aggDef[1].equals(inputType)) {
+      if (aggregateName.equalsIgnoreCase((String) aggDef[0]) &&
+          ((aggDef[1] == null && inputType == null) ||
+          (aggDef[1] != null && aggDef[1].equals(inputType)))) {
         Class<? extends VectorAggregateExpression> aggClass =
             (Class<? extends VectorAggregateExpression>) (aggDef[2]);
         try
         {
           Constructor<? extends VectorAggregateExpression> ctor =
               aggClass.getConstructor(VectorExpression.class);
-          VectorAggregateExpression aggExpr = ctor.newInstance(vectorParams[0]);
+          VectorAggregateExpression aggExpr = ctor.newInstance(
+              vectorParams.length > 0 ? vectorParams[0] : null);
           return aggExpr;
         }
         // TODO: change to 1.7 syntax when possible

Added: hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFCountStar.java
URL: http://svn.apache.org/viewvc/hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFCountStar.java?rev=1487892&view=auto
==============================================================================
--- hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFCountStar.java (added)
+++ hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFCountStar.java Thu May 30 15:15:31 2013
@@ -0,0 +1,124 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow;
+import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.LongWritable;
+
+/**
+* VectorUDAFCountStar. Vectorized implementation for COUNT(*) aggregates.
+*/
+@Description(name = "count", value = "_FUNC_(expr) - Returns count(*) (vectorized)")
+public class VectorUDAFCountStar extends VectorAggregateExpression {
+
+    /**
+    /* class for storing the current aggregate value.
+    */
+    static class Aggregation implements AggregationBuffer {
+      long value;
+      boolean isNull;
+    }
+
+    private final LongWritable result;
+
+    public VectorUDAFCountStar(VectorExpression inputExpression) {
+      super();
+      result = new LongWritable(0);
+    }
+
+    private Aggregation getCurrentAggregationBuffer(
+        VectorAggregationBufferRow[] aggregationBufferSets,
+        int aggregateIndex,
+        int row) {
+      VectorAggregationBufferRow mySet = aggregationBufferSets[row];
+      Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregateIndex);
+      return myagg;
+    }
+
+    @Override
+    public void aggregateInputSelection(
+      VectorAggregationBufferRow[] aggregationBufferSets,
+      int aggregateIndex,
+      VectorizedRowBatch batch) throws HiveException {
+
+      int batchSize = batch.size;
+
+      if (batchSize == 0) {
+        return;
+      }
+
+      // count(*) cares not about NULLs nor selection
+      for (int i=0; i < batchSize; ++i) {
+        Aggregation myAgg = getCurrentAggregationBuffer(
+            aggregationBufferSets, aggregateIndex, i);
+        myAgg.isNull = false;
+        ++myAgg.value;
+      }
+    }
+
+    @Override
+    public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch)
+    throws HiveException {
+
+      int batchSize = batch.size;
+
+      if (batchSize == 0) {
+        return;
+      }
+
+      Aggregation myagg = (Aggregation)agg;
+      myagg.isNull = false;
+      myagg.value += batchSize;
+    }
+
+    @Override
+    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
+      return new Aggregation();
+    }
+
+    @Override
+    public void reset(AggregationBuffer agg) throws HiveException {
+      Aggregation myAgg = (Aggregation) agg;
+      myAgg.isNull = true;
+    }
+
+    @Override
+    public Object evaluateOutput(AggregationBuffer agg) throws HiveException {
+    Aggregation myagg = (Aggregation) agg;
+      if (myagg.isNull) {
+        return null;
+      }
+      else {
+        result.set (myagg.value);
+      return result;
+      }
+    }
+
+    @Override
+    public ObjectInspector getOutputObjectInspector() {
+      return PrimitiveObjectInspectorFactory.writableLongObjectInspector;
+    }
+}
+

Modified: hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java
URL: http://svn.apache.org/viewvc/hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java?rev=1487892&r1=1487891&r2=1487892&view=diff
==============================================================================
--- hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java (original)
+++ hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java Thu May 30 15:15:31 2013
@@ -79,6 +79,14 @@ public class TestVectorGroupByOperator {
 
     return agg;
   }
+  private static AggregationDesc buildAggregationDescCountStar(
+      VectorizationContext ctx) {
+    AggregationDesc agg = new AggregationDesc();
+    agg.setGenericUDAFName("COUNT");
+    agg.setParameters(new ArrayList<ExprNodeDesc>());
+    return agg;
+  }
+
 
   private static GroupByDesc buildGroupByDesc(
       VectorizationContext ctx,
@@ -98,6 +106,23 @@ public class TestVectorGroupByOperator {
 
     return desc;
   }
+  private static GroupByDesc buildGroupByDescCountStar(
+      VectorizationContext ctx) {
+
+    AggregationDesc agg = buildAggregationDescCountStar(ctx);
+    ArrayList<AggregationDesc> aggs = new ArrayList<AggregationDesc>();
+    aggs.add(agg);
+
+    ArrayList<String> outputColumnNames = new ArrayList<String>();
+    outputColumnNames.add("_col0");
+
+    GroupByDesc desc = new GroupByDesc();
+    desc.setOutputColumnNames(outputColumnNames);
+    desc.setAggregators(aggs);
+
+    return desc;
+  }
+
 
   private static GroupByDesc buildKeyGroupByDesc(
       VectorizationContext ctx,
@@ -117,6 +142,14 @@ public class TestVectorGroupByOperator {
   }
 
   @Test
+  public void testCountStar () throws HiveException {
+    testAggregateCountStar(
+        2,
+        Arrays.asList(new Long[]{13L,null,7L,19L}),
+        4L);
+  }
+
+  @Test
   public void testMinLongNullStringKeys() throws HiveException {
     testAggregateStringKeyAggregate(
         "min",
@@ -947,6 +980,17 @@ public class TestVectorGroupByOperator {
     testAggregateLongIterable (aggregateName, fdr, expected);
   }
 
+  public void testAggregateCountStar (
+      int batchSize,
+      Iterable<Long> values,
+      Object expected) throws HiveException {
+
+    @SuppressWarnings("unchecked")
+    FakeVectorRowBatchFromLongIterables fdr = new FakeVectorRowBatchFromLongIterables(batchSize, values);
+    testAggregateCountStarIterable (fdr, expected);
+  }
+
+
   public static interface Validator {
     void validate (Object expected, Object result);
   };
@@ -1086,6 +1130,35 @@ public class TestVectorGroupByOperator {
     throw new HiveException("Missing validator for aggregate: " + aggregate);
   }
 
+  public void testAggregateCountStarIterable (
+      Iterable<VectorizedRowBatch> data,
+      Object expected) throws HiveException {
+    Map<String, Integer> mapColumnNames = new HashMap<String, Integer>();
+    mapColumnNames.put("A", 0);
+    VectorizationContext ctx = new VectorizationContext(mapColumnNames, 1);
+
+    GroupByDesc desc = buildGroupByDescCountStar (ctx);
+
+    VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc);
+
+    FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(vgo);
+    vgo.initialize(null, null);
+
+    for (VectorizedRowBatch unit: data) {
+      vgo.process(unit,  0);
+    }
+    vgo.close(false);
+
+    List<Object> outBatchList = out.getCapturedRows();
+    assertNotNull(outBatchList);
+    assertEquals(1, outBatchList.size());
+
+    Object result = outBatchList.get(0);
+
+    Validator validator = getValidator("count");
+    validator.validate(expected, result);
+  }
+
   public void testAggregateLongIterable (
       String aggregateName,
       Iterable<VectorizedRowBatch> data,