You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/07/05 12:49:16 UTC
spark git commit: [SPARK-24361][SQL] Polish code block manipulation
API
Repository: spark
Updated Branches:
refs/heads/master 4be9f0c02 -> 32cfd3e75
[SPARK-24361][SQL] Polish code block manipulation API
## What changes were proposed in this pull request?
Current code block manipulation API is immature and hacky. We need a formal API to manipulate code blocks.
The basic idea is making `JavaCode` as `TreeNode`. So we can use familiar `transform` API to manipulate code blocks and expressions in code blocks.
For example, we can replace `SimpleExprValue` in a code block like this:
```scala
code.transformExprValues {
case SimpleExprValue("1 + 1", _) => aliasedParam
}
```
The example use case is splitting code to methods.
For example, we have an `ExprCode` containing generated code. But it is too long and we need to split it as method. Because statement-based expressions can't be directly passed into. We need to transform them as variables first:
```scala
def getExprValues(block: Block): Set[ExprValue] = block match {
case c: CodeBlock =>
c.blockInputs.collect {
case e: ExprValue => e
}.toSet
case _ => Set.empty
}
def currentCodegenInputs(ctx: CodegenContext): Set[ExprValue] = {
// Collects current variables in ctx.currentVars and ctx.INPUT_ROW.
// It looks roughly like...
ctx.currentVars.flatMap { v =>
getExprValues(v.code) ++ Set(v.value, v.isNull)
}.toSet + ctx.INPUT_ROW
}
// A code block of an expression contains too long code, making it as method
if (eval.code.length > 1024) {
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
...
} else {
""
}
// Pick up variables and statements necessary to pass in.
val currentVars = currentCodegenInputs(ctx)
val varsPassIn = getExprValues(eval.code).intersect(currentVars)
val aliasedExprs = HashMap.empty[SimpleExprValue, VariableValue]
// Replace statement-based expressions which can't be directly passed in the method.
val newCode = eval.code.transform {
case block =>
block.transformExprValues {
case s: SimpleExprValue(_, javaType) if varsPassIn.contains(s) =>
if (aliasedExprs.contains(s)) {
aliasedExprs(s)
} else {
val aliasedVariable = JavaCode.variable(ctx.freshName("aliasedVar"), javaType)
aliasedExprs += s -> aliasedVariable
varsPassIn += aliasedVariable
aliasedVariable
}
}
}
val params = varsPassIn.filter(!_.isInstanceOf[SimpleExprValue])).map { variable =>
s"${variable.javaType.getName} ${variable.variableName}"
}.mkString(", ")
val funcName = ctx.freshName("nodeName")
val javaType = CodeGenerator.javaType(dataType)
val newValue = JavaCode.variable(ctx.freshName("value"), dataType)
val funcFullName = ctx.addNewFunction(funcName,
s"""
|private $javaType $funcName($params) {
| $newCode
| $setIsNull
| return ${eval.value};
|}
""".stripMargin))
eval.value = newValue
val args = varsPassIn.filter(!_.isInstanceOf[SimpleExprValue])).map { variable =>
s"${variable.variableName}"
}
// Create a code block to assign statements to aliased variables.
val createVariables = aliasedExprs.foldLeft(EmptyBlock) { (block, (statement, variable)) =>
block + code"${statement.javaType.getName} $variable = $statement;"
}
eval.code = createVariables + code"$javaType $newValue = $funcFullName($args);"
}
```
## How was this patch tested?
Added unite tests.
Author: Liang-Chi Hsieh <vi...@gmail.com>
Closes #21405 from viirya/codeblock-api.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/32cfd3e7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/32cfd3e7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/32cfd3e7
Branch: refs/heads/master
Commit: 32cfd3e75a5ca65696fedfa4d49681e6fc3e698d
Parents: 4be9f0c
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Thu Jul 5 20:48:55 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Jul 5 20:48:55 2018 +0800
----------------------------------------------------------------------
.../catalyst/expressions/codegen/javaCode.scala | 48 +++++++++----
.../expressions/codegen/CodeBlockSuite.scala | 75 ++++++++++++++++++--
2 files changed, 104 insertions(+), 19 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/32cfd3e7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
index 44f63e2..2f8c853 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
@@ -22,6 +22,7 @@ import java.lang.{Boolean => JBool}
import scala.collection.mutable.ArrayBuffer
import scala.language.{existentials, implicitConversions}
+import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types.{BooleanType, DataType}
/**
@@ -118,12 +119,9 @@ object JavaCode {
/**
* A trait representing a block of java code.
*/
-trait Block extends JavaCode {
+trait Block extends TreeNode[Block] with JavaCode {
import Block._
- // The expressions to be evaluated inside this block.
- def exprValues: Set[ExprValue]
-
// Returns java code string for this code block.
override def toString: String = _marginChar match {
case Some(c) => code.stripMargin(c).trim
@@ -148,11 +146,41 @@ trait Block extends JavaCode {
this
}
+ /**
+ * Apply a map function to each java expression codes present in this java code, and return a new
+ * java code based on the mapped java expression codes.
+ */
+ def transformExprValues(f: PartialFunction[ExprValue, ExprValue]): this.type = {
+ var changed = false
+
+ @inline def transform(e: ExprValue): ExprValue = {
+ val newE = f lift e
+ if (!newE.isDefined || newE.get.equals(e)) {
+ e
+ } else {
+ changed = true
+ newE.get
+ }
+ }
+
+ def doTransform(arg: Any): AnyRef = arg match {
+ case e: ExprValue => transform(e)
+ case Some(value) => Some(doTransform(value))
+ case seq: Traversable[_] => seq.map(doTransform)
+ case other: AnyRef => other
+ }
+
+ val newArgs = mapProductIterator(doTransform)
+ if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this
+ }
+
// Concatenates this block with other block.
def + (other: Block): Block = other match {
case EmptyBlock => this
case _ => code"$this\n$other"
}
+
+ override def verboseString: String = toString
}
object Block {
@@ -219,12 +247,8 @@ object Block {
* method splitting.
*/
case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block {
- override lazy val exprValues: Set[ExprValue] = {
- blockInputs.flatMap {
- case b: Block => b.exprValues
- case e: ExprValue => Set(e)
- }.toSet
- }
+ override def children: Seq[Block] =
+ blockInputs.filter(_.isInstanceOf[Block]).asInstanceOf[Seq[Block]]
override lazy val code: String = {
val strings = codeParts.iterator
@@ -239,9 +263,9 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends
}
}
-object EmptyBlock extends Block with Serializable {
+case object EmptyBlock extends Block with Serializable {
override val code: String = ""
- override val exprValues: Set[ExprValue] = Set.empty
+ override def children: Seq[Block] = Seq.empty
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/32cfd3e7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
index d2c6420..55569b6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
@@ -65,7 +65,9 @@ class CodeBlockSuite extends SparkFunSuite {
|boolean $isNull = false;
|int $value = -1;
""".stripMargin
- val exprValues = code.exprValues
+ val exprValues = code.asInstanceOf[CodeBlock].blockInputs.collect {
+ case e: ExprValue => e
+ }.toSet
assert(exprValues.size == 2)
assert(exprValues === Set(value, isNull))
}
@@ -94,7 +96,9 @@ class CodeBlockSuite extends SparkFunSuite {
assert(code.toString == expected)
- val exprValues = code.exprValues
+ val exprValues = code.children.flatMap(_.asInstanceOf[CodeBlock].blockInputs.collect {
+ case e: ExprValue => e
+ }).toSet
assert(exprValues.size == 5)
assert(exprValues === Set(isNull1, value1, isNull2, value2, literal))
}
@@ -107,7 +111,7 @@ class CodeBlockSuite extends SparkFunSuite {
assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}"))
}
- test("replace expr values in code block") {
+ test("transform expr in code block") {
val expr = JavaCode.expression("1 + 1", IntegerType)
val isNull = JavaCode.isNullVariable("expr1_isNull")
val exprInFunc = JavaCode.variable("expr1", IntegerType)
@@ -120,11 +124,11 @@ class CodeBlockSuite extends SparkFunSuite {
|}""".stripMargin
val aliasedParam = JavaCode.variable("aliased", expr.javaType)
- val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map {
- case _: SimpleExprValue => aliasedParam
- case other => other
+
+ // We want to replace all occurrences of `expr` with the variable `aliasedParam`.
+ val aliasedCode = code.transformExprValues {
+ case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam
}
- val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin
val expected =
code"""
|callFunc(int $aliasedParam) {
@@ -133,4 +137,61 @@ class CodeBlockSuite extends SparkFunSuite {
|}""".stripMargin
assert(aliasedCode.toString == expected.toString)
}
+
+ test ("transform expr in nested blocks") {
+ val expr = JavaCode.expression("1 + 1", IntegerType)
+ val isNull = JavaCode.isNullVariable("expr1_isNull")
+ val exprInFunc = JavaCode.variable("expr1", IntegerType)
+
+ val funcs = Seq("callFunc1", "callFunc2", "callFunc3")
+ val subBlocks = funcs.map { funcName =>
+ code"""
+ |$funcName(int $expr) {
+ | boolean $isNull = false;
+ | int $exprInFunc = $expr + 1;
+ |}""".stripMargin
+ }
+
+ val aliasedParam = JavaCode.variable("aliased", expr.javaType)
+
+ val block = code"${subBlocks(0)}\n${subBlocks(1)}\n${subBlocks(2)}"
+ val transformedBlock = block.transform {
+ case b: Block => b.transformExprValues {
+ case SimpleExprValue("1 + 1", java.lang.Integer.TYPE) => aliasedParam
+ }
+ }.asInstanceOf[CodeBlock]
+
+ val expected1 =
+ code"""
+ |callFunc1(int aliased) {
+ | boolean expr1_isNull = false;
+ | int expr1 = aliased + 1;
+ |}""".stripMargin
+
+ val expected2 =
+ code"""
+ |callFunc2(int aliased) {
+ | boolean expr1_isNull = false;
+ | int expr1 = aliased + 1;
+ |}""".stripMargin
+
+ val expected3 =
+ code"""
+ |callFunc3(int aliased) {
+ | boolean expr1_isNull = false;
+ | int expr1 = aliased + 1;
+ |}""".stripMargin
+
+ val exprValues = transformedBlock.children.flatMap { block =>
+ block.asInstanceOf[CodeBlock].blockInputs.collect {
+ case e: ExprValue => e
+ }
+ }.toSet
+
+ assert(transformedBlock.children(0).toString == expected1.toString)
+ assert(transformedBlock.children(1).toString == expected2.toString)
+ assert(transformedBlock.children(2).toString == expected3.toString)
+ assert(transformedBlock.toString == (expected1 + expected2 + expected3).toString)
+ assert(exprValues === Set(isNull, exprInFunc, aliasedParam))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org