You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hive.apache.org by mm...@apache.org on 2018/07/16 14:27:06 UTC
[2/3] hive git commit: HIVE-20174: Vectorization: Fix NULL / Wrong
Results issues in GROUP BY Aggregation Functions (Matt McCline,
reviewed by Teddy Choi)
http://git-wip-us.apache.org/repos/asf/hive/blob/0966a383/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64.java
index a503445..7f2a18a 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64.java
@@ -164,15 +164,9 @@ public class VectorUDAFSumDecimal64 extends VectorAggregateExpression {
}
} else {
if (inputVector.isRepeating) {
- if (batch.selectedInUse) {
- iterateHasNullsRepeatingSelectionWithAggregationSelection(
- aggregationBufferSets, aggregateIndex,
- vector[0], batchSize, batch.selected, inputVector.isNull);
- } else {
- iterateHasNullsRepeatingWithAggregationSelection(
- aggregationBufferSets, aggregateIndex,
- vector[0], batchSize, inputVector.isNull);
- }
+ iterateHasNullsRepeatingWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex,
+ vector[0], batchSize, inputVector.isNull);
} else {
if (batch.selectedInUse) {
iterateHasNullsSelectionWithAggregationSelection(
@@ -232,28 +226,6 @@ public class VectorUDAFSumDecimal64 extends VectorAggregateExpression {
}
}
- private void iterateHasNullsRepeatingSelectionWithAggregationSelection(
- VectorAggregationBufferRow[] aggregationBufferSets,
- int aggregateIndex,
- long value,
- int batchSize,
- int[] selection,
- boolean[] isNull) {
-
- if (isNull[0]) {
- return;
- }
-
- for (int i=0; i < batchSize; ++i) {
- Aggregation myagg = getCurrentAggregationBuffer(
- aggregationBufferSets,
- aggregateIndex,
- i);
- myagg.sumValue(value);
- }
-
- }
-
private void iterateHasNullsRepeatingWithAggregationSelection(
VectorAggregationBufferRow[] aggregationBufferSets,
int aggregateIndex,
http://git-wip-us.apache.org/repos/asf/hive/blob/0966a383/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64ToDecimal.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64ToDecimal.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64ToDecimal.java
index 117611e..a02bdf3 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64ToDecimal.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64ToDecimal.java
@@ -189,15 +189,9 @@ public class VectorUDAFSumDecimal64ToDecimal extends VectorAggregateExpression {
}
} else {
if (inputVector.isRepeating) {
- if (batch.selectedInUse) {
- iterateHasNullsRepeatingSelectionWithAggregationSelection(
- aggregationBufferSets, aggregateIndex,
- vector[0], batchSize, batch.selected, inputVector.isNull);
- } else {
- iterateHasNullsRepeatingWithAggregationSelection(
- aggregationBufferSets, aggregateIndex,
- vector[0], batchSize, inputVector.isNull);
- }
+ iterateHasNullsRepeatingWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex,
+ vector[0], batchSize, inputVector.isNull);
} else {
if (batch.selectedInUse) {
iterateHasNullsSelectionWithAggregationSelection(
@@ -257,28 +251,6 @@ public class VectorUDAFSumDecimal64ToDecimal extends VectorAggregateExpression {
}
}
- private void iterateHasNullsRepeatingSelectionWithAggregationSelection(
- VectorAggregationBufferRow[] aggregationBufferSets,
- int aggregateIndex,
- long value,
- int batchSize,
- int[] selection,
- boolean[] isNull) {
-
- if (isNull[0]) {
- return;
- }
-
- for (int i=0; i < batchSize; ++i) {
- Aggregation myagg = getCurrentAggregationBuffer(
- aggregationBufferSets,
- aggregateIndex,
- i);
- myagg.sumValue(value);
- }
-
- }
-
private void iterateHasNullsRepeatingWithAggregationSelection(
VectorAggregationBufferRow[] aggregationBufferSets,
int aggregateIndex,
http://git-wip-us.apache.org/repos/asf/hive/blob/0966a383/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumTimestamp.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumTimestamp.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumTimestamp.java
index e542033..731a143 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumTimestamp.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumTimestamp.java
@@ -131,15 +131,9 @@ public class VectorUDAFSumTimestamp extends VectorAggregateExpression {
}
} else {
if (inputVector.isRepeating) {
- if (batch.selectedInUse) {
- iterateHasNullsRepeatingSelectionWithAggregationSelection(
- aggregationBufferSets, aggregateIndex,
- inputVector.getDouble(0), batchSize, batch.selected, inputVector.isNull);
- } else {
- iterateHasNullsRepeatingWithAggregationSelection(
- aggregationBufferSets, aggregateIndex,
- inputVector.getDouble(0), batchSize, inputVector.isNull);
- }
+ iterateHasNullsRepeatingWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex,
+ inputVector.getDouble(0), batchSize, inputVector.isNull);
} else {
if (batch.selectedInUse) {
iterateHasNullsSelectionWithAggregationSelection(
@@ -199,28 +193,6 @@ public class VectorUDAFSumTimestamp extends VectorAggregateExpression {
}
}
- private void iterateHasNullsRepeatingSelectionWithAggregationSelection(
- VectorAggregationBufferRow[] aggregationBufferSets,
- int aggregateIndex,
- double value,
- int batchSize,
- int[] selection,
- boolean[] isNull) {
-
- if (isNull[0]) {
- return;
- }
-
- for (int i=0; i < batchSize; ++i) {
- Aggregation myagg = getCurrentAggregationBuffer(
- aggregationBufferSets,
- aggregateIndex,
- i);
- myagg.sumValue(value);
- }
-
- }
-
private void iterateHasNullsRepeatingWithAggregationSelection(
VectorAggregationBufferRow[] aggregationBufferSets,
int aggregateIndex,
http://git-wip-us.apache.org/repos/asf/hive/blob/0966a383/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java
index 7afbf04..7ec80e6 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java
@@ -4183,7 +4183,7 @@ public class Vectorizer implements PhysicalPlanResolver {
AggregationDesc aggrDesc, VectorizationContext vContext) throws HiveException {
String aggregateName = aggrDesc.getGenericUDAFName();
- ArrayList<ExprNodeDesc> parameterList = aggrDesc.getParameters();
+ List<ExprNodeDesc> parameterList = aggrDesc.getParameters();
final int parameterCount = parameterList.size();
final GenericUDAFEvaluator.Mode udafEvaluatorMode = aggrDesc.getMode();
@@ -4192,10 +4192,9 @@ public class Vectorizer implements PhysicalPlanResolver {
*/
GenericUDAFEvaluator evaluator = aggrDesc.getGenericUDAFEvaluator();
- ArrayList<ExprNodeDesc> parameters = aggrDesc.getParameters();
ObjectInspector[] parameterObjectInspectors = new ObjectInspector[parameterCount];
for (int i = 0; i < parameterCount; i++) {
- TypeInfo typeInfo = parameters.get(i).getTypeInfo();
+ TypeInfo typeInfo = parameterList.get(i).getTypeInfo();
parameterObjectInspectors[i] = TypeInfoUtils
.getStandardWritableObjectInspectorFromTypeInfo(typeInfo);
}
@@ -4207,18 +4206,30 @@ public class Vectorizer implements PhysicalPlanResolver {
aggrDesc.getMode(),
parameterObjectInspectors);
+ final TypeInfo outputTypeInfo = TypeInfoUtils.getTypeInfoFromTypeString(returnOI.getTypeName());
+
+ return getVectorAggregationDesc(
+ aggregateName, parameterList, evaluator, outputTypeInfo, udafEvaluatorMode, vContext);
+ }
+
+ public static ImmutablePair<VectorAggregationDesc,String> getVectorAggregationDesc(
+ String aggregationName, List<ExprNodeDesc> parameterList,
+ GenericUDAFEvaluator evaluator, TypeInfo outputTypeInfo,
+ GenericUDAFEvaluator.Mode udafEvaluatorMode,
+ VectorizationContext vContext)
+ throws HiveException {
+
VectorizedUDAFs annotation =
AnnotationUtils.getAnnotation(evaluator.getClass(), VectorizedUDAFs.class);
if (annotation == null) {
String issue =
"Evaluator " + evaluator.getClass().getSimpleName() + " does not have a " +
- "vectorized UDAF annotation (aggregation: \"" + aggregateName + "\"). " +
+ "vectorized UDAF annotation (aggregation: \"" + aggregationName + "\"). " +
"Vectorization not supported";
return new ImmutablePair<VectorAggregationDesc,String>(null, issue);
}
final Class<? extends VectorAggregateExpression>[] vecAggrClasses = annotation.value();
- final TypeInfo outputTypeInfo = TypeInfoUtils.getTypeInfoFromTypeString(returnOI.getTypeName());
// Not final since it may change later due to DECIMAL_64.
ColumnVector.Type outputColVectorType =
@@ -4233,6 +4244,7 @@ public class Vectorizer implements PhysicalPlanResolver {
VectorExpression inputExpression;
ColumnVector.Type inputColVectorType;
+ final int parameterCount = parameterList.size();
if (parameterCount == 0) {
// COUNT(*)
@@ -4246,7 +4258,7 @@ public class Vectorizer implements PhysicalPlanResolver {
inputTypeInfo = exprNodeDesc.getTypeInfo();
if (inputTypeInfo == null) {
String issue ="Aggregations with null parameter type not supported " +
- aggregateName + "(" + parameterList.toString() + ")";
+ aggregationName + "(" + parameterList.toString() + ")";
return new ImmutablePair<VectorAggregationDesc,String>(null, issue);
}
@@ -4260,12 +4272,12 @@ public class Vectorizer implements PhysicalPlanResolver {
exprNodeDesc, VectorExpressionDescriptor.Mode.PROJECTION);
if (inputExpression == null) {
String issue ="Parameter expression " + exprNodeDesc.toString() + " not supported " +
- aggregateName + "(" + parameterList.toString() + ")";
+ aggregationName + "(" + parameterList.toString() + ")";
return new ImmutablePair<VectorAggregationDesc,String>(null, issue);
}
if (inputExpression.getOutputTypeInfo() == null) {
String issue ="Parameter expression " + exprNodeDesc.toString() + " with null type not supported " +
- aggregateName + "(" + parameterList.toString() + ")";
+ aggregationName + "(" + parameterList.toString() + ")";
return new ImmutablePair<VectorAggregationDesc,String>(null, issue);
}
inputColVectorType = inputExpression.getOutputColumnVectorType();
@@ -4273,7 +4285,7 @@ public class Vectorizer implements PhysicalPlanResolver {
// No multi-parameter aggregations supported.
String issue ="Aggregations with > 1 parameter are not supported " +
- aggregateName + "(" + parameterList.toString() + ")";
+ aggregationName + "(" + parameterList.toString() + ")";
return new ImmutablePair<VectorAggregationDesc,String>(null, issue);
}
@@ -4291,12 +4303,13 @@ public class Vectorizer implements PhysicalPlanResolver {
// Try with DECIMAL_64 input and DECIMAL_64 output.
final Class<? extends VectorAggregateExpression> vecAggrClass =
findVecAggrClass(
- vecAggrClasses, aggregateName, inputColVectorType,
+ vecAggrClasses, aggregationName, inputColVectorType,
ColumnVector.Type.DECIMAL_64, udafEvaluatorMode);
if (vecAggrClass != null) {
final VectorAggregationDesc vecAggrDesc =
new VectorAggregationDesc(
- aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression,
+ aggregationName, evaluator, udafEvaluatorMode,
+ inputTypeInfo, inputColVectorType, inputExpression,
outputTypeInfo, ColumnVector.Type.DECIMAL_64, vecAggrClass);
return new ImmutablePair<VectorAggregationDesc,String>(vecAggrDesc, null);
}
@@ -4305,12 +4318,13 @@ public class Vectorizer implements PhysicalPlanResolver {
// Try with regular DECIMAL output type.
final Class<? extends VectorAggregateExpression> vecAggrClass =
findVecAggrClass(
- vecAggrClasses, aggregateName, inputColVectorType,
+ vecAggrClasses, aggregationName, inputColVectorType,
outputColVectorType, udafEvaluatorMode);
if (vecAggrClass != null) {
final VectorAggregationDesc vecAggrDesc =
new VectorAggregationDesc(
- aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression,
+ aggregationName, evaluator, udafEvaluatorMode,
+ inputTypeInfo, inputColVectorType, inputExpression,
outputTypeInfo, outputColVectorType, vecAggrClass);
return new ImmutablePair<VectorAggregationDesc,String>(vecAggrDesc, null);
}
@@ -4325,19 +4339,20 @@ public class Vectorizer implements PhysicalPlanResolver {
// Try with with DECIMAL_64 input and desired output type.
final Class<? extends VectorAggregateExpression> vecAggrClass =
findVecAggrClass(
- vecAggrClasses, aggregateName, inputColVectorType,
+ vecAggrClasses, aggregationName, inputColVectorType,
outputColVectorType, udafEvaluatorMode);
if (vecAggrClass != null) {
// for now, disable operating on decimal64 column vectors for semijoin reduction as
// we have to make sure same decimal type should be used during bloom filter creation
// and bloom filter probing
- if (aggregateName.equals("bloom_filter")) {
+ if (aggregationName.equals("bloom_filter")) {
inputExpression = vContext.wrapWithDecimal64ToDecimalConversion(inputExpression);
inputColVectorType = ColumnVector.Type.DECIMAL;
}
final VectorAggregationDesc vecAggrDesc =
new VectorAggregationDesc(
- aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression,
+ aggregationName, evaluator, udafEvaluatorMode,
+ inputTypeInfo, inputColVectorType, inputExpression,
outputTypeInfo, outputColVectorType, vecAggrClass);
return new ImmutablePair<VectorAggregationDesc,String>(vecAggrDesc, null);
}
@@ -4355,19 +4370,20 @@ public class Vectorizer implements PhysicalPlanResolver {
*/
Class<? extends VectorAggregateExpression> vecAggrClass =
findVecAggrClass(
- vecAggrClasses, aggregateName, inputColVectorType,
+ vecAggrClasses, aggregationName, inputColVectorType,
outputColVectorType, udafEvaluatorMode);
if (vecAggrClass != null) {
final VectorAggregationDesc vecAggrDesc =
new VectorAggregationDesc(
- aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression,
+ aggregationName, evaluator, udafEvaluatorMode,
+ inputTypeInfo, inputColVectorType, inputExpression,
outputTypeInfo, outputColVectorType, vecAggrClass);
return new ImmutablePair<VectorAggregationDesc,String>(vecAggrDesc, null);
}
// No match?
String issue =
- "Vector aggregation : \"" + aggregateName + "\" " +
+ "Vector aggregation : \"" + aggregationName + "\" " +
"for input type: " +
(inputColVectorType == null ? "any" : "\"" + inputColVectorType) + "\" " +
"and output type: \"" + outputColVectorType + "\" " +
http://git-wip-us.apache.org/repos/asf/hive/blob/0966a383/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java
index d170d86..5cb7061 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java
@@ -56,6 +56,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.util.StringUtils;
@@ -250,6 +251,25 @@ public class GenericUDAFAverage extends AbstractGenericUDAFResolver {
VectorUDAFAvgDecimalPartial2.class, VectorUDAFAvgDecimalFinal.class})
public static class GenericUDAFAverageEvaluatorDecimal extends AbstractGenericUDAFAverageEvaluator<HiveDecimal> {
+ private int resultPrecision = -1;
+ private int resultScale = -1;
+
+ @Override
+ public ObjectInspector init(Mode m, ObjectInspector[] parameters)
+ throws HiveException {
+
+ // Intercept result ObjectInspector so we can extract the DECIMAL precision and scale.
+ ObjectInspector resultOI = super.init(m, parameters);
+ if (m == Mode.COMPLETE || m == Mode.FINAL) {
+ DecimalTypeInfo decimalTypeInfo =
+ (DecimalTypeInfo)
+ TypeInfoUtils.getTypeInfoFromObjectInspector(resultOI);
+ resultPrecision = decimalTypeInfo.getPrecision();
+ resultScale = decimalTypeInfo.getScale();
+ }
+ return resultOI;
+ }
+
@Override
public void doReset(AverageAggregationBuffer<HiveDecimal> aggregation) throws HiveException {
aggregation.count = 0;
@@ -336,6 +356,7 @@ public class GenericUDAFAverage extends AbstractGenericUDAFResolver {
} else {
HiveDecimalWritable result = new HiveDecimalWritable(HiveDecimal.ZERO);
result.set(aggregation.sum.divide(HiveDecimal.create(aggregation.count)));
+ result.mutateEnforcePrecisionScale(resultPrecision, resultScale);
return result;
}
}
http://git-wip-us.apache.org/repos/asf/hive/blob/0966a383/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFVariance.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFVariance.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFVariance.java
index c9fb3df..bb55d88 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFVariance.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFVariance.java
@@ -132,23 +132,29 @@ public class GenericUDAFVariance extends AbstractGenericUDAFResolver {
/*
* Calculate the variance family {VARIANCE, VARIANCE_SAMPLE, STANDARD_DEVIATION, or
- * STANDARD_DEVIATION_STAMPLE) result when count > 1. Public so vectorization code can
+ * STANDARD_DEVIATION_SAMPLE) result when count > 1. Public so vectorization code can
* use it, etc.
*/
public static double calculateVarianceFamilyResult(double variance, long count,
VarianceKind varianceKind) {
+ final double result;
switch (varianceKind) {
case VARIANCE:
- return GenericUDAFVarianceEvaluator.calculateVarianceResult(variance, count);
+ result = GenericUDAFVarianceEvaluator.calculateVarianceResult(variance, count);
+ break;
case VARIANCE_SAMPLE:
- return GenericUDAFVarianceSampleEvaluator.calculateVarianceSampleResult(variance, count);
+ result = GenericUDAFVarianceSampleEvaluator.calculateVarianceSampleResult(variance, count);
+ break;
case STANDARD_DEVIATION:
- return GenericUDAFStdEvaluator.calculateStdResult(variance, count);
+ result = GenericUDAFStdEvaluator.calculateStdResult(variance, count);
+ break;
case STANDARD_DEVIATION_SAMPLE:
- return GenericUDAFStdSampleEvaluator.calculateStdSampleResult(variance, count);
+ result = GenericUDAFStdSampleEvaluator.calculateStdSampleResult(variance, count);
+ break;
default:
throw new RuntimeException("Unexpected variance kind " + varianceKind);
}
+ return result;
}
@Override
@@ -381,7 +387,8 @@ public class GenericUDAFVariance extends AbstractGenericUDAFResolver {
* Calculate the variance result when count > 1. Public so vectorization code can use it, etc.
*/
public static double calculateVarianceResult(double variance, long count) {
- return variance / count;
+ final double result = variance / count;
+ return result;
}
@Override
http://git-wip-us.apache.org/repos/asf/hive/blob/0966a383/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java
----------------------------------------------------------------------
diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java
index ffdc410..fe1375b 100644
--- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java
+++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java
@@ -216,7 +216,10 @@ public class TestVectorGroupByOperator {
vectorDesc.setVecAggrDescs(
new VectorAggregationDesc[] {
new VectorAggregationDesc(
- agg, new GenericUDAFCount.GenericUDAFCountEvaluator(), null, ColumnVector.Type.NONE, null,
+ agg.getGenericUDAFName(),
+ new GenericUDAFCount.GenericUDAFCountEvaluator(),
+ agg.getMode(),
+ null, ColumnVector.Type.NONE, null,
TypeInfoFactory.longTypeInfo, ColumnVector.Type.LONG, VectorUDAFCountStar.class)});
vectorDesc.setProcessingMode(VectorGroupByDesc.ProcessingMode.HASH);
@@ -1555,7 +1558,7 @@ public class TestVectorGroupByOperator {
"avg",
2,
Arrays.asList(new Long[]{}),
- null);
+ 0.0);
}
@Test
@@ -1564,12 +1567,12 @@ public class TestVectorGroupByOperator {
"avg",
2,
Arrays.asList(new Long[]{null}),
- null);
+ 0.0);
testAggregateLongAggregate(
"avg",
2,
Arrays.asList(new Long[]{null, null, null}),
- null);
+ 0.0);
testAggregateLongAggregate(
"avg",
2,
@@ -1601,7 +1604,7 @@ public class TestVectorGroupByOperator {
null,
4096,
1024,
- null);
+ 0.0);
}
@SuppressWarnings("unchecked")
@@ -1632,7 +1635,7 @@ public class TestVectorGroupByOperator {
"variance",
2,
Arrays.asList(new Long[]{}),
- null);
+ 0.0);
}
@Test
@@ -1650,12 +1653,12 @@ public class TestVectorGroupByOperator {
"variance",
2,
Arrays.asList(new Long[]{null}),
- null);
+ 0.0);
testAggregateLongAggregate(
"variance",
2,
Arrays.asList(new Long[]{null, null, null}),
- null);
+ 0.0);
testAggregateLongAggregate(
"variance",
2,
@@ -1680,7 +1683,7 @@ public class TestVectorGroupByOperator {
null,
4096,
1024,
- null);
+ 0.0);
}
@Test
@@ -1708,7 +1711,7 @@ public class TestVectorGroupByOperator {
"var_samp",
2,
Arrays.asList(new Long[]{}),
- null);
+ 0.0);
}
@@ -1737,7 +1740,7 @@ public class TestVectorGroupByOperator {
"std",
2,
Arrays.asList(new Long[]{}),
- null);
+ 0.0);
}
@@ -1758,7 +1761,7 @@ public class TestVectorGroupByOperator {
null,
4096,
1024,
- null);
+ 0.0);
}
@@ -2236,14 +2239,21 @@ public class TestVectorGroupByOperator {
assertEquals (true, vals[0] instanceof LongWritable);
LongWritable lw = (LongWritable) vals[0];
- assertFalse (lw.get() == 0L);
if (vals[1] instanceof DoubleWritable) {
DoubleWritable dw = (DoubleWritable) vals[1];
- assertEquals (key, expected, dw.get() / lw.get());
+ if (lw.get() != 0L) {
+ assertEquals (key, expected, dw.get() / lw.get());
+ } else {
+ assertEquals(key, expected, 0.0);
+ }
} else if (vals[1] instanceof HiveDecimalWritable) {
HiveDecimalWritable hdw = (HiveDecimalWritable) vals[1];
- assertEquals (key, expected, hdw.getHiveDecimal().divide(HiveDecimal.create(lw.get())));
+ if (lw.get() != 0L) {
+ assertEquals (key, expected, hdw.getHiveDecimal().divide(HiveDecimal.create(lw.get())));
+ } else {
+ assertEquals(key, expected, HiveDecimal.ZERO);
+ }
}
}
}
@@ -2271,10 +2281,14 @@ public class TestVectorGroupByOperator {
assertEquals (true, vals[1] instanceof DoubleWritable);
assertEquals (true, vals[2] instanceof DoubleWritable);
LongWritable cnt = (LongWritable) vals[0];
- DoubleWritable sum = (DoubleWritable) vals[1];
- DoubleWritable var = (DoubleWritable) vals[2];
- assertTrue (1 <= cnt.get());
- validateVariance (key, (Double) expected, cnt.get(), sum.get(), var.get());
+ if (cnt.get() == 0) {
+ assertEquals(key, expected, 0.0);
+ } else {
+ DoubleWritable sum = (DoubleWritable) vals[1];
+ DoubleWritable var = (DoubleWritable) vals[2];
+ assertTrue (1 <= cnt.get());
+ validateVariance (key, (Double) expected, cnt.get(), sum.get(), var.get());
+ }
}
}
}
http://git-wip-us.apache.org/repos/asf/hive/blob/0966a383/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomBatchSource.java
----------------------------------------------------------------------
diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomBatchSource.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomBatchSource.java
index 4c2f872..dd2f8e3 100644
--- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomBatchSource.java
+++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomBatchSource.java
@@ -167,6 +167,8 @@ public class VectorRandomBatchSource {
VectorRandomRowSource vectorRandomRowSource,
Object[][] randomRows) {
+ final boolean allowNull = vectorRandomRowSource.getAllowNull();
+
List<VectorBatchPattern> vectorBatchPatternList = new ArrayList<VectorBatchPattern>();
final int rowCount = randomRows.length;
int rowIndex = 0;
@@ -201,35 +203,38 @@ public class VectorRandomBatchSource {
*/
while (true) {
- // Repeated NULL permutations.
long columnPermutation = 1;
- while (true) {
- if (columnPermutation > columnPermutationLimit) {
- break;
- }
- final int maximumRowCount = Math.min(rowCount - rowIndex, VectorizedRowBatch.DEFAULT_SIZE);
- if (maximumRowCount == 0) {
- break;
- }
- int randomRowCount = 1 + random.nextInt(maximumRowCount);
- final int rowLimit = rowIndex + randomRowCount;
+ if (allowNull) {
- BitSet bitSet = BitSet.valueOf(new long[]{columnPermutation});
+ // Repeated NULL permutations.
+ while (true) {
+ if (columnPermutation > columnPermutationLimit) {
+ break;
+ }
+ final int maximumRowCount = Math.min(rowCount - rowIndex, VectorizedRowBatch.DEFAULT_SIZE);
+ if (maximumRowCount == 0) {
+ break;
+ }
+ int randomRowCount = 1 + random.nextInt(maximumRowCount);
+ final int rowLimit = rowIndex + randomRowCount;
- for (int columnNum = bitSet.nextSetBit(0);
- columnNum >= 0;
- columnNum = bitSet.nextSetBit(columnNum + 1)) {
+ BitSet bitSet = BitSet.valueOf(new long[]{columnPermutation});
- // Repeated NULL fill down column.
- for (int r = rowIndex; r < rowLimit; r++) {
- randomRows[r][columnNum] = null;
+ for (int columnNum = bitSet.nextSetBit(0);
+ columnNum >= 0;
+ columnNum = bitSet.nextSetBit(columnNum + 1)) {
+
+ // Repeated NULL fill down column.
+ for (int r = rowIndex; r < rowLimit; r++) {
+ randomRows[r][columnNum] = null;
+ }
}
+ vectorBatchPatternList.add(
+ VectorBatchPattern.createRepeatedBatch(
+ random, randomRowCount, bitSet, asSelected));
+ columnPermutation++;
+ rowIndex = rowLimit;
}
- vectorBatchPatternList.add(
- VectorBatchPattern.createRepeatedBatch(
- random, randomRowCount, bitSet, asSelected));
- columnPermutation++;
- rowIndex = rowLimit;
}
// Repeated non-NULL permutations.
http://git-wip-us.apache.org/repos/asf/hive/blob/0966a383/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java
----------------------------------------------------------------------
diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java
index 6181ae8..a1cefaa 100644
--- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java
+++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java
@@ -21,7 +21,6 @@ package org.apache.hadoop.hive.ql.exec.vector;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.text.ParseException;
-
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
@@ -29,7 +28,6 @@ import java.util.Random;
import java.util.Set;
import org.apache.commons.lang.StringUtils;
-
import org.apache.hadoop.hive.common.type.DataTypePhysicalVariation;
import org.apache.hadoop.hive.common.type.Date;
import org.apache.hadoop.hive.common.type.HiveChar;
@@ -86,6 +84,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo;
import org.apache.hive.common.util.DateUtils;
import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.BooleanWritable;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.LongWritable;
@@ -130,6 +129,10 @@ public class VectorRandomRowSource {
private boolean addEscapables;
private String needsEscapeStr;
+ public boolean getAllowNull() {
+ return allowNull;
+ }
+
public static class StringGenerationOption {
private boolean generateSentences;
@@ -1021,43 +1024,141 @@ public class VectorRandomRowSource {
switch (primitiveTypeInfo.getPrimitiveCategory()) {
case BOOLEAN:
- return ((WritableBooleanObjectInspector) objectInspector).create((boolean) object);
+ {
+ WritableBooleanObjectInspector writableOI = (WritableBooleanObjectInspector) objectInspector;
+ if (object instanceof Boolean) {
+ return writableOI.create((boolean) object);
+ } else {
+ return writableOI.copyObject(object);
+ }
+ }
case BYTE:
- return ((WritableByteObjectInspector) objectInspector).create((byte) object);
+ {
+ WritableByteObjectInspector writableOI = (WritableByteObjectInspector) objectInspector;
+ if (object instanceof Byte) {
+ return writableOI.create((byte) object);
+ } else {
+ return writableOI.copyObject(object);
+ }
+ }
case SHORT:
- return ((WritableShortObjectInspector) objectInspector).create((short) object);
+ {
+ WritableShortObjectInspector writableOI = (WritableShortObjectInspector) objectInspector;
+ if (object instanceof Short) {
+ return writableOI.create((short) object);
+ } else {
+ return writableOI.copyObject(object);
+ }
+ }
case INT:
- return ((WritableIntObjectInspector) objectInspector).create((int) object);
+ {
+ WritableIntObjectInspector writableOI = (WritableIntObjectInspector) objectInspector;
+ if (object instanceof Integer) {
+ return writableOI.create((int) object);
+ } else {
+ return writableOI.copyObject(object);
+ }
+ }
case LONG:
- return ((WritableLongObjectInspector) objectInspector).create((long) object);
+ {
+ WritableLongObjectInspector writableOI = (WritableLongObjectInspector) objectInspector;
+ if (object instanceof Long) {
+ return writableOI.create((long) object);
+ } else {
+ return writableOI.copyObject(object);
+ }
+ }
case DATE:
- return ((WritableDateObjectInspector) objectInspector).create((Date) object);
+ {
+ WritableDateObjectInspector writableOI = (WritableDateObjectInspector) objectInspector;
+ if (object instanceof Date) {
+ return writableOI.create((Date) object);
+ } else {
+ return writableOI.copyObject(object);
+ }
+ }
case FLOAT:
- return ((WritableFloatObjectInspector) objectInspector).create((float) object);
+ {
+ WritableFloatObjectInspector writableOI = (WritableFloatObjectInspector) objectInspector;
+ if (object instanceof Float) {
+ return writableOI.create((float) object);
+ } else {
+ return writableOI.copyObject(object);
+ }
+ }
case DOUBLE:
- return ((WritableDoubleObjectInspector) objectInspector).create((double) object);
+ {
+ WritableDoubleObjectInspector writableOI = (WritableDoubleObjectInspector) objectInspector;
+ if (object instanceof Double) {
+ return writableOI.create((double) object);
+ } else {
+ return writableOI.copyObject(object);
+ }
+ }
case STRING:
- return ((WritableStringObjectInspector) objectInspector).create((String) object);
+ {
+ WritableStringObjectInspector writableOI = (WritableStringObjectInspector) objectInspector;
+ if (object instanceof String) {
+ return writableOI.create((String) object);
+ } else {
+ return writableOI.copyObject(object);
+ }
+ }
case CHAR:
{
WritableHiveCharObjectInspector writableCharObjectInspector =
new WritableHiveCharObjectInspector( (CharTypeInfo) primitiveTypeInfo);
- return writableCharObjectInspector.create((HiveChar) object);
+ if (object instanceof HiveChar) {
+ return writableCharObjectInspector.create((HiveChar) object);
+ } else {
+ return writableCharObjectInspector.copyObject(object);
+ }
}
case VARCHAR:
{
WritableHiveVarcharObjectInspector writableVarcharObjectInspector =
new WritableHiveVarcharObjectInspector( (VarcharTypeInfo) primitiveTypeInfo);
- return writableVarcharObjectInspector.create((HiveVarchar) object);
+ if (object instanceof HiveVarchar) {
+ return writableVarcharObjectInspector.create((HiveVarchar) object);
+ } else {
+ return writableVarcharObjectInspector.copyObject(object);
+ }
}
case BINARY:
- return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector.create((byte[]) object);
+ {
+ if (object instanceof byte[]) {
+ return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector.create((byte[]) object);
+ } else {
+ return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector.copyObject(object);
+ }
+ }
case TIMESTAMP:
- return ((WritableTimestampObjectInspector) objectInspector).create((Timestamp) object);
+ {
+ WritableTimestampObjectInspector writableOI = (WritableTimestampObjectInspector) objectInspector;
+ if (object instanceof Timestamp) {
+ return writableOI.create((Timestamp) object);
+ } else {
+ return writableOI.copyObject(object);
+ }
+ }
case INTERVAL_YEAR_MONTH:
- return ((WritableHiveIntervalYearMonthObjectInspector) objectInspector).create((HiveIntervalYearMonth) object);
+ {
+ WritableHiveIntervalYearMonthObjectInspector writableOI = (WritableHiveIntervalYearMonthObjectInspector) objectInspector;
+ if (object instanceof HiveIntervalYearMonth) {
+ return writableOI.create((HiveIntervalYearMonth) object);
+ } else {
+ return writableOI.copyObject(object);
+ }
+ }
case INTERVAL_DAY_TIME:
- return ((WritableHiveIntervalDayTimeObjectInspector) objectInspector).create((HiveIntervalDayTime) object);
+ {
+ WritableHiveIntervalDayTimeObjectInspector writableOI = (WritableHiveIntervalDayTimeObjectInspector) objectInspector;
+ if (object instanceof HiveIntervalDayTime) {
+ return writableOI.create((HiveIntervalDayTime) object);
+ } else {
+ return writableOI.copyObject(object);
+ }
+ }
case DECIMAL:
{
if (dataTypePhysicalVariation == dataTypePhysicalVariation.DECIMAL_64) {
@@ -1071,9 +1172,13 @@ public class VectorRandomRowSource {
}
return ((WritableLongObjectInspector) objectInspector).create(value);
} else {
- WritableHiveDecimalObjectInspector writableDecimalObjectInspector =
+ WritableHiveDecimalObjectInspector writableOI =
new WritableHiveDecimalObjectInspector((DecimalTypeInfo) primitiveTypeInfo);
- return writableDecimalObjectInspector.create((HiveDecimal) object);
+ if (object instanceof HiveDecimal) {
+ return writableOI.create((HiveDecimal) object);
+ } else {
+ return writableOI.copyObject(object);
+ }
}
}
default:
@@ -1081,6 +1186,116 @@ public class VectorRandomRowSource {
}
}
+ public static Object getNonWritablePrimitiveObject(Object object, TypeInfo typeInfo,
+ ObjectInspector objectInspector) {
+
+ PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo;
+ switch (primitiveTypeInfo.getPrimitiveCategory()) {
+ case BOOLEAN:
+ if (object instanceof Boolean) {
+ return object;
+ } else {
+ return ((WritableBooleanObjectInspector) objectInspector).get(object);
+ }
+ case BYTE:
+ if (object instanceof Byte) {
+ return object;
+ } else {
+ return ((WritableByteObjectInspector) objectInspector).get(object);
+ }
+ case SHORT:
+ if (object instanceof Short) {
+ return object;
+ } else {
+ return ((WritableShortObjectInspector) objectInspector).get(object);
+ }
+ case INT:
+ if (object instanceof Integer) {
+ return object;
+ } else {
+ return ((WritableIntObjectInspector) objectInspector).get(object);
+ }
+ case LONG:
+ if (object instanceof Long) {
+ return object;
+ } else {
+ return ((WritableLongObjectInspector) objectInspector).get(object);
+ }
+ case FLOAT:
+ if (object instanceof Float) {
+ return object;
+ } else {
+ return ((WritableFloatObjectInspector) objectInspector).get(object);
+ }
+ case DOUBLE:
+ if (object instanceof Double) {
+ return object;
+ } else {
+ return ((WritableDoubleObjectInspector) objectInspector).get(object);
+ }
+ case STRING:
+ if (object instanceof String) {
+ return object;
+ } else {
+ return ((WritableStringObjectInspector) objectInspector).getPrimitiveJavaObject(object);
+ }
+ case DATE:
+ if (object instanceof Date) {
+ return object;
+ } else {
+ return ((WritableDateObjectInspector) objectInspector).getPrimitiveJavaObject(object);
+ }
+ case TIMESTAMP:
+ if (object instanceof Timestamp) {
+ return object;
+ } else if (object instanceof org.apache.hadoop.hive.common.type.Timestamp) {
+ return object;
+ } else {
+ return ((WritableTimestampObjectInspector) objectInspector).getPrimitiveJavaObject(object);
+ }
+ case DECIMAL:
+ if (object instanceof HiveDecimal) {
+ return object;
+ } else {
+ WritableHiveDecimalObjectInspector writableDecimalObjectInspector =
+ new WritableHiveDecimalObjectInspector((DecimalTypeInfo) primitiveTypeInfo);
+ return writableDecimalObjectInspector.getPrimitiveJavaObject(object);
+ }
+ case VARCHAR:
+ if (object instanceof HiveVarchar) {
+ return object;
+ } else {
+ WritableHiveVarcharObjectInspector writableVarcharObjectInspector =
+ new WritableHiveVarcharObjectInspector( (VarcharTypeInfo) primitiveTypeInfo);
+ return writableVarcharObjectInspector.getPrimitiveJavaObject(object);
+ }
+ case CHAR:
+ if (object instanceof HiveChar) {
+ return object;
+ } else {
+ WritableHiveCharObjectInspector writableCharObjectInspector =
+ new WritableHiveCharObjectInspector( (CharTypeInfo) primitiveTypeInfo);
+ return writableCharObjectInspector.getPrimitiveJavaObject(object);
+ }
+ case INTERVAL_YEAR_MONTH:
+ if (object instanceof HiveIntervalYearMonth) {
+ return object;
+ } else {
+ return ((WritableHiveIntervalYearMonthObjectInspector) objectInspector).getPrimitiveJavaObject(object);
+ }
+ case INTERVAL_DAY_TIME:
+ if (object instanceof HiveIntervalDayTime) {
+ return object;
+ } else {
+ return ((WritableHiveIntervalDayTimeObjectInspector) objectInspector).getPrimitiveJavaObject(object);
+ }
+ case BINARY:
+ default:
+ throw new RuntimeException(
+ "Unexpected primitive category " + primitiveTypeInfo.getPrimitiveCategory());
+ }
+ }
+
public Object randomWritable(int column) {
return randomWritable(
typeInfos[column], objectInspectorList.get(column), dataTypePhysicalVariations[column],
http://git-wip-us.apache.org/repos/asf/hive/blob/0966a383/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/aggregation/AggregationBase.java
----------------------------------------------------------------------
diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/aggregation/AggregationBase.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/aggregation/AggregationBase.java
new file mode 100644
index 0000000..583241c
--- /dev/null
+++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/aggregation/AggregationBase.java
@@ -0,0 +1,473 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.vector.aggregation;
+
+import java.lang.reflect.Constructor;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.hadoop.hive.common.type.DataTypePhysicalVariation;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
+import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector;
+import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow;
+import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationDesc;
+import org.apache.hadoop.hive.ql.exec.vector.VectorExtractRow;
+import org.apache.hadoop.hive.ql.exec.vector.VectorRandomBatchSource;
+import org.apache.hadoop.hive.ql.exec.vector.VectorRandomRowSource;
+import org.apache.hadoop.hive.ql.exec.vector.VectorizationContext;
+import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
+import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatchCtx;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.optimizer.physical.Vectorizer;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer;
+import org.apache.hadoop.hive.serde2.io.ShortWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
+import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
+
+import junit.framework.Assert;
+
+public class AggregationBase {
+
+ public enum AggregationTestMode {
+ ROW_MODE,
+ VECTOR_EXPRESSION;
+
+ static final int count = values().length;
+ }
+
+ public static GenericUDAFEvaluator getEvaluator(String aggregationFunctionName,
+ TypeInfo typeInfo)
+ throws SemanticException {
+
+ GenericUDAFResolver resolver =
+ FunctionRegistry.getGenericUDAFResolver(aggregationFunctionName);
+ TypeInfo[] parameters = new TypeInfo[] { typeInfo };
+ GenericUDAFEvaluator evaluator = resolver.getEvaluator(parameters);
+ return evaluator;
+ }
+
+ protected static boolean doRowTest(TypeInfo typeInfo,
+ GenericUDAFEvaluator evaluator, TypeInfo outputTypeInfo,
+ GenericUDAFEvaluator.Mode udafEvaluatorMode, int maxKeyCount,
+ List<String> columns, List<ExprNodeDesc> children,
+ Object[][] randomRows, ObjectInspector rowInspector,
+ Object[] results)
+ throws Exception {
+
+ /*
+ System.out.println(
+ "*DEBUG* typeInfo " + typeInfo.toString() +
+ " aggregationTestMode ROW_MODE" +
+ " outputTypeInfo " + outputTypeInfo.toString());
+ */
+
+ // Last entry is for a NULL key.
+ AggregationBuffer[] aggregationBuffers = new AggregationBuffer[maxKeyCount + 1];
+
+ ObjectInspector objectInspector = TypeInfoUtils
+ .getStandardWritableObjectInspectorFromTypeInfo(outputTypeInfo);
+
+ Object[] parameterArray = new Object[1];
+ final int rowCount = randomRows.length;
+ for (int i = 0; i < rowCount; i++) {
+ Object[] row = randomRows[i];
+ ShortWritable shortWritable = (ShortWritable) row[0];
+
+ final int key;
+ if (shortWritable == null) {
+ key = maxKeyCount;
+ } else {
+ key = shortWritable.get();
+ }
+ AggregationBuffer aggregationBuffer = aggregationBuffers[key];
+ if (aggregationBuffer == null) {
+ aggregationBuffer = evaluator.getNewAggregationBuffer();
+ aggregationBuffers[key] = aggregationBuffer;
+ }
+ parameterArray[0] = row[1];
+ evaluator.aggregate(aggregationBuffer, parameterArray);
+ }
+
+ final boolean isPrimitive = (outputTypeInfo instanceof PrimitiveTypeInfo);
+ final boolean isPartial =
+ (udafEvaluatorMode == GenericUDAFEvaluator.Mode.PARTIAL1 ||
+ udafEvaluatorMode == GenericUDAFEvaluator.Mode.PARTIAL2);
+
+ for (short key = 0; key < maxKeyCount + 1; key++) {
+ AggregationBuffer aggregationBuffer = aggregationBuffers[key];
+ if (aggregationBuffer != null) {
+ final Object result;
+ if (isPartial) {
+ result = evaluator.terminatePartial(aggregationBuffer);
+ } else {
+ result = evaluator.terminate(aggregationBuffer);
+ }
+ Object copyResult;
+ if (result == null) {
+ copyResult = null;
+ } else if (isPrimitive) {
+ copyResult =
+ VectorRandomRowSource.getWritablePrimitiveObject(
+ (PrimitiveTypeInfo) outputTypeInfo, objectInspector, result);
+ } else {
+ copyResult =
+ ObjectInspectorUtils.copyToStandardObject(
+ result, objectInspector, ObjectInspectorCopyOption.WRITABLE);
+ }
+ results[key] = copyResult;
+ }
+ }
+
+ return true;
+ }
+
+ private static void extractResultObjects(VectorizedRowBatch outputBatch, short[] keys,
+ VectorExtractRow resultVectorExtractRow, TypeInfo outputTypeInfo, Object[] scrqtchRow,
+ Object[] results) {
+
+ final boolean isPrimitive = (outputTypeInfo instanceof PrimitiveTypeInfo);
+ ObjectInspector objectInspector;
+ if (isPrimitive) {
+ objectInspector = TypeInfoUtils
+ .getStandardWritableObjectInspectorFromTypeInfo(outputTypeInfo);
+ } else {
+ objectInspector = null;
+ }
+
+ for (int batchIndex = 0; batchIndex < outputBatch.size; batchIndex++) {
+ resultVectorExtractRow.extractRow(outputBatch, batchIndex, scrqtchRow);
+ if (isPrimitive) {
+ Object copyResult =
+ ObjectInspectorUtils.copyToStandardObject(
+ scrqtchRow[0], objectInspector, ObjectInspectorCopyOption.WRITABLE);
+ results[keys[batchIndex]] = copyResult;
+ } else {
+ results[keys[batchIndex]] = scrqtchRow[0];
+ }
+ }
+ }
+
+ protected static boolean doVectorTest(String aggregationName, TypeInfo typeInfo,
+ GenericUDAFEvaluator evaluator, TypeInfo outputTypeInfo,
+ GenericUDAFEvaluator.Mode udafEvaluatorMode, int maxKeyCount,
+ List<String> columns, String[] columnNames,
+ TypeInfo[] typeInfos, DataTypePhysicalVariation[] dataTypePhysicalVariations,
+ List<ExprNodeDesc> parameterList,
+ VectorRandomBatchSource batchSource,
+ Object[] results)
+ throws Exception {
+
+ HiveConf hiveConf = new HiveConf();
+
+ VectorizationContext vectorizationContext =
+ new VectorizationContext(
+ "name",
+ columns,
+ Arrays.asList(typeInfos),
+ Arrays.asList(dataTypePhysicalVariations),
+ hiveConf);
+
+ ImmutablePair<VectorAggregationDesc,String> pair =
+ Vectorizer.getVectorAggregationDesc(
+ aggregationName,
+ parameterList,
+ evaluator,
+ outputTypeInfo,
+ udafEvaluatorMode,
+ vectorizationContext);
+ VectorAggregationDesc vecAggrDesc = pair.left;
+ if (vecAggrDesc == null) {
+ Assert.fail(
+ "No vector aggregation expression found for aggregationName " + aggregationName +
+ " udafEvaluatorMode " + udafEvaluatorMode +
+ " parameterList " + parameterList +
+ " outputTypeInfo " + outputTypeInfo);
+ }
+
+ Class<? extends VectorAggregateExpression> vecAggrClass = vecAggrDesc.getVecAggrClass();
+
+ Constructor<? extends VectorAggregateExpression> ctor = null;
+ try {
+ ctor = vecAggrClass.getConstructor(VectorAggregationDesc.class);
+ } catch (Exception e) {
+ throw new HiveException("Constructor " + vecAggrClass.getSimpleName() +
+ "(VectorAggregationDesc) not available");
+ }
+ VectorAggregateExpression vecAggrExpr = null;
+ try {
+ vecAggrExpr = ctor.newInstance(vecAggrDesc);
+ } catch (Exception e) {
+
+ throw new HiveException("Failed to create " + vecAggrClass.getSimpleName() +
+ "(VectorAggregationDesc) object ", e);
+ }
+ VectorExpression.doTransientInit(vecAggrExpr.getInputExpression());
+
+ /*
+ System.out.println(
+ "*DEBUG* typeInfo " + typeInfo.toString() +
+ " aggregationTestMode VECTOR_MODE" +
+ " vecAggrExpr " + vecAggrExpr.getClass().getSimpleName());
+ */
+
+ VectorRandomRowSource rowSource = batchSource.getRowSource();
+ VectorizedRowBatchCtx batchContext =
+ new VectorizedRowBatchCtx(
+ columnNames,
+ rowSource.typeInfos(),
+ rowSource.dataTypePhysicalVariations(),
+ /* dataColumnNums */ null,
+ /* partitionColumnCount */ 0,
+ /* virtualColumnCount */ 0,
+ /* neededVirtualColumns */ null,
+ vectorizationContext.getScratchColumnTypeNames(),
+ vectorizationContext.getScratchDataTypePhysicalVariations());
+
+ VectorizedRowBatch batch = batchContext.createVectorizedRowBatch();
+
+ // Last entry is for a NULL key.
+ VectorAggregationBufferRow[] vectorAggregationBufferRows =
+ new VectorAggregationBufferRow[maxKeyCount + 1];
+
+ VectorAggregationBufferRow[] batchBufferRows;
+
+ batchSource.resetBatchIteration();
+ int rowIndex = 0;
+ while (true) {
+ if (!batchSource.fillNextBatch(batch)) {
+ break;
+ }
+ LongColumnVector keyLongColVector = (LongColumnVector) batch.cols[0];
+
+ batchBufferRows =
+ new VectorAggregationBufferRow[VectorizedRowBatch.DEFAULT_SIZE];
+
+ final int size = batch.size;
+ boolean selectedInUse = batch.selectedInUse;
+ int[] selected = batch.selected;
+ for (int logical = 0; logical < size; logical++) {
+ final int batchIndex = (selectedInUse ? selected[logical] : logical);
+ final int keyAdjustedBatchIndex;
+ if (keyLongColVector.isRepeating) {
+ keyAdjustedBatchIndex = 0;
+ } else {
+ keyAdjustedBatchIndex = batchIndex;
+ }
+ final short key;
+ if (keyLongColVector.noNulls || !keyLongColVector.isNull[keyAdjustedBatchIndex]) {
+ key = (short) keyLongColVector.vector[keyAdjustedBatchIndex];
+ } else {
+ key = (short) maxKeyCount;
+ }
+
+ VectorAggregationBufferRow bufferRow = vectorAggregationBufferRows[key];
+ if (bufferRow == null) {
+ VectorAggregateExpression.AggregationBuffer aggregationBuffer =
+ vecAggrExpr.getNewAggregationBuffer();
+ aggregationBuffer.reset();
+ VectorAggregateExpression.AggregationBuffer[] aggregationBuffers =
+ new VectorAggregateExpression.AggregationBuffer[] { aggregationBuffer };
+ bufferRow = new VectorAggregationBufferRow(aggregationBuffers);
+ vectorAggregationBufferRows[key] = bufferRow;
+ }
+ batchBufferRows[logical] = bufferRow;
+ }
+
+ vecAggrExpr.aggregateInputSelection(
+ batchBufferRows,
+ 0,
+ batch);
+
+ rowIndex += batch.size;
+ }
+
+ String[] outputColumnNames = new String[] { "output" };
+
+ TypeInfo[] outputTypeInfos = new TypeInfo[] { outputTypeInfo };
+ VectorizedRowBatchCtx outputBatchContext =
+ new VectorizedRowBatchCtx(
+ outputColumnNames,
+ outputTypeInfos,
+ null,
+ /* dataColumnNums */ null,
+ /* partitionColumnCount */ 0,
+ /* virtualColumnCount */ 0,
+ /* neededVirtualColumns */ null,
+ new String[0],
+ new DataTypePhysicalVariation[0]);
+
+ VectorizedRowBatch outputBatch = outputBatchContext.createVectorizedRowBatch();
+
+ short[] keys = new short[VectorizedRowBatch.DEFAULT_SIZE];
+
+ VectorExtractRow resultVectorExtractRow = new VectorExtractRow();
+ resultVectorExtractRow.init(
+ new TypeInfo[] { outputTypeInfo }, new int[] { 0 });
+ Object[] scrqtchRow = new Object[1];
+
+ for (short key = 0; key < maxKeyCount + 1; key++) {
+ VectorAggregationBufferRow vectorAggregationBufferRow = vectorAggregationBufferRows[key];
+ if (vectorAggregationBufferRow != null) {
+ if (outputBatch.size == VectorizedRowBatch.DEFAULT_SIZE) {
+ extractResultObjects(outputBatch, keys, resultVectorExtractRow, outputTypeInfo,
+ scrqtchRow, results);
+ outputBatch.reset();
+ }
+ keys[outputBatch.size] = key;
+ VectorAggregateExpression.AggregationBuffer aggregationBuffer =
+ vectorAggregationBufferRow.getAggregationBuffer(0);
+ vecAggrExpr.assignRowColumn(outputBatch, outputBatch.size++, 0, aggregationBuffer);
+ }
+ }
+ if (outputBatch.size > 0) {
+ extractResultObjects(outputBatch, keys, resultVectorExtractRow, outputTypeInfo,
+ scrqtchRow, results);
+ }
+
+ return true;
+ }
+
+ private boolean compareObjects(Object object1, Object object2, TypeInfo typeInfo,
+ ObjectInspector objectInspector) {
+ if (typeInfo instanceof PrimitiveTypeInfo) {
+ return
+ VectorRandomRowSource.getWritablePrimitiveObject(
+ (PrimitiveTypeInfo) typeInfo, objectInspector, object1).equals(
+ VectorRandomRowSource.getWritablePrimitiveObject(
+ (PrimitiveTypeInfo) typeInfo, objectInspector, object2));
+ } else {
+ return object1.equals(object2);
+ }
+ }
+
+ protected void executeAggregationTests(String aggregationName, TypeInfo typeInfo,
+ GenericUDAFEvaluator evaluator,
+ TypeInfo outputTypeInfo, GenericUDAFEvaluator.Mode udafEvaluatorMode,
+ int maxKeyCount, List<String> columns, String[] columnNames,
+ List<ExprNodeDesc> parameters, Object[][] randomRows,
+ VectorRandomRowSource rowSource, VectorRandomBatchSource batchSource,
+ Object[] resultsArray)
+ throws Exception {
+
+ for (int i = 0; i < AggregationTestMode.count; i++) {
+
+ // Last entry is for a NULL key.
+ Object[] results = new Object[maxKeyCount + 1];
+ resultsArray[i] = results;
+
+ AggregationTestMode aggregationTestMode = AggregationTestMode.values()[i];
+ switch (aggregationTestMode) {
+ case ROW_MODE:
+ if (!doRowTest(
+ typeInfo,
+ evaluator,
+ outputTypeInfo,
+ udafEvaluatorMode,
+ maxKeyCount,
+ columns,
+ parameters,
+ randomRows,
+ rowSource.rowStructObjectInspector(),
+ results)) {
+ return;
+ }
+ break;
+ case VECTOR_EXPRESSION:
+ if (!doVectorTest(
+ aggregationName,
+ typeInfo,
+ evaluator,
+ outputTypeInfo,
+ udafEvaluatorMode,
+ maxKeyCount,
+ columns,
+ columnNames,
+ rowSource.typeInfos(),
+ rowSource.dataTypePhysicalVariations(),
+ parameters,
+ batchSource,
+ results)) {
+ return;
+ }
+ break;
+ default:
+ throw new RuntimeException(
+ "Unexpected Hash Aggregation test mode " + aggregationTestMode);
+ }
+ }
+ }
+
+ protected void verifyAggregationResults(TypeInfo typeInfo, TypeInfo outputTypeInfo,
+ int maxKeyCount, GenericUDAFEvaluator.Mode udafEvaluatorMode,
+ Object[] resultsArray) {
+
+ // Row-mode is the expected results.
+ Object[] expectedResults = (Object[]) resultsArray[0];
+
+ ObjectInspector objectInspector = TypeInfoUtils
+ .getStandardWritableObjectInspectorFromTypeInfo(outputTypeInfo);
+
+ for (int v = 1; v < AggregationTestMode.count; v++) {
+ Object[] vectorResults = (Object[]) resultsArray[v];
+
+ for (short key = 0; key < maxKeyCount + 1; key++) {
+ Object expectedResult = expectedResults[key];
+ Object vectorResult = vectorResults[key];
+ if (expectedResult == null || vectorResult == null) {
+ if (expectedResult != null || vectorResult != null) {
+ Assert.fail(
+ "Key " + key +
+ " typeName " + typeInfo.getTypeName() +
+ " outputTypeName " + outputTypeInfo.getTypeName() +
+ " " + AggregationTestMode.values()[v] +
+ " result is NULL " + (vectorResult == null ? "YES" : "NO result " + vectorResult.toString()) +
+ " does not match row-mode expected result is NULL " +
+ (expectedResult == null ? "YES" : "NO result " + expectedResult.toString()) +
+ " udafEvaluatorMode " + udafEvaluatorMode);
+ }
+ } else {
+ if (!compareObjects(expectedResult, vectorResult, outputTypeInfo, objectInspector)) {
+ Assert.fail(
+ "Key " + key +
+ " typeName " + typeInfo.getTypeName() +
+ " outputTypeName " + outputTypeInfo.getTypeName() +
+ " " + AggregationTestMode.values()[v] +
+ " result " + vectorResult.toString() +
+ " (" + vectorResult.getClass().getSimpleName() + ")" +
+ " does not match row-mode expected result " + expectedResult.toString() +
+ " (" + expectedResult.getClass().getSimpleName() + ")" +
+ " udafEvaluatorMode " + udafEvaluatorMode);
+ }
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/hive/blob/0966a383/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/aggregation/TestVectorAggregation.java
----------------------------------------------------------------------
diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/aggregation/TestVectorAggregation.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/aggregation/TestVectorAggregation.java
new file mode 100644
index 0000000..c5f0483
--- /dev/null
+++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/aggregation/TestVectorAggregation.java
@@ -0,0 +1,664 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.vector.aggregation;
+
+import java.lang.reflect.Constructor;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.sql.Timestamp;
+
+import org.apache.hadoop.hive.common.type.DataTypePhysicalVariation;
+import org.apache.hadoop.hive.ql.exec.vector.VectorRandomBatchSource;
+import org.apache.hadoop.hive.ql.exec.vector.VectorRandomRowSource;
+import org.apache.hadoop.hive.ql.exec.vector.VectorRandomRowSource.GenerationSpec;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
+import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
+import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFVariance;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableShortObjectInspector;
+import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
+import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo;
+import org.apache.hadoop.hive.serde2.io.ShortWritable;
+
+import junit.framework.Assert;
+
+import org.junit.Ignore;
+import org.junit.Test;
+
+public class TestVectorAggregation extends AggregationBase {
+
+ @Test
+ public void testAvgIntegers() throws Exception {
+ Random random = new Random(7743);
+
+ doIntegerTests("avg", random);
+ }
+
+ @Test
+ public void testAvgFloating() throws Exception {
+ Random random = new Random(7743);
+
+ doFloatingTests("avg", random);
+ }
+
+ @Test
+ public void testAvgDecimal() throws Exception {
+ Random random = new Random(7743);
+
+ doDecimalTests("avg", random);
+ }
+
+ @Test
+ public void testAvgTimestamp() throws Exception {
+ Random random = new Random(7743);
+
+ doTests(
+ random, "avg", TypeInfoFactory.timestampTypeInfo);
+ }
+
+ @Test
+ public void testCount() throws Exception {
+ Random random = new Random(7743);
+
+ doTests(
+ random, "count", TypeInfoFactory.shortTypeInfo);
+ doTests(
+ random, "count", TypeInfoFactory.longTypeInfo);
+ doTests(
+ random, "count", TypeInfoFactory.doubleTypeInfo);
+ doTests(
+ random, "count", new DecimalTypeInfo(18, 10));
+ doTests(
+ random, "count", TypeInfoFactory.stringTypeInfo);
+ }
+
+ @Test
+ public void testMax() throws Exception {
+ Random random = new Random(7743);
+
+ doIntegerTests("max", random);
+ doFloatingTests("max", random);
+ doDecimalTests("max", random);
+
+ doTests(
+ random, "max", TypeInfoFactory.timestampTypeInfo);
+ doTests(
+ random, "max", TypeInfoFactory.intervalDayTimeTypeInfo);
+
+ doStringFamilyTests("max", random);
+ }
+
+ @Test
+ public void testMin() throws Exception {
+ Random random = new Random(7743);
+
+ doIntegerTests("min", random);
+ doFloatingTests("min", random);
+ doDecimalTests("min", random);
+
+ doTests(
+ random, "min", TypeInfoFactory.timestampTypeInfo);
+ doTests(
+ random, "min", TypeInfoFactory.intervalDayTimeTypeInfo);
+
+ doStringFamilyTests("min", random);
+ }
+
+ @Test
+ public void testSum() throws Exception {
+ Random random = new Random(7743);
+
+ doTests(
+ random, "sum", TypeInfoFactory.shortTypeInfo);
+ doTests(
+ random, "sum", TypeInfoFactory.longTypeInfo);
+ doTests(
+ random, "sum", TypeInfoFactory.doubleTypeInfo);
+
+ doDecimalTests("sum", random);
+ }
+
+ private final static Set<String> varianceNames =
+ GenericUDAFVariance.VarianceKind.nameMap.keySet();
+
+ @Test
+ public void testVarianceIntegers() throws Exception {
+ Random random = new Random(7743);
+
+ for (String aggregationName : varianceNames) {
+ doIntegerTests(aggregationName, random);
+ }
+ }
+
+ @Test
+ public void testVarianceFloating() throws Exception {
+ Random random = new Random(7743);
+
+ for (String aggregationName : varianceNames) {
+ doFloatingTests(aggregationName, random);
+ }
+ }
+
+ @Test
+ public void testVarianceDecimal() throws Exception {
+ Random random = new Random(7743);
+
+ for (String aggregationName : varianceNames) {
+ doDecimalTests(aggregationName, random);
+ }
+ }
+
+ private static TypeInfo[] integerTypeInfos = new TypeInfo[] {
+ TypeInfoFactory.byteTypeInfo,
+ TypeInfoFactory.shortTypeInfo,
+ TypeInfoFactory.intTypeInfo,
+ TypeInfoFactory.longTypeInfo
+ };
+
+ // We have test failures with FLOAT. Ignoring this issue for now.
+ private static TypeInfo[] floatingTypeInfos = new TypeInfo[] {
+ // TypeInfoFactory.floatTypeInfo,
+ TypeInfoFactory.doubleTypeInfo
+ };
+
+ private void doIntegerTests(String aggregationName, Random random)
+ throws Exception {
+ for (TypeInfo typeInfo : integerTypeInfos) {
+ doTests(
+ random, aggregationName, typeInfo);
+ }
+ }
+
+ private void doFloatingTests(String aggregationName, Random random)
+ throws Exception {
+ for (TypeInfo typeInfo : floatingTypeInfos) {
+ doTests(
+ random, aggregationName, typeInfo);
+ }
+ }
+
+ private static TypeInfo[] decimalTypeInfos = new TypeInfo[] {
+ new DecimalTypeInfo(38, 18),
+ new DecimalTypeInfo(25, 2),
+ new DecimalTypeInfo(19, 4),
+ new DecimalTypeInfo(18, 10),
+ new DecimalTypeInfo(17, 3),
+ new DecimalTypeInfo(12, 2),
+ new DecimalTypeInfo(7, 1)
+ };
+
+ private void doDecimalTests(String aggregationName, Random random)
+ throws Exception {
+ for (TypeInfo typeInfo : decimalTypeInfos) {
+ doTests(
+ random, aggregationName, typeInfo);
+ }
+ }
+
+ private static TypeInfo[] stringFamilyTypeInfos = new TypeInfo[] {
+ TypeInfoFactory.stringTypeInfo,
+ new CharTypeInfo(25),
+ new CharTypeInfo(10),
+ new VarcharTypeInfo(20),
+ new VarcharTypeInfo(15)
+ };
+
+ private void doStringFamilyTests(String aggregationName, Random random)
+ throws Exception {
+ for (TypeInfo typeInfo : stringFamilyTypeInfos) {
+ doTests(
+ random, aggregationName, typeInfo);
+ }
+ }
+
+ public static int getLinearRandomNumber(Random random, int maxSize) {
+ //Get a linearly multiplied random number
+ int randomMultiplier = maxSize * (maxSize + 1) / 2;
+ int randomInt = random.nextInt(randomMultiplier);
+
+ //Linearly iterate through the possible values to find the correct one
+ int linearRandomNumber = 0;
+ for(int i=maxSize; randomInt >= 0; i--){
+ randomInt -= i;
+ linearRandomNumber++;
+ }
+
+ return linearRandomNumber;
+ }
+
+ private static final int TEST_ROW_COUNT = 100000;
+
+ private void doMerge(GenericUDAFEvaluator.Mode mergeUdafEvaluatorMode,
+ Random random,
+ String aggregationName,
+ TypeInfo typeInfo,
+ GenerationSpec keyGenerationSpec,
+ List<String> columns, String[] columnNames,
+ int dataAggrMaxKeyCount, int reductionFactor,
+ TypeInfo partial1OutputTypeInfo,
+ Object[] partial1ResultsArray)
+ throws Exception {
+
+ List<GenerationSpec> mergeAggrGenerationSpecList = new ArrayList<GenerationSpec>();
+ List<DataTypePhysicalVariation> mergeDataTypePhysicalVariationList =
+ new ArrayList<DataTypePhysicalVariation>();
+
+ mergeAggrGenerationSpecList.add(keyGenerationSpec);
+ mergeDataTypePhysicalVariationList.add(DataTypePhysicalVariation.NONE);
+
+ // Use OMIT for both. We will fill in the data from the PARTIAL1 results.
+ GenerationSpec mergeGenerationSpec =
+ GenerationSpec.createOmitGeneration(partial1OutputTypeInfo);
+ mergeAggrGenerationSpecList.add(mergeGenerationSpec);
+ mergeDataTypePhysicalVariationList.add(DataTypePhysicalVariation.NONE);
+
+ ExprNodeColumnDesc mergeCol1Expr =
+ new ExprNodeColumnDesc(partial1OutputTypeInfo, "col1", "table", false);
+ List<ExprNodeDesc> mergeParameters = new ArrayList<ExprNodeDesc>();
+ mergeParameters.add(mergeCol1Expr);
+ final int mergeParameterCount = mergeParameters.size();
+ ObjectInspector[] mergeParameterObjectInspectors =
+ new ObjectInspector[mergeParameterCount];
+ for (int i = 0; i < mergeParameterCount; i++) {
+ TypeInfo paramTypeInfo = mergeParameters.get(i).getTypeInfo();
+ mergeParameterObjectInspectors[i] = TypeInfoUtils
+ .getStandardWritableObjectInspectorFromTypeInfo(paramTypeInfo);
+ }
+
+ VectorRandomRowSource mergeRowSource = new VectorRandomRowSource();
+
+ mergeRowSource.initGenerationSpecSchema(
+ random, mergeAggrGenerationSpecList, /* maxComplexDepth */ 0, /* allowNull */ false,
+ mergeDataTypePhysicalVariationList);
+
+ Object[][] mergeRandomRows = mergeRowSource.randomRows(TEST_ROW_COUNT);
+
+ // Reduce the key range to cause there to be work for each PARTIAL2 key.
+ final int mergeMaxKeyCount = dataAggrMaxKeyCount / reductionFactor;
+
+ Object[] partial1Results = (Object[]) partial1ResultsArray[0];
+
+ short partial1Key = 0;
+ for (int i = 0; i < mergeRandomRows.length; i++) {
+ // Find a non-NULL entry...
+ while (true) {
+ if (partial1Key >= dataAggrMaxKeyCount) {
+ partial1Key = 0;
+ }
+ if (partial1Results[partial1Key] != null) {
+ break;
+ }
+ partial1Key++;
+ }
+ final short mergeKey = (short) (partial1Key % mergeMaxKeyCount);
+ mergeRandomRows[i][0] = new ShortWritable(mergeKey);
+ mergeRandomRows[i][1] = partial1Results[partial1Key];
+ partial1Key++;
+ }
+
+ VectorRandomBatchSource mergeBatchSource =
+ VectorRandomBatchSource.createInterestingBatches(
+ random,
+ mergeRowSource,
+ mergeRandomRows,
+ null);
+
+ // We need to pass the original TypeInfo in for initializing the evaluator.
+ GenericUDAFEvaluator mergeEvaluator =
+ getEvaluator(aggregationName, typeInfo);
+
+ /*
+ System.out.println(
+ "*DEBUG* GenericUDAFEvaluator for " + aggregationName + ", " + typeInfo.getTypeName() + ": " +
+ mergeEvaluator.getClass().getSimpleName());
+ */
+
+ // The only way to get the return object inspector (and its return type) is to
+ // initialize it...
+
+ ObjectInspector mergeReturnOI =
+ mergeEvaluator.init(
+ mergeUdafEvaluatorMode,
+ mergeParameterObjectInspectors);
+ TypeInfo mergeOutputTypeInfo =
+ TypeInfoUtils.getTypeInfoFromObjectInspector(mergeReturnOI);
+
+ Object[] mergeResultsArray = new Object[AggregationTestMode.count];
+
+ executeAggregationTests(
+ aggregationName,
+ partial1OutputTypeInfo,
+ mergeEvaluator,
+ mergeOutputTypeInfo,
+ mergeUdafEvaluatorMode,
+ mergeMaxKeyCount,
+ columns,
+ columnNames,
+ mergeParameters,
+ mergeRandomRows,
+ mergeRowSource,
+ mergeBatchSource,
+ mergeResultsArray);
+
+ verifyAggregationResults(
+ partial1OutputTypeInfo,
+ mergeOutputTypeInfo,
+ mergeMaxKeyCount,
+ mergeUdafEvaluatorMode,
+ mergeResultsArray);
+ }
+
+ private void doTests(Random random, String aggregationName, TypeInfo typeInfo)
+ throws Exception {
+
+ List<GenerationSpec> dataAggrGenerationSpecList = new ArrayList<GenerationSpec>();
+ List<DataTypePhysicalVariation> explicitDataTypePhysicalVariationList =
+ new ArrayList<DataTypePhysicalVariation>();
+
+ TypeInfo keyTypeInfo = TypeInfoFactory.shortTypeInfo;
+ GenerationSpec keyGenerationSpec = GenerationSpec.createOmitGeneration(keyTypeInfo);
+ dataAggrGenerationSpecList.add(keyGenerationSpec);
+ explicitDataTypePhysicalVariationList.add(DataTypePhysicalVariation.NONE);
+
+ GenerationSpec generationSpec = GenerationSpec.createSameType(typeInfo);
+ dataAggrGenerationSpecList.add(generationSpec);
+ explicitDataTypePhysicalVariationList.add(DataTypePhysicalVariation.NONE);
+
+ List<String> columns = new ArrayList<String>();
+ columns.add("col0");
+ columns.add("col1");
+
+ ExprNodeColumnDesc dataAggrCol1Expr = new ExprNodeColumnDesc(typeInfo, "col1", "table", false);
+ List<ExprNodeDesc> dataAggrParameters = new ArrayList<ExprNodeDesc>();
+ dataAggrParameters.add(dataAggrCol1Expr);
+ final int dataAggrParameterCount = dataAggrParameters.size();
+ ObjectInspector[] dataAggrParameterObjectInspectors = new ObjectInspector[dataAggrParameterCount];
+ for (int i = 0; i < dataAggrParameterCount; i++) {
+ TypeInfo paramTypeInfo = dataAggrParameters.get(i).getTypeInfo();
+ dataAggrParameterObjectInspectors[i] = TypeInfoUtils
+ .getStandardWritableObjectInspectorFromTypeInfo(paramTypeInfo);
+ }
+
+ String[] columnNames = columns.toArray(new String[0]);
+
+ final int dataAggrMaxKeyCount = 20000;
+ final int reductionFactor = 16;
+
+ ObjectInspector keyObjectInspector = VectorRandomRowSource.getObjectInspector(keyTypeInfo);
+
+ /*
+ * PARTIAL1.
+ */
+
+ VectorRandomRowSource partial1RowSource = new VectorRandomRowSource();
+
+ partial1RowSource.initGenerationSpecSchema(
+ random, dataAggrGenerationSpecList, /* maxComplexDepth */ 0, /* allowNull */ true,
+ explicitDataTypePhysicalVariationList);
+
+ Object[][] partial1RandomRows = partial1RowSource.randomRows(TEST_ROW_COUNT);
+
+ final int partial1RowCount = partial1RandomRows.length;
+ for (int i = 0; i < partial1RowCount; i++) {
+ final short shortKey = (short) getLinearRandomNumber(random, dataAggrMaxKeyCount);
+ partial1RandomRows[i][0] =
+ ((WritableShortObjectInspector) keyObjectInspector).create((short) shortKey);
+ }
+
+ VectorRandomBatchSource partial1BatchSource =
+ VectorRandomBatchSource.createInterestingBatches(
+ random,
+ partial1RowSource,
+ partial1RandomRows,
+ null);
+
+ GenericUDAFEvaluator partial1Evaluator = getEvaluator(aggregationName, typeInfo);
+
+ /*
+ System.out.println(
+ "*DEBUG* GenericUDAFEvaluator for " + aggregationName + ", " + typeInfo.getTypeName() + ": " +
+ partial1Evaluator.getClass().getSimpleName());
+ */
+
+ // The only way to get the return object inspector (and its return type) is to
+ // initialize it...
+ final GenericUDAFEvaluator.Mode partial1UdafEvaluatorMode = GenericUDAFEvaluator.Mode.PARTIAL1;
+ ObjectInspector partial1ReturnOI =
+ partial1Evaluator.init(
+ partial1UdafEvaluatorMode,
+ dataAggrParameterObjectInspectors);
+ TypeInfo partial1OutputTypeInfo =
+ TypeInfoUtils.getTypeInfoFromObjectInspector(partial1ReturnOI);
+
+ Object[] partial1ResultsArray = new Object[AggregationTestMode.count];
+
+ executeAggregationTests(
+ aggregationName,
+ typeInfo,
+ partial1Evaluator,
+ partial1OutputTypeInfo,
+ partial1UdafEvaluatorMode,
+ dataAggrMaxKeyCount,
+ columns,
+ columnNames,
+ dataAggrParameters,
+ partial1RandomRows,
+ partial1RowSource,
+ partial1BatchSource,
+ partial1ResultsArray);
+
+ verifyAggregationResults(
+ typeInfo,
+ partial1OutputTypeInfo,
+ dataAggrMaxKeyCount,
+ partial1UdafEvaluatorMode,
+ partial1ResultsArray);
+
+ final boolean hasDifferentCompleteExpr;
+ if (varianceNames.contains(aggregationName)) {
+ hasDifferentCompleteExpr = true;
+ } else {
+ switch (aggregationName) {
+ case "avg":
+ /*
+ if (typeInfo instanceof DecimalTypeInfo) {
+ // UNDONE: Row-mode GenericUDAFAverage does not call enforcePrecisionScale...
+ hasDifferentCompleteExpr = false;
+ } else {
+ hasDifferentCompleteExpr = true;
+ }
+ */
+ hasDifferentCompleteExpr = true;
+ break;
+ case "count":
+ case "max":
+ case "min":
+ case "sum":
+ hasDifferentCompleteExpr = false;
+ break;
+ default:
+ throw new RuntimeException("Unexpected aggregation name " + aggregationName);
+ }
+ }
+
+ if (hasDifferentCompleteExpr) {
+
+ /*
+ * COMPLETE.
+ */
+
+ VectorRandomRowSource completeRowSource = new VectorRandomRowSource();
+
+ completeRowSource.initGenerationSpecSchema(
+ random, dataAggrGenerationSpecList, /* maxComplexDepth */ 0, /* allowNull */ true,
+ explicitDataTypePhysicalVariationList);
+
+ Object[][] completeRandomRows = completeRowSource.randomRows(TEST_ROW_COUNT);
+
+ final int completeRowCount = completeRandomRows.length;
+ for (int i = 0; i < completeRowCount; i++) {
+ final short shortKey = (short) getLinearRandomNumber(random, dataAggrMaxKeyCount);
+ completeRandomRows[i][0] =
+ ((WritableShortObjectInspector) keyObjectInspector).create((short) shortKey);
+ }
+
+ VectorRandomBatchSource completeBatchSource =
+ VectorRandomBatchSource.createInterestingBatches(
+ random,
+ completeRowSource,
+ completeRandomRows,
+ null);
+
+ GenericUDAFEvaluator completeEvaluator = getEvaluator(aggregationName, typeInfo);
+
+ /*
+ System.out.println(
+ "*DEBUG* GenericUDAFEvaluator for " + aggregationName + ", " + typeInfo.getTypeName() + ": " +
+ completeEvaluator.getClass().getSimpleName());
+ */
+
+ // The only way to get the return object inspector (and its return type) is to
+ // initialize it...
+ final GenericUDAFEvaluator.Mode completeUdafEvaluatorMode = GenericUDAFEvaluator.Mode.COMPLETE;
+ ObjectInspector completeReturnOI =
+ completeEvaluator.init(
+ completeUdafEvaluatorMode,
+ dataAggrParameterObjectInspectors);
+ TypeInfo completeOutputTypeInfo =
+ TypeInfoUtils.getTypeInfoFromObjectInspector(completeReturnOI);
+
+ Object[] completeResultsArray = new Object[AggregationTestMode.count];
+
+ executeAggregationTests(
+ aggregationName,
+ typeInfo,
+ completeEvaluator,
+ completeOutputTypeInfo,
+ completeUdafEvaluatorMode,
+ dataAggrMaxKeyCount,
+ columns,
+ columnNames,
+ dataAggrParameters,
+ completeRandomRows,
+ completeRowSource,
+ completeBatchSource,
+ completeResultsArray);
+
+ verifyAggregationResults(
+ typeInfo,
+ completeOutputTypeInfo,
+ dataAggrMaxKeyCount,
+ completeUdafEvaluatorMode,
+ completeResultsArray);
+ }
+
+ final boolean hasDifferentPartial2Expr;
+ if (varianceNames.contains(aggregationName)) {
+ hasDifferentPartial2Expr = true;
+ } else {
+ switch (aggregationName) {
+ case "avg":
+ hasDifferentPartial2Expr = true;
+ break;
+ case "count":
+ case "max":
+ case "min":
+ case "sum":
+ hasDifferentPartial2Expr = false;
+ break;
+ default:
+ throw new RuntimeException("Unexpected aggregation name " + aggregationName);
+ }
+ }
+
+ if (hasDifferentPartial2Expr && false) {
+
+ /*
+ * PARTIAL2.
+ */
+
+ final GenericUDAFEvaluator.Mode mergeUdafEvaluatorMode = GenericUDAFEvaluator.Mode.PARTIAL2;
+
+ doMerge(
+ mergeUdafEvaluatorMode,
+ random,
+ aggregationName,
+ typeInfo,
+ keyGenerationSpec,
+ columns, columnNames,
+ dataAggrMaxKeyCount,
+ reductionFactor,
+ partial1OutputTypeInfo,
+ partial1ResultsArray);
+ }
+
+ final boolean hasDifferentFinalExpr;
+ if (varianceNames.contains(aggregationName)) {
+ hasDifferentFinalExpr = true;
+ } else {
+ switch (aggregationName) {
+ case "avg":
+ hasDifferentFinalExpr = true;
+ break;
+ case "count":
+ hasDifferentFinalExpr = true;
+ break;
+ case "max":
+ case "min":
+ case "sum":
+ hasDifferentFinalExpr = false;
+ break;
+ default:
+ throw new RuntimeException("Unexpected aggregation name " + aggregationName);
+ }
+ }
+ if (hasDifferentFinalExpr) {
+
+ /*
+ * FINAL.
+ */
+
+ final GenericUDAFEvaluator.Mode mergeUdafEvaluatorMode = GenericUDAFEvaluator.Mode.FINAL;
+
+ doMerge(
+ mergeUdafEvaluatorMode,
+ random,
+ aggregationName,
+ typeInfo,
+ keyGenerationSpec,
+ columns, columnNames,
+ dataAggrMaxKeyCount,
+ reductionFactor,
+ partial1OutputTypeInfo,
+ partial1ResultsArray);
+ }
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/hive/blob/0966a383/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateAddSub.java
----------------------------------------------------------------------
diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateAddSub.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateAddSub.java
index f5deca5..c4146be 100644
--- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateAddSub.java
+++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateAddSub.java
@@ -370,6 +370,7 @@ public class TestVectorDateAddSub {
Object[][] randomRows, ColumnScalarMode columnScalarMode,
ObjectInspector rowInspector, Object[] resultObjects) throws Exception {
+ /*
System.out.println(
"*DEBUG* dateTimeStringTypeInfo " + dateTimeStringTypeInfo.toString() +
" integerTypeInfo " + integerTypeInfo +
@@ -377,6 +378,7 @@ public class TestVectorDateAddSub {
" dateAddSubTestMode ROW_MODE" +
" columnScalarMode " + columnScalarMode +
" exprDesc " + exprDesc.toString());
+ */
HiveConf hiveConf = new HiveConf();
ExprNodeEvaluator evaluator =
http://git-wip-us.apache.org/repos/asf/hive/blob/0966a383/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateDiff.java
----------------------------------------------------------------------
diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateDiff.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateDiff.java
index dce7ccf..b382c2a 100644
--- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateDiff.java
+++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateDiff.java
@@ -362,12 +362,14 @@ public class TestVectorDateDiff {
Object[][] randomRows, ColumnScalarMode columnScalarMode,
ObjectInspector rowInspector, Object[] resultObjects) throws Exception {
+ /*
System.out.println(
"*DEBUG* dateTimeStringTypeInfo " + dateTimeStringTypeInfo1.toString() +
" dateTimeStringTypeInfo2 " + dateTimeStringTypeInfo2 +
" dateDiffTestMode ROW_MODE" +
" columnScalarMode " + columnScalarMode +
" exprDesc " + exprDesc.toString());
+ */
HiveConf hiveConf = new HiveConf();
ExprNodeEvaluator evaluator =