You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/07/05 12:49:16 UTC

spark git commit: [SPARK-24361][SQL] Polish code block manipulation API

Repository: spark
Updated Branches:
  refs/heads/master 4be9f0c02 -> 32cfd3e75


[SPARK-24361][SQL] Polish code block manipulation API

## What changes were proposed in this pull request?

Current code block manipulation API is immature and hacky. We need a formal API to manipulate code blocks.

The basic idea is making `JavaCode`  as `TreeNode`. So we can use familiar `transform` API to manipulate code blocks and expressions in code blocks.

For example, we can replace `SimpleExprValue` in a code block like this:

```scala
code.transformExprValues {
  case SimpleExprValue("1 + 1", _) => aliasedParam
}
```

The example use case is splitting code to methods.

For example, we have an `ExprCode` containing generated code. But it is too long and we need to split it as method. Because statement-based expressions can't be directly passed into. We need to transform them as variables first:

```scala

def getExprValues(block: Block): Set[ExprValue] = block match {
  case c: CodeBlock =>
    c.blockInputs.collect {
      case e: ExprValue => e
    }.toSet
  case _ => Set.empty
}

def currentCodegenInputs(ctx: CodegenContext): Set[ExprValue] = {
  // Collects current variables in ctx.currentVars and ctx.INPUT_ROW.
  // It looks roughly like...
  ctx.currentVars.flatMap { v =>
    getExprValues(v.code) ++ Set(v.value, v.isNull)
  }.toSet + ctx.INPUT_ROW
}

// A code block of an expression contains too long code, making it as method
if (eval.code.length > 1024) {
  val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
    ...
  } else {
    ""
  }

  // Pick up variables and statements necessary to pass in.
  val currentVars = currentCodegenInputs(ctx)
  val varsPassIn = getExprValues(eval.code).intersect(currentVars)
  val aliasedExprs = HashMap.empty[SimpleExprValue, VariableValue]

  // Replace statement-based expressions which can't be directly passed in the method.
  val newCode = eval.code.transform {
    case block =>
      block.transformExprValues {
        case s: SimpleExprValue(_, javaType) if varsPassIn.contains(s) =>
          if (aliasedExprs.contains(s)) {
            aliasedExprs(s)
          } else {
            val aliasedVariable = JavaCode.variable(ctx.freshName("aliasedVar"), javaType)
            aliasedExprs += s -> aliasedVariable
            varsPassIn += aliasedVariable
            aliasedVariable
          }
      }
  }

  val params = varsPassIn.filter(!_.isInstanceOf[SimpleExprValue])).map { variable =>
    s"${variable.javaType.getName} ${variable.variableName}"
  }.mkString(", ")

  val funcName = ctx.freshName("nodeName")
  val javaType = CodeGenerator.javaType(dataType)
  val newValue = JavaCode.variable(ctx.freshName("value"), dataType)
  val funcFullName = ctx.addNewFunction(funcName,
    s"""
      |private $javaType $funcName($params) {
      |  $newCode
      |  $setIsNull
      |  return ${eval.value};
      |}
    """.stripMargin))

  eval.value = newValue
  val args = varsPassIn.filter(!_.isInstanceOf[SimpleExprValue])).map { variable =>
    s"${variable.variableName}"
  }

  // Create a code block to assign statements to aliased variables.
  val createVariables = aliasedExprs.foldLeft(EmptyBlock) { (block, (statement, variable)) =>
    block + code"${statement.javaType.getName} $variable = $statement;"
  }
  eval.code = createVariables + code"$javaType $newValue = $funcFullName($args);"
}
```

## How was this patch tested?

Added unite tests.

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #21405 from viirya/codeblock-api.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/32cfd3e7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/32cfd3e7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/32cfd3e7

Branch: refs/heads/master
Commit: 32cfd3e75a5ca65696fedfa4d49681e6fc3e698d
Parents: 4be9f0c
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Thu Jul 5 20:48:55 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Jul 5 20:48:55 2018 +0800

----------------------------------------------------------------------
 .../catalyst/expressions/codegen/javaCode.scala | 48 +++++++++----
 .../expressions/codegen/CodeBlockSuite.scala    | 75 ++++++++++++++++++--
 2 files changed, 104 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/32cfd3e7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
index 44f63e2..2f8c853 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
@@ -22,6 +22,7 @@ import java.lang.{Boolean => JBool}
 import scala.collection.mutable.ArrayBuffer
 import scala.language.{existentials, implicitConversions}
 
+import org.apache.spark.sql.catalyst.trees.TreeNode
 import org.apache.spark.sql.types.{BooleanType, DataType}
 
 /**
@@ -118,12 +119,9 @@ object JavaCode {
 /**
  * A trait representing a block of java code.
  */
-trait Block extends JavaCode {
+trait Block extends TreeNode[Block] with JavaCode {
   import Block._
 
-  // The expressions to be evaluated inside this block.
-  def exprValues: Set[ExprValue]
-
   // Returns java code string for this code block.
   override def toString: String = _marginChar match {
     case Some(c) => code.stripMargin(c).trim
@@ -148,11 +146,41 @@ trait Block extends JavaCode {
     this
   }
 
+  /**
+   * Apply a map function to each java expression codes present in this java code, and return a new
+   * java code based on the mapped java expression codes.
+   */
+  def transformExprValues(f: PartialFunction[ExprValue, ExprValue]): this.type = {
+    var changed = false
+
+    @inline def transform(e: ExprValue): ExprValue = {
+      val newE = f lift e
+      if (!newE.isDefined || newE.get.equals(e)) {
+        e
+      } else {
+        changed = true
+        newE.get
+      }
+    }
+
+    def doTransform(arg: Any): AnyRef = arg match {
+      case e: ExprValue => transform(e)
+      case Some(value) => Some(doTransform(value))
+      case seq: Traversable[_] => seq.map(doTransform)
+      case other: AnyRef => other
+    }
+
+    val newArgs = mapProductIterator(doTransform)
+    if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this
+  }
+
   // Concatenates this block with other block.
   def + (other: Block): Block = other match {
     case EmptyBlock => this
     case _ => code"$this\n$other"
   }
+
+  override def verboseString: String = toString
 }
 
 object Block {
@@ -219,12 +247,8 @@ object Block {
  * method splitting.
  */
 case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block {
-  override lazy val exprValues: Set[ExprValue] = {
-    blockInputs.flatMap {
-      case b: Block => b.exprValues
-      case e: ExprValue => Set(e)
-    }.toSet
-  }
+  override def children: Seq[Block] =
+    blockInputs.filter(_.isInstanceOf[Block]).asInstanceOf[Seq[Block]]
 
   override lazy val code: String = {
     val strings = codeParts.iterator
@@ -239,9 +263,9 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends
   }
 }
 
-object EmptyBlock extends Block with Serializable {
+case object EmptyBlock extends Block with Serializable {
   override val code: String = ""
-  override val exprValues: Set[ExprValue] = Set.empty
+  override def children: Seq[Block] = Seq.empty
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/32cfd3e7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
index d2c6420..55569b6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
@@ -65,7 +65,9 @@ class CodeBlockSuite extends SparkFunSuite {
            |boolean $isNull = false;
            |int $value = -1;
           """.stripMargin
-    val exprValues = code.exprValues
+    val exprValues = code.asInstanceOf[CodeBlock].blockInputs.collect {
+      case e: ExprValue => e
+    }.toSet
     assert(exprValues.size == 2)
     assert(exprValues === Set(value, isNull))
   }
@@ -94,7 +96,9 @@ class CodeBlockSuite extends SparkFunSuite {
 
     assert(code.toString == expected)
 
-    val exprValues = code.exprValues
+    val exprValues = code.children.flatMap(_.asInstanceOf[CodeBlock].blockInputs.collect {
+      case e: ExprValue => e
+    }).toSet
     assert(exprValues.size == 5)
     assert(exprValues === Set(isNull1, value1, isNull2, value2, literal))
   }
@@ -107,7 +111,7 @@ class CodeBlockSuite extends SparkFunSuite {
     assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}"))
   }
 
-  test("replace expr values in code block") {
+  test("transform expr in code block") {
     val expr = JavaCode.expression("1 + 1", IntegerType)
     val isNull = JavaCode.isNullVariable("expr1_isNull")
     val exprInFunc = JavaCode.variable("expr1", IntegerType)
@@ -120,11 +124,11 @@ class CodeBlockSuite extends SparkFunSuite {
            |}""".stripMargin
 
     val aliasedParam = JavaCode.variable("aliased", expr.javaType)
-    val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map {
-      case _: SimpleExprValue => aliasedParam
-      case other => other
+
+    // We want to replace all occurrences of `expr` with the variable `aliasedParam`.
+    val aliasedCode = code.transformExprValues {
+      case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam
     }
-    val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin
     val expected =
       code"""
            |callFunc(int $aliasedParam) {
@@ -133,4 +137,61 @@ class CodeBlockSuite extends SparkFunSuite {
            |}""".stripMargin
     assert(aliasedCode.toString == expected.toString)
   }
+
+  test ("transform expr in nested blocks") {
+    val expr = JavaCode.expression("1 + 1", IntegerType)
+    val isNull = JavaCode.isNullVariable("expr1_isNull")
+    val exprInFunc = JavaCode.variable("expr1", IntegerType)
+
+    val funcs = Seq("callFunc1", "callFunc2", "callFunc3")
+    val subBlocks = funcs.map { funcName =>
+      code"""
+           |$funcName(int $expr) {
+           |  boolean $isNull = false;
+           |  int $exprInFunc = $expr + 1;
+           |}""".stripMargin
+    }
+
+    val aliasedParam = JavaCode.variable("aliased", expr.javaType)
+
+    val block = code"${subBlocks(0)}\n${subBlocks(1)}\n${subBlocks(2)}"
+    val transformedBlock = block.transform {
+      case b: Block => b.transformExprValues {
+        case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam
+      }
+    }.asInstanceOf[CodeBlock]
+
+    val expected1 =
+      code"""
+        |callFunc1(int aliased) {
+        |  boolean expr1_isNull = false;
+        |  int expr1 = aliased + 1;
+        |}""".stripMargin
+
+    val expected2 =
+      code"""
+        |callFunc2(int aliased) {
+        |  boolean expr1_isNull = false;
+        |  int expr1 = aliased + 1;
+        |}""".stripMargin
+
+    val expected3 =
+      code"""
+        |callFunc3(int aliased) {
+        |  boolean expr1_isNull = false;
+        |  int expr1 = aliased + 1;
+        |}""".stripMargin
+
+    val exprValues = transformedBlock.children.flatMap { block =>
+      block.asInstanceOf[CodeBlock].blockInputs.collect {
+        case e: ExprValue => e
+      }
+    }.toSet
+
+    assert(transformedBlock.children(0).toString == expected1.toString)
+    assert(transformedBlock.children(1).toString == expected2.toString)
+    assert(transformedBlock.children(2).toString == expected3.toString)
+    assert(transformedBlock.toString == (expected1 + expected2 + expected3).toString)
+    assert(exprValues === Set(isNull, exprInFunc, aliasedParam))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org