You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/01/16 04:07:49 UTC

spark git commit: [SPARK-12840] [SQL] Support passing arbitrary objects (not just expressions) into code generated classes

Repository: spark
Updated Branches:
  refs/heads/master 9039333c0 -> 242efb754


[SPARK-12840] [SQL] Support passing arbitrary objects (not just expressions) into code generated classes

This is a refactor to support codegen for aggregation and broadcast join.

Author: Davies Liu <da...@databricks.com>

Closes #10777 from davies/rename2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/242efb75
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/242efb75
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/242efb75

Branch: refs/heads/master
Commit: 242efb7546084592a5e8122549a27117977303fb
Parents: 9039333
Author: Davies Liu <da...@databricks.com>
Authored: Fri Jan 15 19:07:42 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Fri Jan 15 19:07:42 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/ScalaUDF.scala |  6 +++---
 .../expressions/codegen/CodeGenerator.scala       | 18 ++++++++----------
 .../expressions/codegen/CodegenFallback.scala     |  5 +++--
 .../codegen/GenerateMutableProjection.scala       | 14 +++++++-------
 .../expressions/codegen/GenerateOrdering.scala    | 10 +++++-----
 .../expressions/codegen/GeneratePredicate.scala   | 10 +++++-----
 .../codegen/GenerateSafeProjection.scala          | 12 ++++++------
 .../codegen/GenerateUnsafeProjection.scala        | 14 +++++++-------
 .../codegen/GenerateUnsafeRowJoiner.scala         |  2 +-
 .../sql/catalyst/expressions/predicates.scala     |  2 +-
 .../columnar/GenerateColumnAccessor.scala         |  4 ++--
 11 files changed, 48 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/242efb75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 4035c9d..6816947 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -985,7 +985,7 @@ case class ScalaUDF(
     ctx.addMutableState(converterClassName, converterTerm,
       s"this.$converterTerm = ($converterClassName)$typeConvertersClassName" +
         s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" +
-          s"expressions[$expressionIdx]).getChildren().apply($index))).dataType());")
+          s"references[$expressionIdx]).getChildren().apply($index))).dataType());")
     converterTerm
   }
 
@@ -1005,7 +1005,7 @@ case class ScalaUDF(
     val catalystConverterTermIdx = ctx.references.size - 1
     ctx.addMutableState(converterClassName, catalystConverterTerm,
       s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" +
-        s".createToCatalystConverter((($scalaUDFClassName)expressions" +
+        s".createToCatalystConverter((($scalaUDFClassName)references" +
           s"[$catalystConverterTermIdx]).dataType());")
 
     val resultTerm = ctx.freshName("result")
@@ -1020,7 +1020,7 @@ case class ScalaUDF(
     val funcTerm = ctx.freshName("udf")
     val funcExpressionIdx = ctx.references.size - 1
     ctx.addMutableState(funcClassName, funcTerm,
-      s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)expressions" +
+      s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)references" +
         s"[$funcExpressionIdx]).userDefinedFunc());")
 
     // codegen for children expressions

http://git-wip-us.apache.org/repos/asf/spark/blob/242efb75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 1c7083b..f3a39a0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -45,16 +45,15 @@ import org.apache.spark.util.Utils
 case class ExprCode(var code: String, var isNull: String, var value: String)
 
 /**
- * A context for codegen, which is used to bookkeeping the expressions those are not supported
- * by codegen, then they are evaluated directly. The unsupported expression is appended at the
- * end of `references`, the position of it is kept in the code, used to access and evaluate it.
+ * A context for codegen, tracking a list of objects that could be passed into generated Java
+ * function.
  */
 class CodegenContext {
 
   /**
-   * Holding all the expressions those do not support codegen, will be evaluated directly.
+   * Holding a list of objects that could be used passed into generated class.
    */
-  val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]()
+  val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]()
 
   /**
    * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a
@@ -400,7 +399,7 @@ class CodegenContext {
     // Add each expression tree and compute the common subexpressions.
     expressions.foreach(equivalentExpressions.addExprTree(_))
 
-    // Get all the exprs that appear at least twice and set up the state for subexpression
+    // Get all the expressions that appear at least twice and set up the state for subexpression
     // elimination.
     val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
     commonExprs.foreach(e => {
@@ -465,7 +464,7 @@ class CodegenContext {
  * into generated class.
  */
 abstract class GeneratedClass {
-  def generate(expressions: Array[Expression]): Any
+  def generate(references: Array[Any]): Any
 }
 
 /**
@@ -475,8 +474,6 @@ abstract class GeneratedClass {
  */
 abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging {
 
-  protected val exprType: String = classOf[Expression].getName
-  protected val mutableRowType: String = classOf[MutableRow].getName
   protected val genericMutableRowType: String = classOf[GenericMutableRow].getName
 
   protected def declareMutableStates(ctx: CodegenContext): String = {
@@ -534,7 +531,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
       classOf[UnsafeArrayData].getName,
       classOf[MapData].getName,
       classOf[UnsafeMapData].getName,
-      classOf[MutableRow].getName
+      classOf[MutableRow].getName,
+      classOf[Expression].getName
     ))
     evaluator.setExtendedClass(classOf[GeneratedClass])
 

http://git-wip-us.apache.org/repos/asf/spark/blob/242efb75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index c98b735..cface21 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -30,12 +30,13 @@ trait CodegenFallback extends Expression {
       case _ =>
     }
 
+    val idx = ctx.references.length
     ctx.references += this
     val objectTerm = ctx.freshName("obj")
     if (nullable) {
       s"""
         /* expression: ${this.toCommentSafeString} */
-        Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW});
+        Object $objectTerm = ((Expression) references[$idx]).eval(${ctx.INPUT_ROW});
         boolean ${ev.isNull} = $objectTerm == null;
         ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
         if (!${ev.isNull}) {
@@ -46,7 +47,7 @@ trait CodegenFallback extends Expression {
       ev.isNull = "false"
       s"""
         /* expression: ${this.toCommentSafeString} */
-        Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW});
+        Object $objectTerm = ((Expression) references[$idx]).eval(${ctx.INPUT_ROW});
         ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
       """
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/242efb75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index a6ec242..63d13a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -99,24 +99,24 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
     val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates)
 
     val code = s"""
-      public java.lang.Object generate($exprType[] expr) {
-        return new SpecificMutableProjection(expr);
+      public java.lang.Object generate(Object[] references) {
+        return new SpecificMutableProjection(references);
       }
 
       class SpecificMutableProjection extends ${classOf[BaseMutableProjection].getName} {
 
-        private $exprType[] expressions;
-        private $mutableRowType mutableRow;
+        private Object[] references;
+        private MutableRow mutableRow;
         ${declareMutableStates(ctx)}
         ${declareAddedFunctions(ctx)}
 
-        public SpecificMutableProjection($exprType[] expr) {
-          expressions = expr;
+        public SpecificMutableProjection(Object[] references) {
+          this.references = references;
           mutableRow = new $genericMutableRowType(${expressions.size});
           ${initMutableStates(ctx)}
         }
 
-        public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) {
+        public ${classOf[BaseMutableProjection].getName} target(MutableRow row) {
           mutableRow = row;
           return this;
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/242efb75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index 88bcf5b..e033f62 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -111,18 +111,18 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
     val ctx = newCodeGenContext()
     val comparisons = genComparisons(ctx, ordering)
     val code = s"""
-      public SpecificOrdering generate($exprType[] expr) {
-        return new SpecificOrdering(expr);
+      public SpecificOrdering generate(Object[] references) {
+        return new SpecificOrdering(references);
       }
 
       class SpecificOrdering extends ${classOf[BaseOrdering].getName} {
 
-        private $exprType[] expressions;
+        private Object[] references;
         ${declareMutableStates(ctx)}
         ${declareAddedFunctions(ctx)}
 
-        public SpecificOrdering($exprType[] expr) {
-          expressions = expr;
+        public SpecificOrdering(Object[] references) {
+          this.references = references;
           ${initMutableStates(ctx)}
         }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/242efb75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index 457b4f0..6fbe12f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -41,17 +41,17 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
     val ctx = newCodeGenContext()
     val eval = predicate.gen(ctx)
     val code = s"""
-      public SpecificPredicate generate($exprType[] expr) {
-        return new SpecificPredicate(expr);
+      public SpecificPredicate generate(Object[] references) {
+        return new SpecificPredicate(references);
       }
 
       class SpecificPredicate extends ${classOf[Predicate].getName} {
-        private final $exprType[] expressions;
+        private final Object[] references;
         ${declareMutableStates(ctx)}
         ${declareAddedFunctions(ctx)}
 
-        public SpecificPredicate($exprType[] expr) {
-          expressions = expr;
+        public SpecificPredicate(Object[] references) {
+          this.references = references;
           ${initMutableStates(ctx)}
         }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/242efb75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 8651707..10bd9c6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -152,19 +152,19 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
     }
     val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes)
     val code = s"""
-      public java.lang.Object generate($exprType[] expr) {
-        return new SpecificSafeProjection(expr);
+      public java.lang.Object generate(Object[] references) {
+        return new SpecificSafeProjection(references);
       }
 
       class SpecificSafeProjection extends ${classOf[BaseProjection].getName} {
 
-        private $exprType[] expressions;
-        private $mutableRowType mutableRow;
+        private Object[] references;
+        private MutableRow mutableRow;
         ${declareMutableStates(ctx)}
         ${declareAddedFunctions(ctx)}
 
-        public SpecificSafeProjection($exprType[] expr) {
-          expressions = expr;
+        public SpecificSafeProjection(Object[] references) {
+          this.references = references;
           mutableRow = new $genericMutableRowType(${expressions.size});
           ${initMutableStates(ctx)}
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/242efb75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 3a92992..1a0565a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -320,8 +320,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
     create(canonicalize(expressions), subexpressionEliminationEnabled)
   }
 
-  protected def create(expressions: Seq[Expression]): UnsafeProjection = {
-    create(expressions, subexpressionEliminationEnabled = false)
+  protected def create(references: Seq[Expression]): UnsafeProjection = {
+    create(references, subexpressionEliminationEnabled = false)
   }
 
   private def create(
@@ -331,20 +331,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
     val eval = createCode(ctx, expressions, subexpressionEliminationEnabled)
 
     val code = s"""
-      public java.lang.Object generate($exprType[] exprs) {
-        return new SpecificUnsafeProjection(exprs);
+      public java.lang.Object generate(Object[] references) {
+        return new SpecificUnsafeProjection(references);
       }
 
       class SpecificUnsafeProjection extends ${classOf[UnsafeProjection].getName} {
 
-        private $exprType[] expressions;
+        private Object[] references;
 
         ${declareMutableStates(ctx)}
 
         ${declareAddedFunctions(ctx)}
 
-        public SpecificUnsafeProjection($exprType[] expressions) {
-          this.expressions = expressions;
+        public SpecificUnsafeProjection(Object[] references) {
+          this.references = references;
           ${initMutableStates(ctx)}
         }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/242efb75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index 88b3c5e..8781cc7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -158,7 +158,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
 
     // ------------------------ Finally, put everything together  --------------------------- //
     val code = s"""
-       |public java.lang.Object generate($exprType[] exprs) {
+       |public java.lang.Object generate(Object[] references) {
        |  return new SpecificUnsafeRowJoiner();
        |}
        |

http://git-wip-us.apache.org/repos/asf/spark/blob/242efb75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index a3c10c8..c290aa8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -221,7 +221,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
     val hsetTerm = ctx.freshName("hset")
     val hasNullTerm = ctx.freshName("hasNull")
     ctx.addMutableState(setName, hsetTerm,
-      s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();")
+      s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();")
     ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);")
     s"""
       ${childGen.code}

http://git-wip-us.apache.org/repos/asf/spark/blob/242efb75/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index 55e2c0e..7888e34 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -123,7 +123,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
       import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
       import org.apache.spark.sql.execution.columnar.MutableUnsafeRow;
 
-      public SpecificColumnarIterator generate($exprType[] expr) {
+      public SpecificColumnarIterator generate(Object[] references) {
         return new SpecificColumnarIterator();
       }
 
@@ -190,6 +190,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
 
     logDebug(s"Generated ColumnarIterator: ${CodeFormatter.format(code)}")
 
-    compile(code).generate(ctx.references.toArray).asInstanceOf[ColumnarIterator]
+    compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator]
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org