You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/12/30 06:06:51 UTC

[spark] branch master updated: [SPARK-33890][SQL] Improve the implement of trim/trimleft/trimright

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 687f465  [SPARK-33890][SQL] Improve the implement of trim/trimleft/trimright
687f465 is described below

commit 687f465244301112a1f6cafa5d9361b2c7d7b4a5
Author: gengjiaan <ge...@360.cn>
AuthorDate: Wed Dec 30 06:06:17 2020 +0000

    [SPARK-33890][SQL] Improve the implement of trim/trimleft/trimright
    
    ### What changes were proposed in this pull request?
    The current implement of trim/trimleft/trimright have somewhat redundant.
    
    ### Why are the changes needed?
    Improve the implement of trim/trimleft/trimright
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    
    ### How was this patch tested?
    Jenkins test
    
    Closes #30905 from beliefer/SPARK-33890.
    
    Lead-authored-by: gengjiaan <ge...@360.cn>
    Co-authored-by: beliefer <be...@163.com>
    Co-authored-by: Jiaan Geng <be...@163.com>
    Co-authored-by: Wenchen Fan <cl...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../catalyst/expressions/stringExpressions.scala   | 202 +++++++--------------
 1 file changed, 64 insertions(+), 138 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 6caf439..9317684 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -764,6 +764,55 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes {
   override def nullable: Boolean = children.exists(_.nullable)
   override def foldable: Boolean = children.forall(_.foldable)
 
+  protected def doEval(srcString: UTF8String): UTF8String
+  protected def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String
+
+  override def eval(input: InternalRow): Any = {
+    val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
+    if (srcString == null) {
+      null
+    } else if (trimStr.isDefined) {
+      doEval(srcString, trimStr.get.eval(input).asInstanceOf[UTF8String])
+    } else {
+      doEval(srcString)
+    }
+  }
+
+  protected val trimMethod: String
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val evals = children.map(_.genCode(ctx))
+    val srcString = evals(0)
+
+    if (evals.length == 1) {
+      ev.copy(code = code"""
+         |${srcString.code}
+         |boolean ${ev.isNull} = false;
+         |UTF8String ${ev.value} = null;
+         |if (${srcString.isNull}) {
+         |  ${ev.isNull} = true;
+         |} else {
+         |  ${ev.value} = ${srcString.value}.$trimMethod();
+         |}""".stripMargin)
+    } else {
+      val trimString = evals(1)
+      ev.copy(code = code"""
+         |${srcString.code}
+         |boolean ${ev.isNull} = false;
+         |UTF8String ${ev.value} = null;
+         |if (${srcString.isNull}) {
+         |  ${ev.isNull} = true;
+         |} else {
+         |  ${trimString.code}
+         |  if (${trimString.isNull}) {
+         |    ${ev.isNull} = true;
+         |  } else {
+         |    ${ev.value} = ${srcString.value}.$trimMethod(${trimString.value});
+         |  }
+         |}""".stripMargin)
+    }
+  }
+
   override def sql: String = if (trimStr.isDefined) {
     s"TRIM($direction ${trimStr.get.sql} FROM ${srcStr.sql})"
   } else {
@@ -840,9 +889,7 @@ object StringTrim {
   """,
   since = "1.5.0",
   group = "string_funcs")
-case class StringTrim(
-    srcStr: Expression,
-    trimStr: Option[Expression] = None)
+case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None)
   extends String2TrimExpression {
 
   def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr))
@@ -853,51 +900,12 @@ case class StringTrim(
 
   override protected def direction: String = "BOTH"
 
-  override def eval(input: InternalRow): Any = {
-    val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
-    if (srcString == null) {
-      null
-    } else {
-      if (trimStr.isDefined) {
-        srcString.trim(trimStr.get.eval(input).asInstanceOf[UTF8String])
-      } else {
-        srcString.trim()
-      }
-    }
-  }
+  override def doEval(srcString: UTF8String): UTF8String = srcString.trim()
 
-  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val evals = children.map(_.genCode(ctx))
-    val srcString = evals(0)
+  override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String =
+    srcString.trim(trimString)
 
-    if (evals.length == 1) {
-      ev.copy(evals.map(_.code) :+ code"""
-        boolean ${ev.isNull} = false;
-        UTF8String ${ev.value} = null;
-        if (${srcString.isNull}) {
-          ${ev.isNull} = true;
-        } else {
-          ${ev.value} = ${srcString.value}.trim();
-        }""")
-    } else {
-      val trimString = evals(1)
-      val getTrimFunction =
-        s"""
-        if (${trimString.isNull}) {
-          ${ev.isNull} = true;
-        } else {
-          ${ev.value} = ${srcString.value}.trim(${trimString.value});
-        }"""
-      ev.copy(evals.map(_.code) :+ code"""
-        boolean ${ev.isNull} = false;
-        UTF8String ${ev.value} = null;
-        if (${srcString.isNull}) {
-          ${ev.isNull} = true;
-        } else {
-          $getTrimFunction
-        }""")
-    }
-  }
+  override val trimMethod: String = "trim"
 }
 
 object StringTrimLeft {
@@ -934,9 +942,7 @@ object StringTrimLeft {
   """,
   since = "1.5.0",
   group = "string_funcs")
-case class StringTrimLeft(
-    srcStr: Expression,
-    trimStr: Option[Expression] = None)
+case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None)
   extends String2TrimExpression {
 
   def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr))
@@ -947,51 +953,12 @@ case class StringTrimLeft(
 
   override protected def direction: String = "LEADING"
 
-  override def eval(input: InternalRow): Any = {
-    val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
-    if (srcString == null) {
-      null
-    } else {
-      if (trimStr.isDefined) {
-        srcString.trimLeft(trimStr.get.eval(input).asInstanceOf[UTF8String])
-      } else {
-        srcString.trimLeft()
-      }
-    }
-  }
+  override def doEval(srcString: UTF8String): UTF8String = srcString.trimLeft()
 
-  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val evals = children.map(_.genCode(ctx))
-    val srcString = evals(0)
+  override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String =
+    srcString.trimLeft(trimString)
 
-    if (evals.length == 1) {
-      ev.copy(evals.map(_.code) :+ code"""
-        boolean ${ev.isNull} = false;
-        UTF8String ${ev.value} = null;
-        if (${srcString.isNull}) {
-          ${ev.isNull} = true;
-        } else {
-          ${ev.value} = ${srcString.value}.trimLeft();
-        }""")
-    } else {
-      val trimString = evals(1)
-      val getTrimLeftFunction =
-        s"""
-        if (${trimString.isNull}) {
-          ${ev.isNull} = true;
-        } else {
-          ${ev.value} = ${srcString.value}.trimLeft(${trimString.value});
-        }"""
-      ev.copy(evals.map(_.code) :+ code"""
-        boolean ${ev.isNull} = false;
-        UTF8String ${ev.value} = null;
-        if (${srcString.isNull}) {
-          ${ev.isNull} = true;
-        } else {
-          $getTrimLeftFunction
-        }""")
-    }
-  }
+  override val trimMethod: String = "trimLeft"
 }
 
 object StringTrimRight {
@@ -1030,9 +997,7 @@ object StringTrimRight {
   since = "1.5.0",
   group = "string_funcs")
 // scalastyle:on line.size.limit
-case class StringTrimRight(
-    srcStr: Expression,
-    trimStr: Option[Expression] = None)
+case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = None)
   extends String2TrimExpression {
 
   def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr))
@@ -1043,51 +1008,12 @@ case class StringTrimRight(
 
   override protected def direction: String = "TRAILING"
 
-  override def eval(input: InternalRow): Any = {
-    val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
-    if (srcString == null) {
-      null
-    } else {
-      if (trimStr.isDefined) {
-        srcString.trimRight(trimStr.get.eval(input).asInstanceOf[UTF8String])
-      } else {
-        srcString.trimRight()
-      }
-    }
-  }
+  override def doEval(srcString: UTF8String): UTF8String = srcString.trimRight()
 
-  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val evals = children.map(_.genCode(ctx))
-    val srcString = evals(0)
+  override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String =
+    srcString.trimRight(trimString)
 
-    if (evals.length == 1) {
-      ev.copy(evals.map(_.code) :+ code"""
-        boolean ${ev.isNull} = false;
-        UTF8String ${ev.value} = null;
-        if (${srcString.isNull}) {
-          ${ev.isNull} = true;
-        } else {
-          ${ev.value} = ${srcString.value}.trimRight();
-        }""")
-    } else {
-      val trimString = evals(1)
-      val getTrimRightFunction =
-        s"""
-        if (${trimString.isNull}) {
-          ${ev.isNull} = true;
-        } else {
-          ${ev.value} = ${srcString.value}.trimRight(${trimString.value});
-        }"""
-      ev.copy(evals.map(_.code) :+ code"""
-        boolean ${ev.isNull} = false;
-        UTF8String ${ev.value} = null;
-        if (${srcString.isNull}) {
-          ${ev.isNull} = true;
-        } else {
-          $getTrimRightFunction
-        }""")
-    }
-  }
+  override val trimMethod: String = "trimRight"
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org