You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hive.apache.org by ha...@apache.org on 2013/04/21 17:51:50 UTC

svn commit: r1470312 - in /hive/trunk/ql/src: java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFLeadLag.java test/queries/clientpositive/windowing_expressions.q test/results/clientpositive/windowing_expressions.q.out

Author: hashutosh
Date: Sun Apr 21 15:51:49 2013
New Revision: 1470312

URL: http://svn.apache.org/r1470312
Log:
HIVE-4130 : Bring the Lead/Lag UDFs interface in line with Lead/Lag UDAFs (Harish Butani via Ashutosh Chauhan)

Modified:
    hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFLeadLag.java
    hive/trunk/ql/src/test/queries/clientpositive/windowing_expressions.q
    hive/trunk/ql/src/test/results/clientpositive/windowing_expressions.q.out

Modified: hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFLeadLag.java
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFLeadLag.java?rev=1470312&r1=1470311&r2=1470312&view=diff
==============================================================================
--- hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFLeadLag.java (original)
+++ hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFLeadLag.java Sun Apr 21 15:51:49 2013
@@ -24,19 +24,23 @@ import org.apache.hadoop.hive.ql.exec.PT
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
 import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.IntWritable;
 
 public abstract class GenericUDFLeadLag extends GenericUDF
 {
 	transient ExprNodeEvaluator exprEvaluator;
 	transient PTFPartitionIterator<Object> pItr;
 	ObjectInspector firstArgOI;
-
-	private PrimitiveObjectInspector amtOI;
+	ObjectInspector defaultArgOI;
+	Converter defaultValueConverter;
+	int amt;
 
 	static{
 		PTFUtils.makeTransient(GenericUDFLeadLag.class, "exprEvaluator");
@@ -46,27 +50,30 @@ public abstract class GenericUDFLeadLag 
 	@Override
 	public Object evaluate(DeferredObject[] arguments) throws HiveException
 	{
-		DeferredObject amt = arguments[1];
-		int intAmt = 0;
-		try
-		{
-			intAmt = PrimitiveObjectInspectorUtils.getInt(amt.get(), amtOI);
-		}
-		catch (NullPointerException e)
-		{
-			intAmt = Integer.MAX_VALUE;
-		}
-		catch (NumberFormatException e)
-		{
-			intAmt = Integer.MAX_VALUE;
-		}
+    Object defaultVal = null;
+    if(arguments.length == 3){
+      defaultVal =  ObjectInspectorUtils.copyToStandardObject(
+          defaultValueConverter.convert(arguments[2].get()),
+          defaultArgOI);
+    }
 
 		int idx = pItr.getIndex() - 1;
+		int start = 0;
+		int end = pItr.getPartition().size();
 		try
 		{
-			Object row = getRow(intAmt);
-			Object ret = exprEvaluator.evaluate(row);
-			ret = ObjectInspectorUtils.copyToStandardObject(ret, firstArgOI, ObjectInspectorCopyOption.WRITABLE);
+		  Object ret = null;
+		  int newIdx = getIndex(amt);
+
+		  if(newIdx >= end || newIdx < start) {
+        ret = defaultVal;
+		  }
+		  else {
+        Object row = getRow(amt);
+        ret = exprEvaluator.evaluate(row);
+        ret = ObjectInspectorUtils.copyToStandardObject(ret,
+            firstArgOI, ObjectInspectorCopyOption.WRITABLE);
+		  }
 			return ret;
 		}
 		finally
@@ -83,25 +90,41 @@ public abstract class GenericUDFLeadLag 
 	public ObjectInspector initialize(ObjectInspector[] arguments)
 			throws UDFArgumentException
 	{
-		// index has to be a primitive
-		if (arguments[1] instanceof PrimitiveObjectInspector)
-		{
-			amtOI = (PrimitiveObjectInspector) arguments[1];
-		}
-		else
-		{
-			throw new UDFArgumentTypeException(1,
-					"Primitive Type is expected but "
-							+ arguments[1].getTypeName() + "\" is found");
-		}
-
-		firstArgOI = arguments[0];
-		return ObjectInspectorUtils.getStandardObjectInspector(firstArgOI,
-				ObjectInspectorCopyOption.WRITABLE);
+    if (!(arguments.length >= 1 && arguments.length <= 3)) {
+      throw new UDFArgumentTypeException(arguments.length - 1,
+          "Incorrect invocation of " + _getFnName() + ": _FUNC_(expr, amt, default)");
+    }
+
+    amt = 1;
+
+    if (arguments.length > 1) {
+      ObjectInspector amtOI = arguments[1];
+      if ( !ObjectInspectorUtils.isConstantObjectInspector(amtOI) ||
+          (amtOI.getCategory() != ObjectInspector.Category.PRIMITIVE) ||
+          ((PrimitiveObjectInspector)amtOI).getPrimitiveCategory() !=
+          PrimitiveObjectInspector.PrimitiveCategory.INT )
+      {
+        throw new UDFArgumentTypeException(0,
+            _getFnName() + " amount must be a integer value "
+            + amtOI.getTypeName() + " was passed as parameter 1.");
+      }
+      Object o = ((ConstantObjectInspector)amtOI).
+          getWritableConstantValue();
+      amt = ((IntWritable)o).get();
+    }
+
+    if (arguments.length == 3) {
+      defaultArgOI = arguments[2];
+      ObjectInspectorConverters.getConverter(arguments[2], arguments[0]);
+      defaultValueConverter = ObjectInspectorConverters.getConverter(arguments[2], arguments[0]);
+
+    }
+
+    firstArgOI = arguments[0];
+    return ObjectInspectorUtils.getStandardObjectInspector(firstArgOI,
+        ObjectInspectorCopyOption.WRITABLE);
 	}
 
-
-
 	public ExprNodeEvaluator getExprEvaluator()
 	{
 		return exprEvaluator;
@@ -122,7 +145,39 @@ public abstract class GenericUDFLeadLag 
 		this.pItr = pItr;
 	}
 
-	@Override
+	public ObjectInspector getFirstArgOI() {
+    return firstArgOI;
+  }
+
+  public void setFirstArgOI(ObjectInspector firstArgOI) {
+    this.firstArgOI = firstArgOI;
+  }
+
+  public ObjectInspector getDefaultArgOI() {
+    return defaultArgOI;
+  }
+
+  public void setDefaultArgOI(ObjectInspector defaultArgOI) {
+    this.defaultArgOI = defaultArgOI;
+  }
+
+  public Converter getDefaultValueConverter() {
+    return defaultValueConverter;
+  }
+
+  public void setDefaultValueConverter(Converter defaultValueConverter) {
+    this.defaultValueConverter = defaultValueConverter;
+  }
+
+  public int getAmt() {
+    return amt;
+  }
+
+  public void setAmt(int amt) {
+    this.amt = amt;
+  }
+
+  @Override
 	public String getDisplayString(String[] children)
 	{
 		assert (children.length == 2);
@@ -140,6 +195,8 @@ public abstract class GenericUDFLeadLag 
 
 	protected abstract Object getRow(int amt);
 
+	protected abstract int getIndex(int amt);
+
 	public static class GenericUDFLead extends GenericUDFLeadLag
 	{
 
@@ -150,6 +207,11 @@ public abstract class GenericUDFLeadLag 
 		}
 
 		@Override
+		protected int getIndex(int amt) {
+		  return pItr.getIndex() - 1 + amt;
+		}
+
+		@Override
 		protected Object getRow(int amt)
 		{
 			return pItr.lead(amt - 1);
@@ -166,6 +228,11 @@ public abstract class GenericUDFLeadLag 
 		}
 
 		@Override
+    protected int getIndex(int amt) {
+      return pItr.getIndex() - 1 - amt;
+    }
+
+		@Override
 		protected Object getRow(int amt)
 		{
 			return pItr.lag(amt + 1);

Modified: hive/trunk/ql/src/test/queries/clientpositive/windowing_expressions.q
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/test/queries/clientpositive/windowing_expressions.q?rev=1470312&r1=1470311&r2=1470312&view=diff
==============================================================================
--- hive/trunk/ql/src/test/queries/clientpositive/windowing_expressions.q (original)
+++ hive/trunk/ql/src/test/queries/clientpositive/windowing_expressions.q Sun Apr 21 15:51:49 2013
@@ -35,7 +35,7 @@ create table over10k(
 load data local inpath '../data/files/over10k' into table over10k;
 
 select p_mfgr, p_retailprice, p_size,
-round(sum(p_retailprice),2) = round((sum(lag(p_retailprice,1)) - first_value(p_retailprice)) + last_value(p_retailprice),2) 
+round(sum(p_retailprice),2) = round(sum(lag(p_retailprice,1,0.0)) + last_value(p_retailprice),2) 
   over(distribute by p_mfgr sort by p_retailprice),
 max(p_retailprice) - min(p_retailprice) = last_value(p_retailprice) - first_value(p_retailprice)
   over(distribute by p_mfgr sort by p_retailprice)
@@ -64,3 +64,9 @@ create table t2 (a1 int, b1 string);
 from (select sum(i) over (), s from over10k) tt insert overwrite table t1 select * insert overwrite table t2 select * ;
 select * from t1 limit 3;
 select * from t2 limit 3;
+
+select p_mfgr, p_retailprice, p_size,
+round(sum(p_retailprice),2) + 50.0 = round(sum(lag(p_retailprice,1,50.0)) + last_value(p_retailprice),2) 
+  over(distribute by p_mfgr sort by p_retailprice)
+from part
+limit 11;

Modified: hive/trunk/ql/src/test/results/clientpositive/windowing_expressions.q.out
URL: http://svn.apache.org/viewvc/hive/trunk/ql/src/test/results/clientpositive/windowing_expressions.q.out?rev=1470312&r1=1470311&r2=1470312&view=diff
==============================================================================
--- hive/trunk/ql/src/test/results/clientpositive/windowing_expressions.q.out (original)
+++ hive/trunk/ql/src/test/results/clientpositive/windowing_expressions.q.out Sun Apr 21 15:51:49 2013
@@ -77,7 +77,7 @@ POSTHOOK: query: load data local inpath 
 POSTHOOK: type: LOAD
 POSTHOOK: Output: default@over10k
 PREHOOK: query: select p_mfgr, p_retailprice, p_size,
-round(sum(p_retailprice),2) = round((sum(lag(p_retailprice,1)) - first_value(p_retailprice)) + last_value(p_retailprice),2) 
+round(sum(p_retailprice),2) = round(sum(lag(p_retailprice,1,0.0)) + last_value(p_retailprice),2) 
   over(distribute by p_mfgr sort by p_retailprice),
 max(p_retailprice) - min(p_retailprice) = last_value(p_retailprice) - first_value(p_retailprice)
   over(distribute by p_mfgr sort by p_retailprice)
@@ -86,7 +86,7 @@ PREHOOK: type: QUERY
 PREHOOK: Input: default@part
 #### A masked pattern was here ####
 POSTHOOK: query: select p_mfgr, p_retailprice, p_size,
-round(sum(p_retailprice),2) = round((sum(lag(p_retailprice,1)) - first_value(p_retailprice)) + last_value(p_retailprice),2) 
+round(sum(p_retailprice),2) = round(sum(lag(p_retailprice,1,0.0)) + last_value(p_retailprice),2) 
   over(distribute by p_mfgr sort by p_retailprice),
 max(p_retailprice) - min(p_retailprice) = last_value(p_retailprice) - first_value(p_retailprice)
   over(distribute by p_mfgr sort by p_retailprice)
@@ -718,3 +718,34 @@ POSTHOOK: Lineage: t2.b1 SCRIPT [(over10
 656584379	bob davidson
 656584379	alice zipper
 656584379	katie davidson
+PREHOOK: query: select p_mfgr, p_retailprice, p_size,
+round(sum(p_retailprice),2) + 50.0 = round(sum(lag(p_retailprice,1,50.0)) + last_value(p_retailprice),2) 
+  over(distribute by p_mfgr sort by p_retailprice)
+from part
+limit 11
+PREHOOK: type: QUERY
+PREHOOK: Input: default@part
+#### A masked pattern was here ####
+POSTHOOK: query: select p_mfgr, p_retailprice, p_size,
+round(sum(p_retailprice),2) + 50.0 = round(sum(lag(p_retailprice,1,50.0)) + last_value(p_retailprice),2) 
+  over(distribute by p_mfgr sort by p_retailprice)
+from part
+limit 11
+POSTHOOK: type: QUERY
+POSTHOOK: Input: default@part
+#### A masked pattern was here ####
+POSTHOOK: Lineage: t1.a1 SCRIPT [(over10k)over10k.FieldSchema(name:t, type:tinyint, comment:null), (over10k)over10k.FieldSchema(name:si, type:smallint, comment:null), (over10k)over10k.FieldSchema(name:i, type:int, comment:null), (over10k)over10k.FieldSchema(name:b, type:bigint, comment:null), (over10k)over10k.FieldSchema(name:f, type:float, comment:null), (over10k)over10k.FieldSchema(name:d, type:double, comment:null), (over10k)over10k.FieldSchema(name:bo, type:boolean, comment:null), (over10k)over10k.FieldSchema(name:s, type:string, comment:null), (over10k)over10k.FieldSchema(name:ts, type:timestamp, comment:null), (over10k)over10k.FieldSchema(name:dec, type:decimal, comment:null), (over10k)over10k.FieldSchema(name:bin, type:binary, comment:null), (over10k)over10k.FieldSchema(name:BLOCK__OFFSET__INSIDE__FILE, type:bigint, comment:), (over10k)over10k.FieldSchema(name:INPUT__FILE__NAME, type:string, comment:), ]
+POSTHOOK: Lineage: t1.b1 SCRIPT [(over10k)over10k.FieldSchema(name:t, type:tinyint, comment:null), (over10k)over10k.FieldSchema(name:si, type:smallint, comment:null), (over10k)over10k.FieldSchema(name:i, type:int, comment:null), (over10k)over10k.FieldSchema(name:b, type:bigint, comment:null), (over10k)over10k.FieldSchema(name:f, type:float, comment:null), (over10k)over10k.FieldSchema(name:d, type:double, comment:null), (over10k)over10k.FieldSchema(name:bo, type:boolean, comment:null), (over10k)over10k.FieldSchema(name:s, type:string, comment:null), (over10k)over10k.FieldSchema(name:ts, type:timestamp, comment:null), (over10k)over10k.FieldSchema(name:dec, type:decimal, comment:null), (over10k)over10k.FieldSchema(name:bin, type:binary, comment:null), (over10k)over10k.FieldSchema(name:BLOCK__OFFSET__INSIDE__FILE, type:bigint, comment:), (over10k)over10k.FieldSchema(name:INPUT__FILE__NAME, type:string, comment:), ]
+POSTHOOK: Lineage: t2.a1 SCRIPT [(over10k)over10k.FieldSchema(name:t, type:tinyint, comment:null), (over10k)over10k.FieldSchema(name:si, type:smallint, comment:null), (over10k)over10k.FieldSchema(name:i, type:int, comment:null), (over10k)over10k.FieldSchema(name:b, type:bigint, comment:null), (over10k)over10k.FieldSchema(name:f, type:float, comment:null), (over10k)over10k.FieldSchema(name:d, type:double, comment:null), (over10k)over10k.FieldSchema(name:bo, type:boolean, comment:null), (over10k)over10k.FieldSchema(name:s, type:string, comment:null), (over10k)over10k.FieldSchema(name:ts, type:timestamp, comment:null), (over10k)over10k.FieldSchema(name:dec, type:decimal, comment:null), (over10k)over10k.FieldSchema(name:bin, type:binary, comment:null), (over10k)over10k.FieldSchema(name:BLOCK__OFFSET__INSIDE__FILE, type:bigint, comment:), (over10k)over10k.FieldSchema(name:INPUT__FILE__NAME, type:string, comment:), ]
+POSTHOOK: Lineage: t2.b1 SCRIPT [(over10k)over10k.FieldSchema(name:t, type:tinyint, comment:null), (over10k)over10k.FieldSchema(name:si, type:smallint, comment:null), (over10k)over10k.FieldSchema(name:i, type:int, comment:null), (over10k)over10k.FieldSchema(name:b, type:bigint, comment:null), (over10k)over10k.FieldSchema(name:f, type:float, comment:null), (over10k)over10k.FieldSchema(name:d, type:double, comment:null), (over10k)over10k.FieldSchema(name:bo, type:boolean, comment:null), (over10k)over10k.FieldSchema(name:s, type:string, comment:null), (over10k)over10k.FieldSchema(name:ts, type:timestamp, comment:null), (over10k)over10k.FieldSchema(name:dec, type:decimal, comment:null), (over10k)over10k.FieldSchema(name:bin, type:binary, comment:null), (over10k)over10k.FieldSchema(name:BLOCK__OFFSET__INSIDE__FILE, type:bigint, comment:), (over10k)over10k.FieldSchema(name:INPUT__FILE__NAME, type:string, comment:), ]
+Manufacturer#1	1173.15	2	true
+Manufacturer#1	1173.15	2	true
+Manufacturer#1	1414.42	28	true
+Manufacturer#1	1602.59	6	true
+Manufacturer#1	1632.66	42	true
+Manufacturer#1	1753.76	34	true
+Manufacturer#2	1690.68	14	true
+Manufacturer#2	1698.66	25	true
+Manufacturer#2	1701.6	18	true
+Manufacturer#2	1800.7	40	true
+Manufacturer#2	2031.98	2	true