You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hive.apache.org by ha...@apache.org on 2013/04/21 17:51:50 UTC
svn commit: r1470312 - in /hive/trunk/ql/src:
java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFLeadLag.java
test/queries/clientpositive/windowing_expressions.q
test/results/clientpositive/windowing_expressions.q.out
Author: hashutosh
Date: Sun Apr 21 15:51:49 2013
New Revision: 1470312
URL: http://svn.apache.org/r1470312
Log:
HIVE-4130 : Bring the Lead/Lag UDFs interface in line with Lead/Lag UDAFs (Harish Butani via Ashutosh Chauhan)
Modified:
hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFLeadLag.java
hive/trunk/ql/src/test/queries/clientpositive/windowing_expressions.q
hive/trunk/ql/src/test/results/clientpositive/windowing_expressions.q.out
Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFLeadLag.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFLeadLag.java?rev=1470312&r1=1470311&r2=1470312&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFLeadLag.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFLeadLag.java Sun Apr 21 15:51:49 2013
@@ -24,19 +24,23 @@ import org.apache.hadoop.hive.ql.exec.PT
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.IntWritable;
public abstract class GenericUDFLeadLag extends GenericUDF
{
transient ExprNodeEvaluator exprEvaluator;
transient PTFPartitionIterator<Object> pItr;
ObjectInspector firstArgOI;
-
- private PrimitiveObjectInspector amtOI;
+ ObjectInspector defaultArgOI;
+ Converter defaultValueConverter;
+ int amt;
static{
PTFUtils.makeTransient(GenericUDFLeadLag.class, "exprEvaluator");
@@ -46,27 +50,30 @@ public abstract class GenericUDFLeadLag
@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException
{
- DeferredObject amt = arguments[1];
- int intAmt = 0;
- try
- {
- intAmt = PrimitiveObjectInspectorUtils.getInt(amt.get(), amtOI);
- }
- catch (NullPointerException e)
- {
- intAmt = Integer.MAX_VALUE;
- }
- catch (NumberFormatException e)
- {
- intAmt = Integer.MAX_VALUE;
- }
+ Object defaultVal = null;
+ if(arguments.length == 3){
+ defaultVal = ObjectInspectorUtils.copyToStandardObject(
+ defaultValueConverter.convert(arguments[2].get()),
+ defaultArgOI);
+ }
int idx = pItr.getIndex() - 1;
+ int start = 0;
+ int end = pItr.getPartition().size();
try
{
- Object row = getRow(intAmt);
- Object ret = exprEvaluator.evaluate(row);
- ret = ObjectInspectorUtils.copyToStandardObject(ret, firstArgOI, ObjectInspectorCopyOption.WRITABLE);
+ Object ret = null;
+ int newIdx = getIndex(amt);
+
+ if(newIdx >= end || newIdx < start) {
+ ret = defaultVal;
+ }
+ else {
+ Object row = getRow(amt);
+ ret = exprEvaluator.evaluate(row);
+ ret = ObjectInspectorUtils.copyToStandardObject(ret,
+ firstArgOI, ObjectInspectorCopyOption.WRITABLE);
+ }
return ret;
}
finally
@@ -83,25 +90,41 @@ public abstract class GenericUDFLeadLag
public ObjectInspector initialize(ObjectInspector[] arguments)
throws UDFArgumentException
{
- // index has to be a primitive
- if (arguments[1] instanceof PrimitiveObjectInspector)
- {
- amtOI = (PrimitiveObjectInspector) arguments[1];
- }
- else
- {
- throw new UDFArgumentTypeException(1,
- "Primitive Type is expected but "
- + arguments[1].getTypeName() + "\" is found");
- }
-
- firstArgOI = arguments[0];
- return ObjectInspectorUtils.getStandardObjectInspector(firstArgOI,
- ObjectInspectorCopyOption.WRITABLE);
+ if (!(arguments.length >= 1 && arguments.length <= 3)) {
+ throw new UDFArgumentTypeException(arguments.length - 1,
+ "Incorrect invocation of " + _getFnName() + ": _FUNC_(expr, amt, default)");
+ }
+
+ amt = 1;
+
+ if (arguments.length > 1) {
+ ObjectInspector amtOI = arguments[1];
+ if ( !ObjectInspectorUtils.isConstantObjectInspector(amtOI) ||
+ (amtOI.getCategory() != ObjectInspector.Category.PRIMITIVE) ||
+ ((PrimitiveObjectInspector)amtOI).getPrimitiveCategory() !=
+ PrimitiveObjectInspector.PrimitiveCategory.INT )
+ {
+ throw new UDFArgumentTypeException(0,
+ _getFnName() + " amount must be a integer value "
+ + amtOI.getTypeName() + " was passed as parameter 1.");
+ }
+ Object o = ((ConstantObjectInspector)amtOI).
+ getWritableConstantValue();
+ amt = ((IntWritable)o).get();
+ }
+
+ if (arguments.length == 3) {
+ defaultArgOI = arguments[2];
+ ObjectInspectorConverters.getConverter(arguments[2], arguments[0]);
+ defaultValueConverter = ObjectInspectorConverters.getConverter(arguments[2], arguments[0]);
+
+ }
+
+ firstArgOI = arguments[0];
+ return ObjectInspectorUtils.getStandardObjectInspector(firstArgOI,
+ ObjectInspectorCopyOption.WRITABLE);
}
-
-
public ExprNodeEvaluator getExprEvaluator()
{
return exprEvaluator;
@@ -122,7 +145,39 @@ public abstract class GenericUDFLeadLag
this.pItr = pItr;
}
- @Override
+ public ObjectInspector getFirstArgOI() {
+ return firstArgOI;
+ }
+
+ public void setFirstArgOI(ObjectInspector firstArgOI) {
+ this.firstArgOI = firstArgOI;
+ }
+
+ public ObjectInspector getDefaultArgOI() {
+ return defaultArgOI;
+ }
+
+ public void setDefaultArgOI(ObjectInspector defaultArgOI) {
+ this.defaultArgOI = defaultArgOI;
+ }
+
+ public Converter getDefaultValueConverter() {
+ return defaultValueConverter;
+ }
+
+ public void setDefaultValueConverter(Converter defaultValueConverter) {
+ this.defaultValueConverter = defaultValueConverter;
+ }
+
+ public int getAmt() {
+ return amt;
+ }
+
+ public void setAmt(int amt) {
+ this.amt = amt;
+ }
+
+ @Override
public String getDisplayString(String[] children)
{
assert (children.length == 2);
@@ -140,6 +195,8 @@ public abstract class GenericUDFLeadLag
protected abstract Object getRow(int amt);
+ protected abstract int getIndex(int amt);
+
public static class GenericUDFLead extends GenericUDFLeadLag
{
@@ -150,6 +207,11 @@ public abstract class GenericUDFLeadLag
}
@Override
+ protected int getIndex(int amt) {
+ return pItr.getIndex() - 1 + amt;
+ }
+
+ @Override
protected Object getRow(int amt)
{
return pItr.lead(amt - 1);
@@ -166,6 +228,11 @@ public abstract class GenericUDFLeadLag
}
@Override
+ protected int getIndex(int amt) {
+ return pItr.getIndex() - 1 - amt;
+ }
+
+ @Override
protected Object getRow(int amt)
{
return pItr.lag(amt + 1);
Modified: hive/trunk/ql/src/test/queries/clientpositive/windowing_expressions.q
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/test/queries/clientpositive/windowing_expressions.q?rev=1470312&r1=1470311&r2=1470312&view=diff
==============================================================================
--- hive/trunk/ql/src/test/queries/clientpositive/windowing_expressions.q (original)
+++ hive/trunk/ql/src/test/queries/clientpositive/windowing_expressions.q Sun Apr 21 15:51:49 2013
@@ -35,7 +35,7 @@ create table over10k(
load data local inpath '../data/files/over10k' into table over10k;
select p_mfgr, p_retailprice, p_size,
-round(sum(p_retailprice),2) = round((sum(lag(p_retailprice,1)) - first_value(p_retailprice)) + last_value(p_retailprice),2)
+round(sum(p_retailprice),2) = round(sum(lag(p_retailprice,1,0.0)) + last_value(p_retailprice),2)
over(distribute by p_mfgr sort by p_retailprice),
max(p_retailprice) - min(p_retailprice) = last_value(p_retailprice) - first_value(p_retailprice)
over(distribute by p_mfgr sort by p_retailprice)
@@ -64,3 +64,9 @@ create table t2 (a1 int, b1 string);
from (select sum(i) over (), s from over10k) tt insert overwrite table t1 select * insert overwrite table t2 select * ;
select * from t1 limit 3;
select * from t2 limit 3;
+
+select p_mfgr, p_retailprice, p_size,
+round(sum(p_retailprice),2) + 50.0 = round(sum(lag(p_retailprice,1,50.0)) + last_value(p_retailprice),2)
+ over(distribute by p_mfgr sort by p_retailprice)
+from part
+limit 11;
Modified: hive/trunk/ql/src/test/results/clientpositive/windowing_expressions.q.out
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/test/results/clientpositive/windowing_expressions.q.out?rev=1470312&r1=1470311&r2=1470312&view=diff
==============================================================================
--- hive/trunk/ql/src/test/results/clientpositive/windowing_expressions.q.out (original)
+++ hive/trunk/ql/src/test/results/clientpositive/windowing_expressions.q.out Sun Apr 21 15:51:49 2013
@@ -77,7 +77,7 @@ POSTHOOK: query: load data local inpath
POSTHOOK: type: LOAD
POSTHOOK: Output: default@over10k
PREHOOK: query: select p_mfgr, p_retailprice, p_size,
-round(sum(p_retailprice),2) = round((sum(lag(p_retailprice,1)) - first_value(p_retailprice)) + last_value(p_retailprice),2)
+round(sum(p_retailprice),2) = round(sum(lag(p_retailprice,1,0.0)) + last_value(p_retailprice),2)
over(distribute by p_mfgr sort by p_retailprice),
max(p_retailprice) - min(p_retailprice) = last_value(p_retailprice) - first_value(p_retailprice)
over(distribute by p_mfgr sort by p_retailprice)
@@ -86,7 +86,7 @@ PREHOOK: type: QUERY
PREHOOK: Input: default@part
#### A masked pattern was here ####
POSTHOOK: query: select p_mfgr, p_retailprice, p_size,
-round(sum(p_retailprice),2) = round((sum(lag(p_retailprice,1)) - first_value(p_retailprice)) + last_value(p_retailprice),2)
+round(sum(p_retailprice),2) = round(sum(lag(p_retailprice,1,0.0)) + last_value(p_retailprice),2)
over(distribute by p_mfgr sort by p_retailprice),
max(p_retailprice) - min(p_retailprice) = last_value(p_retailprice) - first_value(p_retailprice)
over(distribute by p_mfgr sort by p_retailprice)
@@ -718,3 +718,34 @@ POSTHOOK: Lineage: t2.b1 SCRIPT [(over10
656584379 bob davidson
656584379 alice zipper
656584379 katie davidson
+PREHOOK: query: select p_mfgr, p_retailprice, p_size,
+round(sum(p_retailprice),2) + 50.0 = round(sum(lag(p_retailprice,1,50.0)) + last_value(p_retailprice),2)
+ over(distribute by p_mfgr sort by p_retailprice)
+from part
+limit 11
+PREHOOK: type: QUERY
+PREHOOK: Input: default@part
+#### A masked pattern was here ####
+POSTHOOK: query: select p_mfgr, p_retailprice, p_size,
+round(sum(p_retailprice),2) + 50.0 = round(sum(lag(p_retailprice,1,50.0)) + last_value(p_retailprice),2)
+ over(distribute by p_mfgr sort by p_retailprice)
+from part
+limit 11
+POSTHOOK: type: QUERY
+POSTHOOK: Input: default@part
+#### A masked pattern was here ####
+POSTHOOK: Lineage: t1.a1 SCRIPT [(over10k)over10k.FieldSchema(name:t, type:tinyint, comment:null), (over10k)over10k.FieldSchema(name:si, type:smallint, comment:null), (over10k)over10k.FieldSchema(name:i, type:int, comment:null), (over10k)over10k.FieldSchema(name:b, type:bigint, comment:null), (over10k)over10k.FieldSchema(name:f, type:float, comment:null), (over10k)over10k.FieldSchema(name:d, type:double, comment:null), (over10k)over10k.FieldSchema(name:bo, type:boolean, comment:null), (over10k)over10k.FieldSchema(name:s, type:string, comment:null), (over10k)over10k.FieldSchema(name:ts, type:timestamp, comment:null), (over10k)over10k.FieldSchema(name:dec, type:decimal, comment:null), (over10k)over10k.FieldSchema(name:bin, type:binary, comment:null), (over10k)over10k.FieldSchema(name:BLOCK__OFFSET__INSIDE__FILE, type:bigint, comment:), (over10k)over10k.FieldSchema(name:INPUT__FILE__NAME, type:string, comment:), ]
+POSTHOOK: Lineage: t1.b1 SCRIPT [(over10k)over10k.FieldSchema(name:t, type:tinyint, comment:null), (over10k)over10k.FieldSchema(name:si, type:smallint, comment:null), (over10k)over10k.FieldSchema(name:i, type:int, comment:null), (over10k)over10k.FieldSchema(name:b, type:bigint, comment:null), (over10k)over10k.FieldSchema(name:f, type:float, comment:null), (over10k)over10k.FieldSchema(name:d, type:double, comment:null), (over10k)over10k.FieldSchema(name:bo, type:boolean, comment:null), (over10k)over10k.FieldSchema(name:s, type:string, comment:null), (over10k)over10k.FieldSchema(name:ts, type:timestamp, comment:null), (over10k)over10k.FieldSchema(name:dec, type:decimal, comment:null), (over10k)over10k.FieldSchema(name:bin, type:binary, comment:null), (over10k)over10k.FieldSchema(name:BLOCK__OFFSET__INSIDE__FILE, type:bigint, comment:), (over10k)over10k.FieldSchema(name:INPUT__FILE__NAME, type:string, comment:), ]
+POSTHOOK: Lineage: t2.a1 SCRIPT [(over10k)over10k.FieldSchema(name:t, type:tinyint, comment:null), (over10k)over10k.FieldSchema(name:si, type:smallint, comment:null), (over10k)over10k.FieldSchema(name:i, type:int, comment:null), (over10k)over10k.FieldSchema(name:b, type:bigint, comment:null), (over10k)over10k.FieldSchema(name:f, type:float, comment:null), (over10k)over10k.FieldSchema(name:d, type:double, comment:null), (over10k)over10k.FieldSchema(name:bo, type:boolean, comment:null), (over10k)over10k.FieldSchema(name:s, type:string, comment:null), (over10k)over10k.FieldSchema(name:ts, type:timestamp, comment:null), (over10k)over10k.FieldSchema(name:dec, type:decimal, comment:null), (over10k)over10k.FieldSchema(name:bin, type:binary, comment:null), (over10k)over10k.FieldSchema(name:BLOCK__OFFSET__INSIDE__FILE, type:bigint, comment:), (over10k)over10k.FieldSchema(name:INPUT__FILE__NAME, type:string, comment:), ]
+POSTHOOK: Lineage: t2.b1 SCRIPT [(over10k)over10k.FieldSchema(name:t, type:tinyint, comment:null), (over10k)over10k.FieldSchema(name:si, type:smallint, comment:null), (over10k)over10k.FieldSchema(name:i, type:int, comment:null), (over10k)over10k.FieldSchema(name:b, type:bigint, comment:null), (over10k)over10k.FieldSchema(name:f, type:float, comment:null), (over10k)over10k.FieldSchema(name:d, type:double, comment:null), (over10k)over10k.FieldSchema(name:bo, type:boolean, comment:null), (over10k)over10k.FieldSchema(name:s, type:string, comment:null), (over10k)over10k.FieldSchema(name:ts, type:timestamp, comment:null), (over10k)over10k.FieldSchema(name:dec, type:decimal, comment:null), (over10k)over10k.FieldSchema(name:bin, type:binary, comment:null), (over10k)over10k.FieldSchema(name:BLOCK__OFFSET__INSIDE__FILE, type:bigint, comment:), (over10k)over10k.FieldSchema(name:INPUT__FILE__NAME, type:string, comment:), ]
+Manufacturer#1 1173.15 2 true
+Manufacturer#1 1173.15 2 true
+Manufacturer#1 1414.42 28 true
+Manufacturer#1 1602.59 6 true
+Manufacturer#1 1632.66 42 true
+Manufacturer#1 1753.76 34 true
+Manufacturer#2 1690.68 14 true
+Manufacturer#2 1698.66 25 true
+Manufacturer#2 1701.6 18 true
+Manufacturer#2 1800.7 40 true
+Manufacturer#2 2031.98 2 true