You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/07/31 17:29:24 UTC

spark git commit: [SPARK-9500] add TernaryExpression to simplify ternary expressions

Repository: spark
Updated Branches:
  refs/heads/master a3a85d73d -> 6bba7509a


[SPARK-9500] add TernaryExpression to simplify ternary expressions

There lots of duplicated code in ternary expressions, create a TernaryExpression for them to reduce duplicated code.

cc chenghao-intel

Author: Davies Liu <da...@databricks.com>

Closes #7816 from davies/ternary and squashes the following commits:

ed2bf76 [Davies Liu] add TernaryExpression


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6bba7509
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6bba7509
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6bba7509

Branch: refs/heads/master
Commit: 6bba7509a932aa4d39266df2d15b1370b7aabbec
Parents: a3a85d7
Author: Davies Liu <da...@databricks.com>
Authored: Fri Jul 31 08:28:05 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Fri Jul 31 08:28:05 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/Expression.scala   |  85 +++++
 .../expressions/codegen/CodeGenerator.scala     |   2 +-
 .../spark/sql/catalyst/expressions/math.scala   |  66 +---
 .../catalyst/expressions/stringOperations.scala | 356 +++++--------------
 4 files changed, 183 insertions(+), 326 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6bba7509/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 8fc1826..2842b3e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -432,3 +432,88 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
 private[sql] object BinaryOperator {
   def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right))
 }
+
+/**
+ * An expression with three inputs and one output. The output is by default evaluated to null
+ * if any input is evaluated to null.
+ */
+abstract class TernaryExpression extends Expression {
+
+  override def foldable: Boolean = children.forall(_.foldable)
+
+  override def nullable: Boolean = children.exists(_.nullable)
+
+  /**
+   * Default behavior of evaluation according to the default nullability of BinaryExpression.
+   * If subclass of BinaryExpression override nullable, probably should also override this.
+   */
+  override def eval(input: InternalRow): Any = {
+    val exprs = children
+    val value1 = exprs(0).eval(input)
+    if (value1 != null) {
+      val value2 = exprs(1).eval(input)
+      if (value2 != null) {
+        val value3 = exprs(2).eval(input)
+        if (value3 != null) {
+          return nullSafeEval(value1, value2, value3)
+        }
+      }
+    }
+    null
+  }
+
+  /**
+   * Called by default [[eval]] implementation.  If subclass of BinaryExpression keep the default
+   * nullability, they can override this method to save null-check code.  If we need full control
+   * of evaluation process, we should override [[eval]].
+   */
+  protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any =
+    sys.error(s"BinaryExpressions must override either eval or nullSafeEval")
+
+  /**
+   * Short hand for generating binary evaluation code.
+   * If either of the sub-expressions is null, the result of this computation
+   * is assumed to be null.
+   *
+   * @param f accepts two variable names and returns Java code to compute the output.
+   */
+  protected def defineCodeGen(
+    ctx: CodeGenContext,
+    ev: GeneratedExpressionCode,
+    f: (String, String, String) => String): String = {
+    nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => {
+      s"${ev.primitive} = ${f(eval1, eval2, eval3)};"
+    })
+  }
+
+  /**
+   * Short hand for generating binary evaluation code.
+   * If either of the sub-expressions is null, the result of this computation
+   * is assumed to be null.
+   *
+   * @param f function that accepts the 2 non-null evaluation result names of children
+   *          and returns Java code to compute the output.
+   */
+  protected def nullSafeCodeGen(
+    ctx: CodeGenContext,
+    ev: GeneratedExpressionCode,
+    f: (String, String, String) => String): String = {
+    val evals = children.map(_.gen(ctx))
+    val resultCode = f(evals(0).primitive, evals(1).primitive, evals(2).primitive)
+    s"""
+      ${evals(0).code}
+      boolean ${ev.isNull} = true;
+      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+      if (!${evals(0).isNull}) {
+        ${evals(1).code}
+        if (!${evals(1).isNull}) {
+          ${evals(2).code}
+          if (!${evals(2).isNull}) {
+            ${ev.isNull} = false;  // resultCode could change nullability
+            $resultCode
+          }
+        }
+      }
+    """
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6bba7509/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 60e2863..e50ec27 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -305,7 +305,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
       evaluator.cook(code)
     } catch {
       case e: Exception =>
-        val msg = "failed to compile:\n " + CodeFormatter.format(code)
+        val msg = s"failed to compile: $e\n" + CodeFormatter.format(code)
         logError(msg, e)
         throw new Exception(msg, e)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/6bba7509/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index e6d807f..15ceb91 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -165,69 +165,29 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH"
  * @param toBaseExpr to which base
  */
 case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
-  extends Expression with ImplicitCastInputTypes {
-
-  override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable
-
-  override def nullable: Boolean = numExpr.nullable || fromBaseExpr.nullable || toBaseExpr.nullable
+  extends TernaryExpression with ImplicitCastInputTypes {
 
   override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr)
-
   override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType)
-
   override def dataType: DataType = StringType
 
-  /** Returns the result of evaluating this expression on a given input Row */
-  override def eval(input: InternalRow): Any = {
-    val num = numExpr.eval(input)
-    if (num != null) {
-      val fromBase = fromBaseExpr.eval(input)
-      if (fromBase != null) {
-        val toBase = toBaseExpr.eval(input)
-        if (toBase != null) {
-          NumberConverter.convert(
-            num.asInstanceOf[UTF8String].getBytes,
-            fromBase.asInstanceOf[Int],
-            toBase.asInstanceOf[Int])
-        } else {
-          null
-        }
-      } else {
-        null
-      }
-    } else {
-      null
-    }
+  override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = {
+    NumberConverter.convert(
+      num.asInstanceOf[UTF8String].getBytes,
+      fromBase.asInstanceOf[Int],
+      toBase.asInstanceOf[Int])
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val numGen = numExpr.gen(ctx)
-    val from = fromBaseExpr.gen(ctx)
-    val to = toBaseExpr.gen(ctx)
-
     val numconv = NumberConverter.getClass.getName.stripSuffix("$")
-    s"""
-       ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
-       ${numGen.code}
-       boolean ${ev.isNull} = ${numGen.isNull};
-       if (!${ev.isNull}) {
-         ${from.code}
-         if (!${from.isNull}) {
-           ${to.code}
-           if (!${to.isNull}) {
-             ${ev.primitive} = $numconv.convert(${numGen.primitive}.getBytes(),
-               ${from.primitive}, ${to.primitive});
-             if (${ev.primitive} == null) {
-               ${ev.isNull} = true;
-             }
-           } else {
-             ${ev.isNull} = true;
-           }
-         } else {
-           ${ev.isNull} = true;
-         }
+    nullSafeCodeGen(ctx, ev, (num, from, to) =>
+      s"""
+       ${ev.primitive} = $numconv.convert($num.getBytes(), $from, $to);
+       if (${ev.primitive} == null) {
+         ${ev.isNull} = true;
        }
-     """
+       """
+    )
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6bba7509/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 99a6234..684eac1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -426,15 +426,13 @@ case class StringInstr(str: Expression, substr: Expression)
  * in given string after position pos.
  */
 case class StringLocate(substr: Expression, str: Expression, start: Expression)
-  extends Expression with ImplicitCastInputTypes with CodegenFallback {
+  extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback {
 
   def this(substr: Expression, str: Expression) = {
     this(substr, str, Literal(0))
   }
 
   override def children: Seq[Expression] = substr :: str :: start :: Nil
-  override def foldable: Boolean = children.forall(_.foldable)
-  override def nullable: Boolean = substr.nullable || str.nullable
   override def dataType: DataType = IntegerType
   override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
 
@@ -467,60 +465,18 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
  * Returns str, left-padded with pad to a length of len.
  */
 case class StringLPad(str: Expression, len: Expression, pad: Expression)
-  extends Expression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes {
 
   override def children: Seq[Expression] = str :: len :: pad :: Nil
-  override def foldable: Boolean = children.forall(_.foldable)
-  override def nullable: Boolean = children.exists(_.nullable)
   override def dataType: DataType = StringType
   override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType)
 
-  override def eval(input: InternalRow): Any = {
-    val s = str.eval(input)
-    if (s == null) {
-      null
-    } else {
-      val l = len.eval(input)
-      if (l == null) {
-        null
-      } else {
-        val p = pad.eval(input)
-        if (p == null) {
-          null
-        } else {
-          val len = l.asInstanceOf[Int]
-          val str = s.asInstanceOf[UTF8String]
-          val pad = p.asInstanceOf[UTF8String]
-
-          str.lpad(len, pad)
-        }
-      }
-    }
+  override def nullSafeEval(str: Any, len: Any, pad: Any): Any = {
+    str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String])
   }
 
   override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val lenGen = len.gen(ctx)
-    val strGen = str.gen(ctx)
-    val padGen = pad.gen(ctx)
-
-    s"""
-      ${lenGen.code}
-      boolean ${ev.isNull} = ${lenGen.isNull};
-      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        ${strGen.code}
-        if (!${strGen.isNull}) {
-          ${padGen.code}
-          if (!${padGen.isNull}) {
-            ${ev.primitive} = ${strGen.primitive}.lpad(${lenGen.primitive}, ${padGen.primitive});
-          } else {
-            ${ev.isNull} = true;
-          }
-        } else {
-          ${ev.isNull} = true;
-        }
-      }
-     """
+    defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)")
   }
 
   override def prettyName: String = "lpad"
@@ -530,60 +486,18 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression)
  * Returns str, right-padded with pad to a length of len.
  */
 case class StringRPad(str: Expression, len: Expression, pad: Expression)
-  extends Expression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes {
 
   override def children: Seq[Expression] = str :: len :: pad :: Nil
-  override def foldable: Boolean = children.forall(_.foldable)
-  override def nullable: Boolean = children.exists(_.nullable)
   override def dataType: DataType = StringType
   override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType)
 
-  override def eval(input: InternalRow): Any = {
-    val s = str.eval(input)
-    if (s == null) {
-      null
-    } else {
-      val l = len.eval(input)
-      if (l == null) {
-        null
-      } else {
-        val p = pad.eval(input)
-        if (p == null) {
-          null
-        } else {
-          val len = l.asInstanceOf[Int]
-          val str = s.asInstanceOf[UTF8String]
-          val pad = p.asInstanceOf[UTF8String]
-
-          str.rpad(len, pad)
-        }
-      }
-    }
+  override def nullSafeEval(str: Any, len: Any, pad: Any): Any = {
+    str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String])
   }
 
   override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val lenGen = len.gen(ctx)
-    val strGen = str.gen(ctx)
-    val padGen = pad.gen(ctx)
-
-    s"""
-      ${lenGen.code}
-      boolean ${ev.isNull} = ${lenGen.isNull};
-      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        ${strGen.code}
-        if (!${strGen.isNull}) {
-          ${padGen.code}
-          if (!${padGen.isNull}) {
-            ${ev.primitive} = ${strGen.primitive}.rpad(${lenGen.primitive}, ${padGen.primitive});
-          } else {
-            ${ev.isNull} = true;
-          }
-        } else {
-          ${ev.isNull} = true;
-        }
-      }
-     """
+    defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)")
   }
 
   override def prettyName: String = "rpad"
@@ -745,68 +659,24 @@ case class StringSplit(str: Expression, pattern: Expression)
  * Defined for String and Binary types.
  */
 case class Substring(str: Expression, pos: Expression, len: Expression)
-  extends Expression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes {
 
   def this(str: Expression, pos: Expression) = {
     this(str, pos, Literal(Integer.MAX_VALUE))
   }
 
-  override def foldable: Boolean = str.foldable && pos.foldable && len.foldable
-  override def nullable: Boolean = str.nullable || pos.nullable || len.nullable
-
   override def dataType: DataType = StringType
 
   override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
 
   override def children: Seq[Expression] = str :: pos :: len :: Nil
 
-  override def eval(input: InternalRow): Any = {
-    val stringEval = str.eval(input)
-    if (stringEval != null) {
-      val posEval = pos.eval(input)
-      if (posEval != null) {
-        val lenEval = len.eval(input)
-        if (lenEval != null) {
-          stringEval.asInstanceOf[UTF8String]
-            .substringSQL(posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int])
-        } else {
-          null
-        }
-      } else {
-        null
-      }
-    } else {
-      null
-    }
+  override def nullSafeEval(string: Any, pos: Any, len: Any): Any = {
+    string.asInstanceOf[UTF8String].substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int])
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val strGen = str.gen(ctx)
-    val posGen = pos.gen(ctx)
-    val lenGen = len.gen(ctx)
-
-    val start = ctx.freshName("start")
-    val end = ctx.freshName("end")
-
-    s"""
-      ${strGen.code}
-      boolean ${ev.isNull} = ${strGen.isNull};
-      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        ${posGen.code}
-        if (!${posGen.isNull}) {
-          ${lenGen.code}
-          if (!${lenGen.isNull}) {
-            ${ev.primitive} = ${strGen.primitive}
-              .substringSQL(${posGen.primitive}, ${lenGen.primitive});
-          } else {
-            ${ev.isNull} = true;
-          }
-        } else {
-          ${ev.isNull} = true;
-        }
-      }
-     """
+    defineCodeGen(ctx, ev, (str, pos, len) => s"$str.substringSQL($pos, $len)")
   }
 }
 
@@ -986,7 +856,7 @@ case class Encode(value: Expression, charset: Expression)
  * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
  */
 case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression)
-  extends Expression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes {
 
   // last regex in string, we will update the pattern iff regexp value changed.
   @transient private var lastRegex: UTF8String = _
@@ -998,40 +868,26 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
   // result buffer write by Matcher
   @transient private val result: StringBuffer = new StringBuffer
 
-  override def nullable: Boolean = subject.nullable || regexp.nullable || rep.nullable
-  override def foldable: Boolean = subject.foldable && regexp.foldable && rep.foldable
-
-  override def eval(input: InternalRow): Any = {
-    val s = subject.eval(input)
-    if (null != s) {
-      val p = regexp.eval(input)
-      if (null != p) {
-        val r = rep.eval(input)
-        if (null != r) {
-          if (!p.equals(lastRegex)) {
-            // regex value changed
-            lastRegex = p.asInstanceOf[UTF8String]
-            pattern = Pattern.compile(lastRegex.toString)
-          }
-          if (!r.equals(lastReplacementInUTF8)) {
-            // replacement string changed
-            lastReplacementInUTF8 = r.asInstanceOf[UTF8String]
-            lastReplacement = lastReplacementInUTF8.toString
-          }
-          val m = pattern.matcher(s.toString())
-          result.delete(0, result.length())
-
-          while (m.find) {
-            m.appendReplacement(result, lastReplacement)
-          }
-          m.appendTail(result)
+  override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
+    if (!p.equals(lastRegex)) {
+      // regex value changed
+      lastRegex = p.asInstanceOf[UTF8String]
+      pattern = Pattern.compile(lastRegex.toString)
+    }
+    if (!r.equals(lastReplacementInUTF8)) {
+      // replacement string changed
+      lastReplacementInUTF8 = r.asInstanceOf[UTF8String]
+      lastReplacement = lastReplacementInUTF8.toString
+    }
+    val m = pattern.matcher(s.toString())
+    result.delete(0, result.length())
 
-          return UTF8String.fromString(result.toString)
-        }
-      }
+    while (m.find) {
+      m.appendReplacement(result, lastReplacement)
     }
+    m.appendTail(result)
 
-    null
+    UTF8String.fromString(result.toString)
   }
 
   override def dataType: DataType = StringType
@@ -1048,59 +904,43 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
 
     val termResult = ctx.freshName("result")
 
-    val classNameUTF8String = classOf[UTF8String].getCanonicalName
     val classNamePattern = classOf[Pattern].getCanonicalName
-    val classNameString = classOf[java.lang.String].getCanonicalName
     val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName
 
-    ctx.addMutableState(classNameUTF8String,
+    ctx.addMutableState("UTF8String",
       termLastRegex, s"${termLastRegex} = null;")
     ctx.addMutableState(classNamePattern,
       termPattern, s"${termPattern} = null;")
-    ctx.addMutableState(classNameString,
+    ctx.addMutableState("String",
       termLastReplacement, s"${termLastReplacement} = null;")
-    ctx.addMutableState(classNameUTF8String,
+    ctx.addMutableState("UTF8String",
       termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;")
     ctx.addMutableState(classNameStringBuffer,
       termResult, s"${termResult} = new $classNameStringBuffer();")
 
-    val evalSubject = subject.gen(ctx)
-    val evalRegexp = regexp.gen(ctx)
-    val evalRep = rep.gen(ctx)
-
+    nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => {
     s"""
-      ${evalSubject.code}
-      boolean ${ev.isNull} = true;
-      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
-      if (!${evalSubject.isNull}) {
-        ${evalRegexp.code}
-        if (!${evalRegexp.isNull}) {
-          ${evalRep.code}
-          if (!${evalRep.isNull}) {
-            if (!${evalRegexp.primitive}.equals(${termLastRegex})) {
-              // regex value changed
-              ${termLastRegex} = ${evalRegexp.primitive};
-              ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
-            }
-            if (!${evalRep.primitive}.equals(${termLastReplacementInUTF8})) {
-              // replacement string changed
-              ${termLastReplacementInUTF8} = ${evalRep.primitive};
-              ${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
-            }
-            ${termResult}.delete(0, ${termResult}.length());
-            ${classOf[java.util.regex.Matcher].getCanonicalName} m =
-                                   ${termPattern}.matcher(${evalSubject.primitive}.toString());
+      if (!$regexp.equals(${termLastRegex})) {
+        // regex value changed
+        ${termLastRegex} = $regexp;
+        ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
+      }
+      if (!$rep.equals(${termLastReplacementInUTF8})) {
+        // replacement string changed
+        ${termLastReplacementInUTF8} = $rep;
+        ${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
+      }
+      ${termResult}.delete(0, ${termResult}.length());
+      java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString());
 
-            while (m.find()) {
-              m.appendReplacement(${termResult}, ${termLastReplacement});
-            }
-            m.appendTail(${termResult});
-            ${ev.primitive} = ${classNameUTF8String}.fromString(${termResult}.toString());
-            ${ev.isNull} = false;
-          }
-        }
+      while (m.find()) {
+        m.appendReplacement(${termResult}, ${termLastReplacement});
       }
+      m.appendTail(${termResult});
+      ${ev.primitive} = UTF8String.fromString(${termResult}.toString());
+      ${ev.isNull} = false;
     """
+    })
   }
 }
 
@@ -1110,7 +950,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
  * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
  */
 case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression)
-  extends Expression with ImplicitCastInputTypes {
+  extends TernaryExpression with ImplicitCastInputTypes {
   def this(s: Expression, r: Expression) = this(s, r, Literal(1))
 
   // last regex in string, we will update the pattern iff regexp value changed.
@@ -1118,32 +958,19 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
   // last regex pattern, we cache it for performance concern
   @transient private var pattern: Pattern = _
 
-  override def nullable: Boolean = subject.nullable || regexp.nullable || idx.nullable
-  override def foldable: Boolean = subject.foldable && regexp.foldable && idx.foldable
-
-  override def eval(input: InternalRow): Any = {
-    val s = subject.eval(input)
-    if (null != s) {
-      val p = regexp.eval(input)
-      if (null != p) {
-        val r = idx.eval(input)
-        if (null != r) {
-          if (!p.equals(lastRegex)) {
-            // regex value changed
-            lastRegex = p.asInstanceOf[UTF8String]
-            pattern = Pattern.compile(lastRegex.toString)
-          }
-          val m = pattern.matcher(s.toString())
-          if (m.find) {
-            val mr: MatchResult = m.toMatchResult
-            return UTF8String.fromString(mr.group(r.asInstanceOf[Int]))
-          }
-          return UTF8String.EMPTY_UTF8
-        }
-      }
+  override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
+    if (!p.equals(lastRegex)) {
+      // regex value changed
+      lastRegex = p.asInstanceOf[UTF8String]
+      pattern = Pattern.compile(lastRegex.toString)
+    }
+    val m = pattern.matcher(s.toString())
+    if (m.find) {
+      val mr: MatchResult = m.toMatchResult
+      UTF8String.fromString(mr.group(r.asInstanceOf[Int]))
+    } else {
+      UTF8String.EMPTY_UTF8
     }
-
-    null
   }
 
   override def dataType: DataType = StringType
@@ -1154,44 +981,29 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
   override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     val termLastRegex = ctx.freshName("lastRegex")
     val termPattern = ctx.freshName("pattern")
-    val classNameUTF8String = classOf[UTF8String].getCanonicalName
     val classNamePattern = classOf[Pattern].getCanonicalName
 
-    ctx.addMutableState(classNameUTF8String, termLastRegex, s"${termLastRegex} = null;")
+    ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;")
     ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;")
 
-    val evalSubject = subject.gen(ctx)
-    val evalRegexp = regexp.gen(ctx)
-    val evalIdx = idx.gen(ctx)
-
-    s"""
-      ${evalSubject.code}
-      ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
-      boolean ${ev.isNull} = true;
-      if (!${evalSubject.isNull}) {
-        ${evalRegexp.code}
-        if (!${evalRegexp.isNull}) {
-          ${evalIdx.code}
-          if (!${evalIdx.isNull}) {
-            if (!${evalRegexp.primitive}.equals(${termLastRegex})) {
-              // regex value changed
-              ${termLastRegex} = ${evalRegexp.primitive};
-              ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
-            }
-            ${classOf[java.util.regex.Matcher].getCanonicalName} m =
-              ${termPattern}.matcher(${evalSubject.primitive}.toString());
-            if (m.find()) {
-              ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult();
-              ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive}));
-              ${ev.isNull} = false;
-            } else {
-              ${ev.primitive} = ${classNameUTF8String}.EMPTY_UTF8;
-              ${ev.isNull} = false;
-            }
-          }
-        }
+    nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
+      s"""
+      if (!$regexp.equals(${termLastRegex})) {
+        // regex value changed
+        ${termLastRegex} = $regexp;
+        ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
       }
-    """
+      java.util.regex.Matcher m =
+        ${termPattern}.matcher($subject.toString());
+      if (m.find()) {
+        java.util.regex.MatchResult mr = m.toMatchResult();
+        ${ev.primitive} = UTF8String.fromString(mr.group($idx));
+        ${ev.isNull} = false;
+      } else {
+        ${ev.primitive} = UTF8String.EMPTY_UTF8;
+        ${ev.isNull} = false;
+      }"""
+    })
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org