You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hive.apache.org by ji...@apache.org on 2014/02/19 18:43:21 UTC
svn commit: r1569850 [2/3] - in /hive/trunk:
ant/src/org/apache/hadoop/hive/ant/
common/src/java/org/apache/hadoop/hive/common/type/
common/src/java/org/apache/hive/common/util/
ql/src/gen/vectorization/UDAFTemplates/ ql/src/java/org/apache/hadoop/hive...
Added: hive/trunk/ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt?rev=1569850&view=auto
==============================================================================
--- hive/trunk/ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt (added)
+++ hive/trunk/ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt Wed Feb 19 17:43:20 2014
@@ -0,0 +1,472 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.hive.common.type.Decimal128;
+import org.apache.hadoop.hive.common.type.HiveDecimal;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
+import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow;
+import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
+import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.plan.AggregationDesc;
+import org.apache.hadoop.hive.ql.util.JavaDataModel;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+/**
+* <ClassName>. Vectorized implementation for VARIANCE aggregates.
+*/
+@Description(name = "<DescriptionName>",
+ value = "<DescriptionValue>")
+public class <ClassName> extends VectorAggregateExpression {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ /* class for storing the current aggregate value.
+ */
+ private static final class Aggregation implements AggregationBuffer {
+
+ private static final long serialVersionUID = 1L;
+
+ transient private final Decimal128 sum;
+ transient private long count;
+ transient private double variance;
+
+ /**
+ * Value is explicitly (re)initialized in reset()
+ */
+ transient private boolean isNull = true;
+
+ public Aggregation() {
+ sum = new Decimal128();
+ }
+
+ public void init() {
+ isNull = false;
+ sum.zeroClear();
+ count = 0;
+ variance = 0f;
+ }
+
+ @Override
+ public int getVariableSize() {
+ throw new UnsupportedOperationException();
+ }
+
+ public void updateValueWithCheckAndInit(Decimal128 scratch, Decimal128 value) {
+ if (this.isNull) {
+ this.init();
+ }
+ this.sum.addDestructive(value, value.getScale());
+ this.count += 1;
+ if(this.count > 1) {
+ scratch.update(count);
+ scratch.multiplyDestructive(value, value.getScale());
+ scratch.subtractDestructive(sum, sum.getScale());
+ double t = scratch.doubleValue();
+ this.variance += (t*t) / ((double)this.count*(this.count-1));
+ }
+ }
+
+ public void updateValueNoCheck(Decimal128 scratch, Decimal128 value) {
+ this.sum.addDestructive(value, value.getScale());
+ this.count += 1;
+ scratch.update(count);
+ scratch.multiplyDestructive(value, value.getScale());
+ scratch.subtractDestructive(sum, sum.getScale());
+ double t = scratch.doubleValue();
+ this.variance += (t*t) / ((double)this.count*(this.count-1));
+ }
+
+ }
+
+ private VectorExpression inputExpression;
+ transient private LongWritable resultCount;
+ transient private DoubleWritable resultSum;
+ transient private DoubleWritable resultVariance;
+ transient private Object[] partialResult;
+
+ transient private ObjectInspector soi;
+
+ transient private final Decimal128 scratchDecimal;
+
+
+ public <ClassName>(VectorExpression inputExpression) {
+ this();
+ this.inputExpression = inputExpression;
+ }
+
+ public <ClassName>() {
+ super();
+ partialResult = new Object[3];
+ resultCount = new LongWritable();
+ resultSum = new DoubleWritable();
+ resultVariance = new DoubleWritable();
+ partialResult[0] = resultCount;
+ partialResult[1] = resultSum;
+ partialResult[2] = resultVariance;
+ initPartialResultInspector();
+ scratchDecimal = new Decimal128();
+ }
+
+ private void initPartialResultInspector() {
+ List<ObjectInspector> foi = new ArrayList<ObjectInspector>();
+ foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+ foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+
+ List<String> fname = new ArrayList<String>();
+ fname.add("count");
+ fname.add("sum");
+ fname.add("variance");
+
+ soi = ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
+ }
+
+ private Aggregation getCurrentAggregationBuffer(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ int row) {
+ VectorAggregationBufferRow mySet = aggregationBufferSets[row];
+ Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregateIndex);
+ return myagg;
+ }
+
+
+ @Override
+ public void aggregateInputSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ VectorizedRowBatch batch) throws HiveException {
+
+ inputExpression.evaluate(batch);
+
+ DecimalColumnVector inputVector = (DecimalColumnVector)batch.
+ cols[this.inputExpression.getOutputColumn()];
+
+ int batchSize = batch.size;
+
+ if (batchSize == 0) {
+ return;
+ }
+
+ Decimal128[] vector = inputVector.vector;
+
+ if (inputVector.isRepeating) {
+ if (inputVector.noNulls || !inputVector.isNull[0]) {
+ iterateRepeatingNoNullsWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex, vector[0], batchSize);
+ }
+ }
+ else if (!batch.selectedInUse && inputVector.noNulls) {
+ iterateNoSelectionNoNullsWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex, vector, batchSize);
+ }
+ else if (!batch.selectedInUse) {
+ iterateNoSelectionHasNullsWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex, vector, batchSize, inputVector.isNull);
+ }
+ else if (inputVector.noNulls){
+ iterateSelectionNoNullsWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex, vector, batchSize, batch.selected);
+ }
+ else {
+ iterateSelectionHasNullsWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex, vector, batchSize,
+ inputVector.isNull, batch.selected);
+ }
+
+ }
+
+ private void iterateRepeatingNoNullsWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ Decimal128 value,
+ int batchSize) {
+
+ for (int i=0; i<batchSize; ++i) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ aggregateIndex,
+ i);
+ myagg.updateValueWithCheckAndInit(scratchDecimal, value);
+ }
+ }
+
+ private void iterateSelectionHasNullsWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ Decimal128[] vector,
+ int batchSize,
+ boolean[] isNull,
+ int[] selected) {
+
+ for (int j=0; j< batchSize; ++j) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ aggregateIndex,
+ j);
+ int i = selected[j];
+ if (!isNull[i]) {
+ Decimal128 value = vector[i];
+ myagg.updateValueWithCheckAndInit(scratchDecimal, value);
+ }
+ }
+ }
+
+ private void iterateSelectionNoNullsWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ Decimal128[] vector,
+ int batchSize,
+ int[] selected) {
+
+ for (int i=0; i< batchSize; ++i) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ aggregateIndex,
+ i);
+ Decimal128 value = vector[selected[i]];
+ myagg.updateValueWithCheckAndInit(scratchDecimal, value);
+ }
+ }
+
+ private void iterateNoSelectionHasNullsWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ Decimal128[] vector,
+ int batchSize,
+ boolean[] isNull) {
+
+ for(int i=0;i<batchSize;++i) {
+ if (!isNull[i]) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ aggregateIndex,
+ i);
+ Decimal128 value = vector[i];
+ myagg.updateValueWithCheckAndInit(scratchDecimal, value);
+ }
+ }
+ }
+
+ private void iterateNoSelectionNoNullsWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ Decimal128[] vector,
+ int batchSize) {
+
+ for (int i=0; i<batchSize; ++i) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ aggregateIndex,
+ i);
+ Decimal128 value = vector[i];
+ myagg.updateValueWithCheckAndInit(scratchDecimal, value);
+ }
+ }
+
+ @Override
+ public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch)
+ throws HiveException {
+
+ inputExpression.evaluate(batch);
+
+ DecimalColumnVector inputVector = (DecimalColumnVector)batch.
+ cols[this.inputExpression.getOutputColumn()];
+
+ int batchSize = batch.size;
+
+ if (batchSize == 0) {
+ return;
+ }
+
+ Aggregation myagg = (Aggregation)agg;
+
+ Decimal128[] vector = inputVector.vector;
+
+ if (inputVector.isRepeating) {
+ if (inputVector.noNulls) {
+ iterateRepeatingNoNulls(myagg, vector[0], batchSize);
+ }
+ }
+ else if (!batch.selectedInUse && inputVector.noNulls) {
+ iterateNoSelectionNoNulls(myagg, vector, batchSize);
+ }
+ else if (!batch.selectedInUse) {
+ iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull);
+ }
+ else if (inputVector.noNulls){
+ iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected);
+ }
+ else {
+ iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected);
+ }
+ }
+
+ private void iterateRepeatingNoNulls(
+ Aggregation myagg,
+ Decimal128 value,
+ int batchSize) {
+
+ // TODO: conjure a formula w/o iterating
+ //
+
+ myagg.updateValueWithCheckAndInit(scratchDecimal, value);
+
+ // We pulled out i=0 so we can remove the count > 1 check in the loop
+ for (int i=1; i<batchSize; ++i) {
+ myagg.updateValueNoCheck(scratchDecimal, value);
+ }
+ }
+
+ private void iterateSelectionHasNulls(
+ Aggregation myagg,
+ Decimal128[] vector,
+ int batchSize,
+ boolean[] isNull,
+ int[] selected) {
+
+ for (int j=0; j< batchSize; ++j) {
+ int i = selected[j];
+ if (!isNull[i]) {
+ Decimal128 value = vector[i];
+ myagg.updateValueWithCheckAndInit(scratchDecimal, value);
+ }
+ }
+ }
+
+ private void iterateSelectionNoNulls(
+ Aggregation myagg,
+ Decimal128[] vector,
+ int batchSize,
+ int[] selected) {
+
+ if (myagg.isNull) {
+ myagg.init ();
+ }
+
+ Decimal128 value = vector[selected[0]];
+ myagg.updateValueWithCheckAndInit(scratchDecimal, value);
+
+ // i=0 was pulled out to remove the count > 1 check in the loop
+ //
+ for (int i=1; i< batchSize; ++i) {
+ value = vector[selected[i]];
+ myagg.updateValueNoCheck(scratchDecimal, value);
+ }
+ }
+
+ private void iterateNoSelectionHasNulls(
+ Aggregation myagg,
+ Decimal128[] vector,
+ int batchSize,
+ boolean[] isNull) {
+
+ for(int i=0;i<batchSize;++i) {
+ if (!isNull[i]) {
+ Decimal128 value = vector[i];
+ myagg.updateValueWithCheckAndInit(scratchDecimal, value);
+ }
+ }
+ }
+
+ private void iterateNoSelectionNoNulls(
+ Aggregation myagg,
+ Decimal128[] vector,
+ int batchSize) {
+
+ if (myagg.isNull) {
+ myagg.init ();
+ }
+
+ Decimal128 value = vector[0];
+ myagg.updateValueWithCheckAndInit(scratchDecimal, value);
+
+ // i=0 was pulled out to remove count > 1 check
+ for (int i=1; i<batchSize; ++i) {
+ value = vector[i];
+ myagg.updateValueNoCheck(scratchDecimal, value);
+ }
+ }
+
+ @Override
+ public AggregationBuffer getNewAggregationBuffer() throws HiveException {
+ return new Aggregation();
+ }
+
+ @Override
+ public void reset(AggregationBuffer agg) throws HiveException {
+ Aggregation myAgg = (Aggregation) agg;
+ myAgg.isNull = true;
+ }
+
+ @Override
+ public Object evaluateOutput(
+ AggregationBuffer agg) throws HiveException {
+ Aggregation myagg = (Aggregation) agg;
+ if (myagg.isNull) {
+ return null;
+ }
+ else {
+ assert(0 < myagg.count);
+ resultCount.set (myagg.count);
+ resultSum.set (myagg.sum.doubleValue());
+ resultVariance.set (myagg.variance);
+ return partialResult;
+ }
+ }
+ @Override
+ public ObjectInspector getOutputObjectInspector() {
+ return soi;
+ }
+
+ @Override
+ public int getAggregationBufferFixedSize() {
+ JavaDataModel model = JavaDataModel.get();
+ return JavaDataModel.alignUp(
+ model.object() +
+ model.primitive2()*3+
+ model.primitive1(),
+ model.memoryAlign());
+ }
+
+ @Override
+ public void init(AggregationDesc desc) throws HiveException {
+ // No-op
+ }
+
+ public VectorExpression getInputExpression() {
+ return inputExpression;
+ }
+
+ public void setInputExpression(VectorExpression inputExpression) {
+ this.inputExpression = inputExpression;
+ }
+}
\ No newline at end of file
Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorColumnAssignFactory.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorColumnAssignFactory.java?rev=1569850&r1=1569849&r2=1569850&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorColumnAssignFactory.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorColumnAssignFactory.java Wed Feb 19 17:43:20 2014
@@ -23,9 +23,12 @@ import java.util.Arrays;
import java.util.List;
import java.util.Map;
+import org.apache.hadoop.hive.common.type.Decimal128;
+import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.io.ByteWritable;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.io.ShortWritable;
import org.apache.hadoop.hive.serde2.io.TimestampWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -141,6 +144,23 @@ public class VectorColumnAssignFactory {
}
}
+ private static abstract class VectorDecimalColumnAssign
+ extends VectorColumnAssignVectorBase<DecimalColumnVector> {
+ protected void assignDecimal(HiveDecimal value, int index) {
+ outCol.vector[index].update(value.unscaledValue(), (byte) value.scale());
+ }
+
+ protected void assignDecimal(Decimal128 value, int index) {
+ outCol.vector[index].update(value);
+ }
+ protected void assignDecimal(HiveDecimalWritable hdw, int index) {
+ byte[] internalStorage = hdw.getInternalStorage();
+ int scale = hdw.getScale();
+
+ outCol.vector[index].fastUpdateFromInternalStorage(internalStorage, (short)scale);
+ }
+ }
+
public static VectorColumnAssign[] buildAssigners(VectorizedRowBatch outputBatch)
throws HiveException {
@@ -175,6 +195,14 @@ public class VectorColumnAssignFactory {
}
}.init(outputBatch, (BytesColumnVector) cv);
}
+ else if (cv instanceof DecimalColumnVector) {
+ vca[i] = new VectorDecimalColumnAssign() {
+ @Override
+ protected void copyValue(DecimalColumnVector src, int srcIndex, int destIndex) {
+ assignDecimal(src.vector[srcIndex], destIndex);
+ }
+ };
+ }
else {
throw new HiveException("Unimplemented vector column type: " + cv.getClass().getName());
}
@@ -336,6 +364,27 @@ public class VectorColumnAssignFactory {
poi.getPrimitiveCategory());
}
}
+ else if (destCol instanceof DecimalColumnVector) {
+ switch(poi.getPrimitiveCategory()) {
+ case DECIMAL:
+ outVCA = new VectorDecimalColumnAssign() {
+ @Override
+ public void assignObjectValue(Object val, int destIndex) throws HiveException {
+ if (val == null) {
+ assignNull(destIndex);
+ }
+ else {
+ HiveDecimalWritable hdw = (HiveDecimalWritable) val;
+ assignDecimal(hdw, destIndex);
+ }
+ }
+ }.init(outputBatch, (DecimalColumnVector) destCol);
+ break;
+ default:
+ throw new HiveException("Incompatible Decimal vector column and primitive category " +
+ poi.getPrimitiveCategory());
+ }
+ }
else {
throw new HiveException("Unknown vector column type " + destCol.getClass().getName());
}
Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExpressionDescriptor.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExpressionDescriptor.java?rev=1569850&r1=1569849&r2=1569850&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExpressionDescriptor.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExpressionDescriptor.java Wed Feb 19 17:43:20 2014
@@ -49,7 +49,7 @@ public class VectorExpressionDescriptor
public static ArgumentType getType(String inType) {
String type = VectorizationContext.getNormalizedTypeName(inType);
- if (VectorizationContext.decimalTypePattern.matcher(type.toLowerCase()).matches()) {
+ if (VectorizationContext.decimalTypePattern.matcher(type).matches()) {
type = "decimal";
}
return valueOf(type.toUpperCase());
Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapper.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapper.java?rev=1569850&r1=1569849&r2=1569850&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapper.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapper.java Wed Feb 19 17:43:20 2014
@@ -20,6 +20,7 @@ package org.apache.hadoop.hive.ql.exec.v
import java.util.Arrays;
+import org.apache.hadoop.hive.common.type.Decimal128;
import org.apache.hadoop.hive.ql.exec.KeyWrapper;
import org.apache.hadoop.hive.ql.exec.vector.expressions.StringExpr;
import org.apache.hadoop.hive.ql.metadata.HiveException;
@@ -42,16 +43,23 @@ public class VectorHashKeyWrapper extend
private int[] byteStarts;
private int[] byteLengths;
+ private Decimal128[] decimalValues;
+
private boolean[] isNull;
private int hashcode;
- public VectorHashKeyWrapper(int longValuesCount, int doubleValuesCount, int byteValuesCount) {
+ public VectorHashKeyWrapper(int longValuesCount, int doubleValuesCount,
+ int byteValuesCount, int decimalValuesCount) {
longValues = new long[longValuesCount];
doubleValues = new double[doubleValuesCount];
+ decimalValues = new Decimal128[decimalValuesCount];
+ for(int i = 0; i < decimalValuesCount; ++i) {
+ decimalValues[i] = new Decimal128();
+ }
byteValues = new byte[byteValuesCount][];
byteStarts = new int[byteValuesCount];
byteLengths = new int[byteValuesCount];
- isNull = new boolean[longValuesCount + doubleValuesCount + byteValuesCount];
+ isNull = new boolean[longValuesCount + doubleValuesCount + byteValuesCount + decimalValuesCount];
}
private VectorHashKeyWrapper() {
@@ -66,6 +74,7 @@ public class VectorHashKeyWrapper extend
public void setHashKey() {
hashcode = Arrays.hashCode(longValues) ^
Arrays.hashCode(doubleValues) ^
+ Arrays.hashCode(decimalValues) ^
Arrays.hashCode(isNull);
// This code, with branches and all, is not executed if there are no string keys
@@ -104,6 +113,7 @@ public class VectorHashKeyWrapper extend
return hashcode == keyThat.hashcode &&
Arrays.equals(longValues, keyThat.longValues) &&
Arrays.equals(doubleValues, keyThat.doubleValues) &&
+ Arrays.equals(decimalValues, keyThat.decimalValues) &&
Arrays.equals(isNull, keyThat.isNull) &&
byteValues.length == keyThat.byteValues.length &&
(0 == byteValues.length || bytesEquals(keyThat));
@@ -137,6 +147,12 @@ public class VectorHashKeyWrapper extend
clone.doubleValues = doubleValues.clone();
clone.isNull = isNull.clone();
+ // Decimal128 requires deep clone
+ clone.decimalValues = new Decimal128[decimalValues.length];
+ for(int i = 0; i < decimalValues.length; ++i) {
+ clone.decimalValues[i] = new Decimal128().update(decimalValues[i]);
+ }
+
clone.byteValues = new byte[byteValues.length][];
clone.byteStarts = new int[byteValues.length];
clone.byteLengths = byteLengths.clone();
@@ -201,13 +217,22 @@ public class VectorHashKeyWrapper extend
isNull[longValues.length + doubleValues.length + index] = true;
}
+ public void assignDecimal(int index, Decimal128 value) {
+ decimalValues[index].update(value);
+ }
+
+ public void assignNullDecimal(int index) {
+ isNull[longValues.length + doubleValues.length + byteValues.length + index] = true;
+ }
+
@Override
public String toString()
{
- return String.format("%d[%s] %d[%s] %d[%s]",
+ return String.format("%d[%s] %d[%s] %d[%s] %d[%s]",
longValues.length, Arrays.toString(longValues),
doubleValues.length, Arrays.toString(doubleValues),
- byteValues.length, Arrays.toString(byteValues));
+ byteValues.length, Arrays.toString(byteValues),
+ decimalValues.length, Arrays.toString(decimalValues));
}
public boolean getIsLongNull(int i) {
@@ -222,7 +247,7 @@ public class VectorHashKeyWrapper extend
return isNull[longValues.length + doubleValues.length + i];
}
-
+
public long getLongValue(int i) {
return longValues[i];
}
@@ -252,6 +277,12 @@ public class VectorHashKeyWrapper extend
return variableSize;
}
+ public boolean getIsDecimalNull(int i) {
+ return isNull[longValues.length + doubleValues.length + byteValues.length + i];
+ }
+ public Decimal128 getDecimal(int i) {
+ return decimalValues[i];
+ }
}
Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java?rev=1569850&r1=1569849&r2=1569850&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java Wed Feb 19 17:43:20 2014
@@ -40,9 +40,40 @@ public class VectorHashKeyWrapperBatch {
private int longIndex;
private int doubleIndex;
private int stringIndex;
+ private int decimalIndex;
+
+ private static final int INDEX_UNUSED = -1;
+
+ private void resetIndices() {
+ this.longIndex = this.doubleIndex = this.stringIndex = this.decimalIndex = INDEX_UNUSED;
+ }
+ public void setLong(int index) {
+ resetIndices();
+ this.longIndex= index;
+ }
+
+ public void setDouble(int index) {
+ resetIndices();
+ this.doubleIndex = index;
+ }
+
+ public void setString(int index) {
+ resetIndices();
+ this.stringIndex = index;
+ }
+
+ public void setDecimal(int index) {
+ resetIndices();
+ this.decimalIndex = index;
+ }
}
/**
+ * Number of object references in 'this' (for size computation)
+ */
+ private static final int MODEL_REFERENCES_COUNT = 7;
+
+ /**
* The key expressions that require evaluation and output the primitive values for each key.
*/
private VectorExpression[] keyExpressions;
@@ -63,6 +94,11 @@ public class VectorHashKeyWrapperBatch {
private int[] stringIndices;
/**
+ * indices of decimal primitive keys.
+ */
+ private int[] decimalIndices;
+
+ /**
* Pre-allocated batch size vector of keys wrappers.
* N.B. these keys are **mutable** and should never be used in a HashMap.
* Always clone the key wrapper to obtain an immutable keywrapper suitable
@@ -175,6 +211,28 @@ public class VectorHashKeyWrapperBatch {
columnVector.noNulls, columnVector.isRepeating, batch.selectedInUse));
}
}
+ for(int i=0;i<decimalIndices.length; ++i) {
+ int keyIndex = decimalIndices[i];
+ int columnIndex = keyExpressions[keyIndex].getOutputColumn();
+ DecimalColumnVector columnVector = (DecimalColumnVector) batch.cols[columnIndex];
+ if (columnVector.noNulls && !columnVector.isRepeating && !batch.selectedInUse) {
+ assignDecimalNoNullsNoRepeatingNoSelection(i, batch.size, columnVector);
+ } else if (columnVector.noNulls && !columnVector.isRepeating && batch.selectedInUse) {
+ assignDecimalNoNullsNoRepeatingSelection(i, batch.size, columnVector, batch.selected);
+ } else if (columnVector.noNulls && columnVector.isRepeating) {
+ assignDecimalNoNullsRepeating(i, batch.size, columnVector);
+ } else if (!columnVector.noNulls && !columnVector.isRepeating && !batch.selectedInUse) {
+ assignDecimalNullsNoRepeatingNoSelection(i, batch.size, columnVector);
+ } else if (!columnVector.noNulls && columnVector.isRepeating) {
+ assignDecimalNullsRepeating(i, batch.size, columnVector);
+ } else if (!columnVector.noNulls && !columnVector.isRepeating && batch.selectedInUse) {
+ assignDecimalNullsNoRepeatingSelection (i, batch.size, columnVector, batch.selected);
+ } else {
+ throw new HiveException (String.format(
+ "Unimplemented Decimal null/repeat/selected combination %b/%b/%b",
+ columnVector.noNulls, columnVector.isRepeating, batch.selectedInUse));
+ }
+ }
for(int i=0;i<batch.size;++i) {
vectorHashKeyWrappers[i].setHashKey();
}
@@ -427,6 +485,80 @@ public class VectorHashKeyWrapperBatch {
}
/**
+ * Helper method to assign values from a vector column into the key wrapper.
+ * Optimized for Decimal type, possible nulls, no repeat values, batch selection vector.
+ */
+ private void assignDecimalNullsNoRepeatingSelection(int index, int size,
+ DecimalColumnVector columnVector, int[] selected) {
+ for(int i = 0; i < size; ++i) {
+ int row = selected[i];
+ if (!columnVector.isNull[row]) {
+ vectorHashKeyWrappers[i].assignDecimal(index, columnVector.vector[row]);
+ } else {
+ vectorHashKeyWrappers[i].assignNullDecimal(index);
+ }
+ }
+ }
+
+ /**
+ * Helper method to assign values from a vector column into the key wrapper.
+ * Optimized for Decimal type, repeat null values.
+ */
+ private void assignDecimalNullsRepeating(int index, int size,
+ DecimalColumnVector columnVector) {
+ for(int r = 0; r < size; ++r) {
+ vectorHashKeyWrappers[r].assignNullDecimal(index);
+ }
+ }
+
+ /**
+ * Helper method to assign values from a vector column into the key wrapper.
+ * Optimized for Decimal type, possible nulls, repeat values.
+ */
+ private void assignDecimalNullsNoRepeatingNoSelection(int index, int size,
+ DecimalColumnVector columnVector) {
+ for(int r = 0; r < size; ++r) {
+ if (!columnVector.isNull[r]) {
+ vectorHashKeyWrappers[r].assignDecimal(index, columnVector.vector[r]);
+ } else {
+ vectorHashKeyWrappers[r].assignNullDecimal(index);
+ }
+ }
+ }
+
+ /**
+ * Helper method to assign values from a vector column into the key wrapper.
+ * Optimized for Decimal type, no nulls, repeat values, no selection vector.
+ */
+ private void assignDecimalNoNullsRepeating(int index, int size, DecimalColumnVector columnVector) {
+ for(int r = 0; r < size; ++r) {
+ vectorHashKeyWrappers[r].assignDecimal(index, columnVector.vector[0]);
+ }
+ }
+
+ /**
+ * Helper method to assign values from a vector column into the key wrapper.
+ * Optimized for Decimal type, no nulls, no repeat values, batch selection vector.
+ */
+ private void assignDecimalNoNullsNoRepeatingSelection(int index, int size,
+ DecimalColumnVector columnVector, int[] selected) {
+ for(int r = 0; r < size; ++r) {
+ vectorHashKeyWrappers[r].assignDecimal(index, columnVector.vector[selected[r]]);
+ }
+ }
+
+ /**
+ * Helper method to assign values from a vector column into the key wrapper.
+ * Optimized for Decimal type, no nulls, no repeat values, no selection vector.
+ */
+ private void assignDecimalNoNullsNoRepeatingNoSelection(int index, int size,
+ DecimalColumnVector columnVector) {
+ for(int r = 0; r < size; ++r) {
+ vectorHashKeyWrappers[r].assignDecimal(index, columnVector.vector[r]);
+ }
+ }
+
+ /**
* Prepares a VectorHashKeyWrapperBatch to work for a specific set of keys.
* Computes the fast access lookup indices, preallocates all needed internal arrays.
* This step is done only once per query, not once per batch. The information computed now
@@ -446,6 +578,8 @@ public class VectorHashKeyWrapperBatch {
int doubleIndicesIndex = 0;
int[] stringIndices = new int[keyExpressions.length];
int stringIndicesIndex = 0;
+ int[] decimalIndices = new int[keyExpressions.length];
+ int decimalIndicesIndex = 0;
KeyLookupHelper[] indexLookup = new KeyLookupHelper[keyExpressions.length];
// Inspect the output type of each key expression.
@@ -455,22 +589,20 @@ public class VectorHashKeyWrapperBatch {
if (VectorizationContext.isIntFamily(outputType) ||
VectorizationContext.isDatetimeFamily(outputType)) {
longIndices[longIndicesIndex] = i;
- indexLookup[i].longIndex = longIndicesIndex;
- indexLookup[i].doubleIndex = -1;
- indexLookup[i].stringIndex = -1;
+ indexLookup[i].setLong(longIndicesIndex);
++longIndicesIndex;
} else if (VectorizationContext.isFloatFamily(outputType)) {
doubleIndices[doubleIndicesIndex] = i;
- indexLookup[i].longIndex = -1;
- indexLookup[i].doubleIndex = doubleIndicesIndex;
- indexLookup[i].stringIndex = -1;
+ indexLookup[i].setDouble(doubleIndicesIndex);
++doubleIndicesIndex;
} else if (VectorizationContext.isStringFamily(outputType)) {
stringIndices[stringIndicesIndex]= i;
- indexLookup[i].longIndex = -1;
- indexLookup[i].doubleIndex = -1;
- indexLookup[i].stringIndex = stringIndicesIndex;
+ indexLookup[i].setString(stringIndicesIndex);
++stringIndicesIndex;
+ } else if (VectorizationContext.isDecimalFamily(outputType)) {
+ decimalIndices[decimalIndicesIndex]= i;
+ indexLookup[i].setDecimal(decimalIndicesIndex);
+ ++decimalIndicesIndex;
}
else {
throw new HiveException("Unsuported vector output type: " + outputType);
@@ -480,11 +612,13 @@ public class VectorHashKeyWrapperBatch {
compiledKeyWrapperBatch.longIndices = Arrays.copyOf(longIndices, longIndicesIndex);
compiledKeyWrapperBatch.doubleIndices = Arrays.copyOf(doubleIndices, doubleIndicesIndex);
compiledKeyWrapperBatch.stringIndices = Arrays.copyOf(stringIndices, stringIndicesIndex);
+ compiledKeyWrapperBatch.decimalIndices = Arrays.copyOf(decimalIndices, decimalIndicesIndex);
compiledKeyWrapperBatch.vectorHashKeyWrappers =
new VectorHashKeyWrapper[VectorizedRowBatch.DEFAULT_SIZE];
for(int i=0;i<VectorizedRowBatch.DEFAULT_SIZE; ++i) {
compiledKeyWrapperBatch.vectorHashKeyWrappers[i] =
- new VectorHashKeyWrapper(longIndicesIndex, doubleIndicesIndex, stringIndicesIndex);
+ new VectorHashKeyWrapper(longIndicesIndex, doubleIndicesIndex,
+ stringIndicesIndex, decimalIndicesIndex);
}
JavaDataModel model = JavaDataModel.get();
@@ -493,7 +627,7 @@ public class VectorHashKeyWrapperBatch {
// start with the keywrapper itself
compiledKeyWrapperBatch.keysFixedSize += JavaDataModel.alignUp(
model.object() +
- model.ref() * 6+
+ model.ref() * MODEL_REFERENCES_COUNT +
model.primitive1(),
model.memoryAlign());
@@ -501,6 +635,7 @@ public class VectorHashKeyWrapperBatch {
compiledKeyWrapperBatch.keysFixedSize += model.lengthForLongArrayOfSize(longIndicesIndex);
compiledKeyWrapperBatch.keysFixedSize += model.lengthForDoubleArrayOfSize(doubleIndicesIndex);
compiledKeyWrapperBatch.keysFixedSize += model.lengthForObjectArrayOfSize(stringIndicesIndex);
+ compiledKeyWrapperBatch.keysFixedSize += model.lengthForObjectArrayOfSize(decimalIndicesIndex);
compiledKeyWrapperBatch.keysFixedSize += model.lengthForIntArrayOfSize(longIndicesIndex) * 2;
compiledKeyWrapperBatch.keysFixedSize +=
model.lengthForBooleanArrayOfSize(keyExpressions.length);
@@ -529,7 +664,12 @@ public class VectorHashKeyWrapperBatch {
kw.getBytes(klh.stringIndex),
kw.getByteStart(klh.stringIndex),
kw.getByteLength(klh.stringIndex));
- } else {
+ } else if (klh.decimalIndex >= 0) {
+ return kw.getIsDecimalNull(klh.decimalIndex)? null :
+ keyOutputWriter.writeValue(
+ kw.getDecimal(klh.decimalIndex));
+ }
+ else {
throw new HiveException(String.format(
"Internal inconsistent KeyLookupHelper at index [%d]:%d %d %d",
i, klh.longIndex, klh.doubleIndex, klh.stringIndex));
Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorMapJoinOperator.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorMapJoinOperator.java?rev=1569850&r1=1569849&r2=1569850&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorMapJoinOperator.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorMapJoinOperator.java Wed Feb 19 17:43:20 2014
@@ -63,8 +63,8 @@ public class VectorMapJoinOperator exten
private int tagLen;
private VectorExpression[] keyExpressions;
- private VectorHashKeyWrapperBatch keyWrapperBatch;
- private VectorExpressionWriter[] keyOutputWriters;
+ private transient VectorHashKeyWrapperBatch keyWrapperBatch;
+ private transient VectorExpressionWriter[] keyOutputWriters;
private VectorExpression[] bigTableFilterExpressions;
private VectorExpression[] bigTableValueExpressions;
@@ -111,7 +111,6 @@ public class VectorMapJoinOperator exten
List<ExprNodeDesc> keyDesc = desc.getKeys().get(posBigTable);
keyExpressions = vContext.getVectorExpressions(keyDesc);
- keyOutputWriters = VectorExpressionWriterFactory.getExpressionWriters(keyDesc);
// We're only going to evaluate the big table vectorized expressions,
Map<Byte, List<ExprNodeDesc>> exprs = desc.getExprs();
@@ -135,6 +134,8 @@ public class VectorMapJoinOperator exten
public void initializeOp(Configuration hconf) throws HiveException {
super.initializeOp(hconf);
+ List<ExprNodeDesc> keyDesc = conf.getKeys().get(posBigTable);
+ keyOutputWriters = VectorExpressionWriterFactory.getExpressionWriters(keyDesc);
vrbCtx = new VectorizedRowBatchCtx();
vrbCtx.init(hconf, this.fileKey, (StructObjectInspector) this.outputObjInspector);
Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java?rev=1569850&r1=1569849&r2=1569850&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java Wed Feb 19 17:43:20 2014
@@ -40,28 +40,37 @@ import org.apache.hadoop.hive.ql.exec.Fu
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.exec.vector.TimestampUtils;
+import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor.ArgumentType;
import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor.InputExpressionType;
import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor.Mode;
import org.apache.hadoop.hive.ql.exec.vector.expressions.*;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFAvgDecimal;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFCount;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFCountStar;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFSumDecimal;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFAvgDouble;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFAvgLong;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMaxDecimal;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMaxDouble;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMaxLong;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMaxString;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMinDecimal;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMinDouble;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMinLong;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMinString;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdPopDecimal;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdPopDouble;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdPopLong;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdSampDecimal;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdSampDouble;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdSampLong;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFSumDouble;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFSumLong;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarPopDecimal;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarPopDouble;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarPopLong;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarSampDecimal;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarSampDouble;
import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarSampLong;
import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.CastLongToBooleanViaLongToLong;
@@ -114,7 +123,8 @@ public class VectorizationContext {
private final Map<String, Integer> columnMap;
private final int firstOutputColumnIndex;
- public static final Pattern decimalTypePattern = Pattern.compile("decimal.*");
+ public static final Pattern decimalTypePattern = Pattern.compile("decimal.*",
+ Pattern.CASE_INSENSITIVE);
//Map column number to type
private final OutputColumnManager ocm;
@@ -869,7 +879,7 @@ public class VectorizationContext {
ExprNodeDesc colExpr = childExpr.get(0);
TypeInfo colTypeInfo = colExpr.getTypeInfo();
String colType = colExpr.getTypeString();
-
+
// prepare arguments for createVectorExpression
List<ExprNodeDesc> childrenForInList =
foldConstantsForUnaryExprs(childExpr.subList(1, childExpr.size()));
@@ -1111,7 +1121,7 @@ public class VectorizationContext {
String colType = colExpr.getTypeString();
// prepare arguments for createVectorExpression
- List<ExprNodeDesc> childrenAfterNot = foldConstantsForUnaryExprs(childExpr.subList(1, 4));
+ List<ExprNodeDesc> childrenAfterNot = foldConstantsForUnaryExprs(childExpr.subList(1, 4));;
// determine class
Class<?> cl = null;
@@ -1241,6 +1251,10 @@ public class VectorizationContext {
|| resultType.equalsIgnoreCase("long");
}
+ public static boolean isDecimalFamily(String colType) {
+ return decimalTypePattern.matcher(colType).matches();
+ }
+
private Object getScalarValue(ExprNodeConstantDesc constDesc)
throws HiveException {
if (constDesc.getTypeString().equalsIgnoreCase("String")) {
@@ -1353,14 +1367,13 @@ public class VectorizationContext {
}
}
- static String getNormalizedTypeName(String colType) {
+ static String getNormalizedTypeName(String colType){
String normalizedType = null;
if (colType.equalsIgnoreCase("Double") || colType.equalsIgnoreCase("Float")) {
normalizedType = "Double";
} else if (colType.equalsIgnoreCase("String")) {
normalizedType = "String";
- } else if (decimalTypePattern.matcher(colType.toLowerCase()).matches()) {
-
+ } else if (decimalTypePattern.matcher(colType).matches()) {
//Return the decimal type as is, it includes scale and precision.
normalizedType = colType;
} else {
@@ -1373,31 +1386,43 @@ public class VectorizationContext {
{"min", "Long", VectorUDAFMinLong.class},
{"min", "Double", VectorUDAFMinDouble.class},
{"min", "String", VectorUDAFMinString.class},
+ {"min", "Decimal",VectorUDAFMinDecimal.class},
{"max", "Long", VectorUDAFMaxLong.class},
{"max", "Double", VectorUDAFMaxDouble.class},
{"max", "String", VectorUDAFMaxString.class},
+ {"max", "Decimal",VectorUDAFMaxDecimal.class},
{"count", null, VectorUDAFCountStar.class},
{"count", "Long", VectorUDAFCount.class},
{"count", "Double", VectorUDAFCount.class},
{"count", "String", VectorUDAFCount.class},
+ {"count", "Decimal",VectorUDAFCount.class},
{"sum", "Long", VectorUDAFSumLong.class},
{"sum", "Double", VectorUDAFSumDouble.class},
+ {"sum", "Decimal",VectorUDAFSumDecimal.class},
{"avg", "Long", VectorUDAFAvgLong.class},
{"avg", "Double", VectorUDAFAvgDouble.class},
+ {"avg", "Decimal",VectorUDAFAvgDecimal.class},
{"variance", "Long", VectorUDAFVarPopLong.class},
{"var_pop", "Long", VectorUDAFVarPopLong.class},
{"variance", "Double", VectorUDAFVarPopDouble.class},
{"var_pop", "Double", VectorUDAFVarPopDouble.class},
+ {"variance", "Decimal",VectorUDAFVarPopDecimal.class},
+ {"var_pop", "Decimal",VectorUDAFVarPopDecimal.class},
{"var_samp", "Long", VectorUDAFVarSampLong.class},
{"var_samp" , "Double", VectorUDAFVarSampDouble.class},
+ {"var_samp" , "Decimal",VectorUDAFVarSampDecimal.class},
{"std", "Long", VectorUDAFStdPopLong.class},
{"stddev", "Long", VectorUDAFStdPopLong.class},
{"stddev_pop","Long", VectorUDAFStdPopLong.class},
{"std", "Double", VectorUDAFStdPopDouble.class},
{"stddev", "Double", VectorUDAFStdPopDouble.class},
{"stddev_pop","Double", VectorUDAFStdPopDouble.class},
+ {"std", "Decimal",VectorUDAFStdPopDecimal.class},
+ {"stddev", "Decimal",VectorUDAFStdPopDecimal.class},
+ {"stddev_pop","Decimal",VectorUDAFStdPopDecimal.class},
{"stddev_samp","Long", VectorUDAFStdSampLong.class},
{"stddev_samp","Double",VectorUDAFStdSampDouble.class},
+ {"stddev_samp","Decimal",VectorUDAFStdSampDecimal.class},
};
public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc)
@@ -1417,6 +1442,9 @@ public class VectorizationContext {
if (paramDescList.size() > 0) {
ExprNodeDesc inputExpr = paramDescList.get(0);
inputType = getNormalizedTypeName(inputExpr.getTypeString());
+ if (decimalTypePattern.matcher(inputType).matches()) {
+ inputType = "Decimal";
+ }
}
for (Object[] aggDef : aggregatesDefinition) {
Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java?rev=1569850&r1=1569849&r2=1569850&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java Wed Feb 19 17:43:20 2014
@@ -394,7 +394,7 @@ public class VectorizedRowBatchCtx {
return new DoubleColumnVector(defaultSize);
} else if (type.equalsIgnoreCase("string")) {
return new BytesColumnVector(defaultSize);
- } else if (VectorizationContext.decimalTypePattern.matcher(type.toLowerCase()).matches()){
+ } else if (VectorizationContext.decimalTypePattern.matcher(type).matches()){
int [] precisionScale = getScalePrecisionFromDecimalType(type);
return new DecimalColumnVector(defaultSize, precisionScale[0], precisionScale[1]);
} else {
Added: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFAvgDecimal.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFAvgDecimal.java?rev=1569850&view=auto
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFAvgDecimal.java (added)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFAvgDecimal.java Wed Feb 19 17:43:20 2014
@@ -0,0 +1,516 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.hive.common.type.Decimal128;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
+import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow;
+import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
+import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.plan.AggregationDesc;
+import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage;
+import org.apache.hadoop.hive.ql.util.JavaDataModel;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
+import org.apache.hive.common.util.Decimal128FastBuffer;
+
+/**
+ * Generated from template VectorUDAFAvg.txt.
+ */
+@Description(name = "avg",
+ value = "_FUNC_(AVG) - Returns the average value of expr (vectorized, type: decimal)")
+public class VectorUDAFAvgDecimal extends VectorAggregateExpression {
+
+ private static final long serialVersionUID = 1L;
+
+ /** class for storing the current aggregate value. */
+ static class Aggregation implements AggregationBuffer {
+
+ private static final long serialVersionUID = 1L;
+
+ transient private final Decimal128 sum = new Decimal128();
+ transient private long count;
+ transient private boolean isNull;
+
+ public void sumValueWithCheck(Decimal128 value, short scale) {
+ if (isNull) {
+ sum.update(value);
+ sum.changeScaleDestructive(scale);
+ count = 1;
+ isNull = false;
+ } else {
+ sum.addDestructive(value, scale);
+ count++;
+ }
+ }
+
+ public void sumValueNoCheck(Decimal128 value, short scale) {
+ sum.addDestructive(value, scale);
+ count++;
+ }
+
+
+ @Override
+ public int getVariableSize() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ private VectorExpression inputExpression;
+ transient private Object[] partialResult;
+ transient private LongWritable resultCount;
+ transient private HiveDecimalWritable resultSum;
+ transient private StructObjectInspector soi;
+
+ transient private final Decimal128FastBuffer scratch;
+
+ /**
+ * The scale of the SUM in the partial output
+ */
+ private short sumScale;
+
+ /**
+ * The precision of the SUM in the partial output
+ */
+ private short sumPrecision;
+
+ /**
+ * the scale of the input expression
+ */
+ private short inputScale;
+
+ /**
+ * the precision of the input expression
+ */
+ private short inputPrecision;
+
+ /**
+ * A value used as scratch to avoid allocating at runtime.
+ * Needed by computations like vector[0] * batchSize
+ */
+ transient private Decimal128 scratchDecimal = new Decimal128();
+
+ public VectorUDAFAvgDecimal(VectorExpression inputExpression) {
+ this();
+ this.inputExpression = inputExpression;
+ }
+
+ public VectorUDAFAvgDecimal() {
+ super();
+ partialResult = new Object[2];
+ resultCount = new LongWritable();
+ resultSum = new HiveDecimalWritable();
+ partialResult[0] = resultCount;
+ partialResult[1] = resultSum;
+ scratch = new Decimal128FastBuffer();
+
+ }
+
+ private void initPartialResultInspector() {
+ // the output type of the vectorized partial aggregate must match the
+ // expected type for the row-mode aggregation
+ // For decimal, the type is "same number of integer digits and 4 more decimal digits"
+
+ DecimalTypeInfo dtiSum = GenericUDAFAverage.deriveSumTypeInfo(inputScale, inputPrecision);
+ this.sumScale = (short) dtiSum.scale();
+ this.sumPrecision = (short) dtiSum.precision();
+
+ List<ObjectInspector> foi = new ArrayList<ObjectInspector>();
+ foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+ foi.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(dtiSum));
+ List<String> fname = new ArrayList<String>();
+ fname.add("count");
+ fname.add("sum");
+ soi = ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
+ }
+
+ private Aggregation getCurrentAggregationBuffer(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int bufferIndex,
+ int row) {
+ VectorAggregationBufferRow mySet = aggregationBufferSets[row];
+ Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(bufferIndex);
+ return myagg;
+ }
+
+ @Override
+ public void aggregateInputSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int bufferIndex,
+ VectorizedRowBatch batch) throws HiveException {
+
+ int batchSize = batch.size;
+
+ if (batchSize == 0) {
+ return;
+ }
+
+ inputExpression.evaluate(batch);
+
+ DecimalColumnVector inputVector = ( DecimalColumnVector)batch.
+ cols[this.inputExpression.getOutputColumn()];
+ Decimal128[] vector = inputVector.vector;
+
+ if (inputVector.noNulls) {
+ if (inputVector.isRepeating) {
+ iterateNoNullsRepeatingWithAggregationSelection(
+ aggregationBufferSets, bufferIndex,
+ vector[0], batchSize);
+ } else {
+ if (batch.selectedInUse) {
+ iterateNoNullsSelectionWithAggregationSelection(
+ aggregationBufferSets, bufferIndex,
+ vector, batch.selected, batchSize);
+ } else {
+ iterateNoNullsWithAggregationSelection(
+ aggregationBufferSets, bufferIndex,
+ vector, batchSize);
+ }
+ }
+ } else {
+ if (inputVector.isRepeating) {
+ if (batch.selectedInUse) {
+ iterateHasNullsRepeatingSelectionWithAggregationSelection(
+ aggregationBufferSets, bufferIndex,
+ vector[0], batchSize, batch.selected, inputVector.isNull);
+ } else {
+ iterateHasNullsRepeatingWithAggregationSelection(
+ aggregationBufferSets, bufferIndex,
+ vector[0], batchSize, inputVector.isNull);
+ }
+ } else {
+ if (batch.selectedInUse) {
+ iterateHasNullsSelectionWithAggregationSelection(
+ aggregationBufferSets, bufferIndex,
+ vector, batchSize, batch.selected, inputVector.isNull);
+ } else {
+ iterateHasNullsWithAggregationSelection(
+ aggregationBufferSets, bufferIndex,
+ vector, batchSize, inputVector.isNull);
+ }
+ }
+ }
+ }
+
+ private void iterateNoNullsRepeatingWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int bufferIndex,
+ Decimal128 value,
+ int batchSize) {
+
+ for (int i=0; i < batchSize; ++i) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ bufferIndex,
+ i);
+ myagg.sumValueWithCheck(value, this.sumScale);
+ }
+ }
+
+ private void iterateNoNullsSelectionWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int bufferIndex,
+ Decimal128[] values,
+ int[] selection,
+ int batchSize) {
+
+ for (int i=0; i < batchSize; ++i) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ bufferIndex,
+ i);
+ myagg.sumValueWithCheck(values[selection[i]], this.sumScale);
+ }
+ }
+
+ private void iterateNoNullsWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int bufferIndex,
+ Decimal128[] values,
+ int batchSize) {
+ for (int i=0; i < batchSize; ++i) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ bufferIndex,
+ i);
+ myagg.sumValueWithCheck(values[i], this.sumScale);
+ }
+ }
+
+ private void iterateHasNullsRepeatingSelectionWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int bufferIndex,
+ Decimal128 value,
+ int batchSize,
+ int[] selection,
+ boolean[] isNull) {
+
+ for (int i=0; i < batchSize; ++i) {
+ if (!isNull[selection[i]]) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ bufferIndex,
+ i);
+ myagg.sumValueWithCheck(value, this.sumScale);
+ }
+ }
+
+ }
+
+ private void iterateHasNullsRepeatingWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int bufferIndex,
+ Decimal128 value,
+ int batchSize,
+ boolean[] isNull) {
+
+ for (int i=0; i < batchSize; ++i) {
+ if (!isNull[i]) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ bufferIndex,
+ i);
+ myagg.sumValueWithCheck(value, this.sumScale);
+ }
+ }
+ }
+
+ private void iterateHasNullsSelectionWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int bufferIndex,
+ Decimal128[] values,
+ int batchSize,
+ int[] selection,
+ boolean[] isNull) {
+
+ for (int j=0; j < batchSize; ++j) {
+ int i = selection[j];
+ if (!isNull[i]) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ bufferIndex,
+ j);
+ myagg.sumValueWithCheck(values[i], this.sumScale);
+ }
+ }
+ }
+
+ private void iterateHasNullsWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int bufferIndex,
+ Decimal128[] values,
+ int batchSize,
+ boolean[] isNull) {
+
+ for (int i=0; i < batchSize; ++i) {
+ if (!isNull[i]) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ bufferIndex,
+ i);
+ myagg.sumValueWithCheck(values[i], this.sumScale);
+ }
+ }
+ }
+
+
+ @Override
+ public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch)
+ throws HiveException {
+
+ inputExpression.evaluate(batch);
+
+ DecimalColumnVector inputVector =
+ (DecimalColumnVector)batch.cols[this.inputExpression.getOutputColumn()];
+
+ int batchSize = batch.size;
+
+ if (batchSize == 0) {
+ return;
+ }
+
+ Aggregation myagg = (Aggregation)agg;
+
+ Decimal128[] vector = inputVector.vector;
+
+ if (inputVector.isRepeating) {
+ if (inputVector.noNulls) {
+ if (myagg.isNull) {
+ myagg.isNull = false;
+ myagg.sum.zeroClear();
+ myagg.count = 0;
+ }
+ scratchDecimal.update(batchSize);
+ scratchDecimal.multiplyDestructive(vector[0], vector[0].getScale());
+ myagg.sum.update(scratchDecimal);
+ myagg.count += batchSize;
+ }
+ return;
+ }
+
+ if (!batch.selectedInUse && inputVector.noNulls) {
+ iterateNoSelectionNoNulls(myagg, vector, batchSize);
+ }
+ else if (!batch.selectedInUse) {
+ iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull);
+ }
+ else if (inputVector.noNulls){
+ iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected);
+ }
+ else {
+ iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected);
+ }
+ }
+
+ private void iterateSelectionHasNulls(
+ Aggregation myagg,
+ Decimal128[] vector,
+ int batchSize,
+ boolean[] isNull,
+ int[] selected) {
+
+ for (int j=0; j< batchSize; ++j) {
+ int i = selected[j];
+ if (!isNull[i]) {
+ Decimal128 value = vector[i];
+ myagg.sumValueWithCheck(value, this.sumScale);
+ }
+ }
+ }
+
+ private void iterateSelectionNoNulls(
+ Aggregation myagg,
+ Decimal128[] vector,
+ int batchSize,
+ int[] selected) {
+
+ if (myagg.isNull) {
+ myagg.isNull = false;
+ myagg.sum.zeroClear();
+ myagg.count = 0;
+ }
+
+ for (int i=0; i< batchSize; ++i) {
+ Decimal128 value = vector[selected[i]];
+ myagg.sumValueNoCheck(value, this.sumScale);
+ }
+ }
+
+ private void iterateNoSelectionHasNulls(
+ Aggregation myagg,
+ Decimal128[] vector,
+ int batchSize,
+ boolean[] isNull) {
+
+ for(int i=0;i<batchSize;++i) {
+ if (!isNull[i]) {
+ Decimal128 value = vector[i];
+ myagg.sumValueWithCheck(value, this.sumScale);
+ }
+ }
+ }
+
+ private void iterateNoSelectionNoNulls(
+ Aggregation myagg,
+ Decimal128[] vector,
+ int batchSize) {
+ if (myagg.isNull) {
+ myagg.isNull = false;
+ myagg.sum.zeroClear();
+ myagg.count = 0;
+ }
+
+ for (int i=0;i<batchSize;++i) {
+ Decimal128 value = vector[i];
+ myagg.sumValueNoCheck(value, this.sumScale);
+ }
+ }
+
+ @Override
+ public AggregationBuffer getNewAggregationBuffer() throws HiveException {
+ return new Aggregation();
+ }
+
+ @Override
+ public void reset(AggregationBuffer agg) throws HiveException {
+ Aggregation myAgg = (Aggregation) agg;
+ myAgg.isNull = true;
+ }
+
+ @Override
+ public Object evaluateOutput(
+ AggregationBuffer agg) throws HiveException {
+ Aggregation myagg = (Aggregation) agg;
+ if (myagg.isNull) {
+ return null;
+ }
+ else {
+ assert(0 < myagg.count);
+ resultCount.set (myagg.count);
+ int bufferIndex = myagg.sum.fastSerializeForHiveDecimal(scratch);
+ resultSum.set(scratch.getBytes(bufferIndex), (int) sumScale);
+ return partialResult;
+ }
+ }
+
+ @Override
+ public ObjectInspector getOutputObjectInspector() {
+ return soi;
+ }
+
+ @Override
+ public int getAggregationBufferFixedSize() {
+ JavaDataModel model = JavaDataModel.get();
+ return JavaDataModel.alignUp(
+ model.object() +
+ model.primitive2() * 2,
+ model.memoryAlign());
+ }
+
+ @Override
+ public void init(AggregationDesc desc) throws HiveException {
+ ExprNodeDesc inputExpr = desc.getParameters().get(0);
+ DecimalTypeInfo tiInput = (DecimalTypeInfo) inputExpr.getTypeInfo();
+ this.inputScale = (short) tiInput.scale();
+ this.inputPrecision = (short) tiInput.precision();
+
+ initPartialResultInspector();
+ }
+
+ public VectorExpression getInputExpression() {
+ return inputExpression;
+ }
+
+ public void setInputExpression(VectorExpression inputExpression) {
+ this.inputExpression = inputExpression;
+ }
+}
+
Added: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java?rev=1569850&view=auto
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java (added)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java Wed Feb 19 17:43:20 2014
@@ -0,0 +1,436 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates;
+
+import org.apache.hadoop.hive.common.type.Decimal128;
+import org.apache.hadoop.hive.common.type.HiveDecimal;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression;
+import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression;
+import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow;
+import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
+import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.plan.AggregationDesc;
+import org.apache.hadoop.hive.ql.util.JavaDataModel;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+/**
+* VectorUDAFSumDecimal. Vectorized implementation for SUM aggregates.
+*/
+@Description(name = "sum",
+ value = "_FUNC_(expr) - Returns the sum value of expr (vectorized, type: decimal)")
+public class VectorUDAFSumDecimal extends VectorAggregateExpression {
+
+ private static final long serialVersionUID = 1L;
+
+ /**
+ * class for storing the current aggregate value.
+ */
+ private static final class Aggregation implements AggregationBuffer {
+
+ private static final long serialVersionUID = 1L;
+
+ transient private Decimal128 sum = new Decimal128();
+ transient private boolean isNull;
+
+ public void sumValue(Decimal128 value) {
+ if (isNull) {
+ sum.update(value);
+ isNull = false;
+ } else {
+ sum.addDestructive(value, value.getScale());
+ }
+ }
+
+ @Override
+ public int getVariableSize() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ private VectorExpression inputExpression;
+ transient private final Decimal128 scratchDecimal;
+
+ public VectorUDAFSumDecimal(VectorExpression inputExpression) {
+ this();
+ this.inputExpression = inputExpression;
+ }
+
+ public VectorUDAFSumDecimal() {
+ super();
+ scratchDecimal = new Decimal128();
+ }
+
+ private Aggregation getCurrentAggregationBuffer(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ int row) {
+ VectorAggregationBufferRow mySet = aggregationBufferSets[row];
+ Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregateIndex);
+ return myagg;
+ }
+
+ @Override
+ public void aggregateInputSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ VectorizedRowBatch batch) throws HiveException {
+
+ int batchSize = batch.size;
+
+ if (batchSize == 0) {
+ return;
+ }
+
+ inputExpression.evaluate(batch);
+
+ DecimalColumnVector inputVector = (DecimalColumnVector)batch.
+ cols[this.inputExpression.getOutputColumn()];
+ Decimal128[] vector = inputVector.vector;
+
+ if (inputVector.noNulls) {
+ if (inputVector.isRepeating) {
+ iterateNoNullsRepeatingWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex,
+ vector[0], batchSize);
+ } else {
+ if (batch.selectedInUse) {
+ iterateNoNullsSelectionWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex,
+ vector, batch.selected, batchSize);
+ } else {
+ iterateNoNullsWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex,
+ vector, batchSize);
+ }
+ }
+ } else {
+ if (inputVector.isRepeating) {
+ if (batch.selectedInUse) {
+ iterateHasNullsRepeatingSelectionWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex,
+ vector[0], batchSize, batch.selected, inputVector.isNull);
+ } else {
+ iterateHasNullsRepeatingWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex,
+ vector[0], batchSize, inputVector.isNull);
+ }
+ } else {
+ if (batch.selectedInUse) {
+ iterateHasNullsSelectionWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex,
+ vector, batchSize, batch.selected, inputVector.isNull);
+ } else {
+ iterateHasNullsWithAggregationSelection(
+ aggregationBufferSets, aggregateIndex,
+ vector, batchSize, inputVector.isNull);
+ }
+ }
+ }
+ }
+
+ private void iterateNoNullsRepeatingWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ Decimal128 value,
+ int batchSize) {
+
+ for (int i=0; i < batchSize; ++i) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ aggregateIndex,
+ i);
+ myagg.sumValue(value);
+ }
+ }
+
+ private void iterateNoNullsSelectionWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ Decimal128[] values,
+ int[] selection,
+ int batchSize) {
+
+ for (int i=0; i < batchSize; ++i) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ aggregateIndex,
+ i);
+ myagg.sumValue(values[selection[i]]);
+ }
+ }
+
+ private void iterateNoNullsWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ Decimal128[] values,
+ int batchSize) {
+ for (int i=0; i < batchSize; ++i) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ aggregateIndex,
+ i);
+ myagg.sumValue(values[i]);
+ }
+ }
+
+ private void iterateHasNullsRepeatingSelectionWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ Decimal128 value,
+ int batchSize,
+ int[] selection,
+ boolean[] isNull) {
+
+ for (int i=0; i < batchSize; ++i) {
+ if (!isNull[selection[i]]) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ aggregateIndex,
+ i);
+ myagg.sumValue(value);
+ }
+ }
+
+ }
+
+ private void iterateHasNullsRepeatingWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ Decimal128 value,
+ int batchSize,
+ boolean[] isNull) {
+
+ for (int i=0; i < batchSize; ++i) {
+ if (!isNull[i]) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ aggregateIndex,
+ i);
+ myagg.sumValue(value);
+ }
+ }
+ }
+
+ private void iterateHasNullsSelectionWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ Decimal128[] values,
+ int batchSize,
+ int[] selection,
+ boolean[] isNull) {
+
+ for (int j=0; j < batchSize; ++j) {
+ int i = selection[j];
+ if (!isNull[i]) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ aggregateIndex,
+ j);
+ myagg.sumValue(values[i]);
+ }
+ }
+ }
+
+ private void iterateHasNullsWithAggregationSelection(
+ VectorAggregationBufferRow[] aggregationBufferSets,
+ int aggregateIndex,
+ Decimal128[] values,
+ int batchSize,
+ boolean[] isNull) {
+
+ for (int i=0; i < batchSize; ++i) {
+ if (!isNull[i]) {
+ Aggregation myagg = getCurrentAggregationBuffer(
+ aggregationBufferSets,
+ aggregateIndex,
+ i);
+ myagg.sumValue(values[i]);
+ }
+ }
+ }
+
+
+ @Override
+ public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch)
+ throws HiveException {
+
+ inputExpression.evaluate(batch);
+
+ DecimalColumnVector inputVector = (DecimalColumnVector)batch.
+ cols[this.inputExpression.getOutputColumn()];
+
+ int batchSize = batch.size;
+
+ if (batchSize == 0) {
+ return;
+ }
+
+ Aggregation myagg = (Aggregation)agg;
+
+ Decimal128[] vector = inputVector.vector;
+
+ if (inputVector.isRepeating) {
+ if (inputVector.noNulls) {
+ if (myagg.isNull) {
+ myagg.isNull = false;
+ myagg.sum.zeroClear();
+ }
+ scratchDecimal.update(batchSize);
+ scratchDecimal.multiplyDestructive(vector[0], vector[0].getScale());
+ myagg.sum.update(scratchDecimal);
+ }
+ return;
+ }
+
+ if (!batch.selectedInUse && inputVector.noNulls) {
+ iterateNoSelectionNoNulls(myagg, vector, batchSize);
+ }
+ else if (!batch.selectedInUse) {
+ iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull);
+ }
+ else if (inputVector.noNulls){
+ iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected);
+ }
+ else {
+ iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected);
+ }
+ }
+
+ private void iterateSelectionHasNulls(
+ Aggregation myagg,
+ Decimal128[] vector,
+ int batchSize,
+ boolean[] isNull,
+ int[] selected) {
+
+ for (int j=0; j< batchSize; ++j) {
+ int i = selected[j];
+ if (!isNull[i]) {
+ Decimal128 value = vector[i];
+ if (myagg.isNull) {
+ myagg.isNull = false;
+ myagg.sum.zeroClear();
+ }
+ myagg.sum.addDestructive(value, value.getScale());
+ }
+ }
+ }
+
+ private void iterateSelectionNoNulls(
+ Aggregation myagg,
+ Decimal128[] vector,
+ int batchSize,
+ int[] selected) {
+
+ if (myagg.isNull) {
+ myagg.sum.zeroClear();
+ myagg.isNull = false;
+ }
+
+ for (int i=0; i< batchSize; ++i) {
+ Decimal128 value = vector[selected[i]];
+ myagg.sum.addDestructive(value, value.getScale());
+ }
+ }
+
+ private void iterateNoSelectionHasNulls(
+ Aggregation myagg,
+ Decimal128[] vector,
+ int batchSize,
+ boolean[] isNull) {
+
+ for(int i=0;i<batchSize;++i) {
+ if (!isNull[i]) {
+ Decimal128 value = vector[i];
+ if (myagg.isNull) {
+ myagg.sum.zeroClear();
+ myagg.isNull = false;
+ }
+ myagg.sum.addDestructive(value, value.getScale());
+ }
+ }
+ }
+
+ private void iterateNoSelectionNoNulls(
+ Aggregation myagg,
+ Decimal128[] vector,
+ int batchSize) {
+ if (myagg.isNull) {
+ myagg.sum.zeroClear();
+ myagg.isNull = false;
+ }
+
+ for (int i=0;i<batchSize;++i) {
+ Decimal128 value = vector[i];
+ myagg.sum.addDestructive(value, value.getScale());
+ }
+ }
+
+ @Override
+ public AggregationBuffer getNewAggregationBuffer() throws HiveException {
+ return new Aggregation();
+ }
+
+ @Override
+ public void reset(AggregationBuffer agg) throws HiveException {
+ Aggregation myAgg = (Aggregation) agg;
+ myAgg.isNull = true;
+ }
+
+ @Override
+ public Object evaluateOutput(AggregationBuffer agg) throws HiveException {
+ Aggregation myagg = (Aggregation) agg;
+ if (myagg.isNull) {
+ return null;
+ }
+ else {
+ return HiveDecimal.create(myagg.sum.toBigDecimal());
+ }
+ }
+
+ @Override
+ public ObjectInspector getOutputObjectInspector() {
+ return PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector;
+ }
+
+ @Override
+ public int getAggregationBufferFixedSize() {
+ JavaDataModel model = JavaDataModel.get();
+ return JavaDataModel.alignUp(
+ model.object(),
+ model.memoryAlign());
+ }
+
+ @Override
+ public void init(AggregationDesc desc) throws HiveException {
+ // No-op
+ }
+
+ public VectorExpression getInputExpression() {
+ return inputExpression;
+ }
+
+ public void setInputExpression(VectorExpression inputExpression) {
+ this.inputExpression = inputExpression;
+ }
+}
+
Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java?rev=1569850&r1=1569849&r2=1569850&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java Wed Feb 19 17:43:20 2014
@@ -174,15 +174,9 @@ public class GenericUDAFAverage extends
return PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(typeInfo);
}
- /**
- * The result type has the same number of integer digits and 4 more decimal digits.
- */
private DecimalTypeInfo deriveResultDecimalTypeInfo() {
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
- int scale = inputOI.scale();
- int intPart = inputOI.precision() - scale;
- scale = Math.min(scale + 4, HiveDecimal.MAX_SCALE - intPart);
- return TypeInfoFactory.getDecimalTypeInfo(intPart + scale, scale);
+ return GenericUDAFAverage.deriveSumTypeInfo(inputOI.scale(), inputOI.precision());
} else {
PrimitiveObjectInspector sfOI = (PrimitiveObjectInspector) sumFieldOI;
return (DecimalTypeInfo) sfOI.getTypeInfo();
@@ -367,4 +361,17 @@ public class GenericUDAFAverage extends
return doTerminate((AverageAggregationBuffer<TYPE>)aggregation);
}
}
+
+ /**
+ * The result type has the same number of integer digits and 4 more decimal digits
+ * This is exposed as static so that the vectorized AVG operator use the same logic
+ * @param scale
+ * @param precision
+ * @return
+ */
+ public static DecimalTypeInfo deriveSumTypeInfo(int scale, int precision) {
+ int intPart = precision - scale;
+ scale = Math.min(scale + 4, HiveDecimal.MAX_SCALE - intPart);
+ return TypeInfoFactory.getDecimalTypeInfo(intPart + scale, scale);
+ }
}