You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/12/30 06:06:51 UTC
[spark] branch master updated: [SPARK-33890][SQL] Improve the
implement of trim/trimleft/trimright
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 687f465 [SPARK-33890][SQL] Improve the implement of trim/trimleft/trimright
687f465 is described below
commit 687f465244301112a1f6cafa5d9361b2c7d7b4a5
Author: gengjiaan <ge...@360.cn>
AuthorDate: Wed Dec 30 06:06:17 2020 +0000
[SPARK-33890][SQL] Improve the implement of trim/trimleft/trimright
### What changes were proposed in this pull request?
The current implement of trim/trimleft/trimright have somewhat redundant.
### Why are the changes needed?
Improve the implement of trim/trimleft/trimright
### Does this PR introduce _any_ user-facing change?
'No'.
### How was this patch tested?
Jenkins test
Closes #30905 from beliefer/SPARK-33890.
Lead-authored-by: gengjiaan <ge...@360.cn>
Co-authored-by: beliefer <be...@163.com>
Co-authored-by: Jiaan Geng <be...@163.com>
Co-authored-by: Wenchen Fan <cl...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../catalyst/expressions/stringExpressions.scala | 202 +++++++--------------
1 file changed, 64 insertions(+), 138 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 6caf439..9317684 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -764,6 +764,55 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes {
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
+ protected def doEval(srcString: UTF8String): UTF8String
+ protected def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String
+
+ override def eval(input: InternalRow): Any = {
+ val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
+ if (srcString == null) {
+ null
+ } else if (trimStr.isDefined) {
+ doEval(srcString, trimStr.get.eval(input).asInstanceOf[UTF8String])
+ } else {
+ doEval(srcString)
+ }
+ }
+
+ protected val trimMethod: String
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val evals = children.map(_.genCode(ctx))
+ val srcString = evals(0)
+
+ if (evals.length == 1) {
+ ev.copy(code = code"""
+ |${srcString.code}
+ |boolean ${ev.isNull} = false;
+ |UTF8String ${ev.value} = null;
+ |if (${srcString.isNull}) {
+ | ${ev.isNull} = true;
+ |} else {
+ | ${ev.value} = ${srcString.value}.$trimMethod();
+ |}""".stripMargin)
+ } else {
+ val trimString = evals(1)
+ ev.copy(code = code"""
+ |${srcString.code}
+ |boolean ${ev.isNull} = false;
+ |UTF8String ${ev.value} = null;
+ |if (${srcString.isNull}) {
+ | ${ev.isNull} = true;
+ |} else {
+ | ${trimString.code}
+ | if (${trimString.isNull}) {
+ | ${ev.isNull} = true;
+ | } else {
+ | ${ev.value} = ${srcString.value}.$trimMethod(${trimString.value});
+ | }
+ |}""".stripMargin)
+ }
+ }
+
override def sql: String = if (trimStr.isDefined) {
s"TRIM($direction ${trimStr.get.sql} FROM ${srcStr.sql})"
} else {
@@ -840,9 +889,7 @@ object StringTrim {
""",
since = "1.5.0",
group = "string_funcs")
-case class StringTrim(
- srcStr: Expression,
- trimStr: Option[Expression] = None)
+case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None)
extends String2TrimExpression {
def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr))
@@ -853,51 +900,12 @@ case class StringTrim(
override protected def direction: String = "BOTH"
- override def eval(input: InternalRow): Any = {
- val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
- if (srcString == null) {
- null
- } else {
- if (trimStr.isDefined) {
- srcString.trim(trimStr.get.eval(input).asInstanceOf[UTF8String])
- } else {
- srcString.trim()
- }
- }
- }
+ override def doEval(srcString: UTF8String): UTF8String = srcString.trim()
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val evals = children.map(_.genCode(ctx))
- val srcString = evals(0)
+ override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String =
+ srcString.trim(trimString)
- if (evals.length == 1) {
- ev.copy(evals.map(_.code) :+ code"""
- boolean ${ev.isNull} = false;
- UTF8String ${ev.value} = null;
- if (${srcString.isNull}) {
- ${ev.isNull} = true;
- } else {
- ${ev.value} = ${srcString.value}.trim();
- }""")
- } else {
- val trimString = evals(1)
- val getTrimFunction =
- s"""
- if (${trimString.isNull}) {
- ${ev.isNull} = true;
- } else {
- ${ev.value} = ${srcString.value}.trim(${trimString.value});
- }"""
- ev.copy(evals.map(_.code) :+ code"""
- boolean ${ev.isNull} = false;
- UTF8String ${ev.value} = null;
- if (${srcString.isNull}) {
- ${ev.isNull} = true;
- } else {
- $getTrimFunction
- }""")
- }
- }
+ override val trimMethod: String = "trim"
}
object StringTrimLeft {
@@ -934,9 +942,7 @@ object StringTrimLeft {
""",
since = "1.5.0",
group = "string_funcs")
-case class StringTrimLeft(
- srcStr: Expression,
- trimStr: Option[Expression] = None)
+case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None)
extends String2TrimExpression {
def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr))
@@ -947,51 +953,12 @@ case class StringTrimLeft(
override protected def direction: String = "LEADING"
- override def eval(input: InternalRow): Any = {
- val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
- if (srcString == null) {
- null
- } else {
- if (trimStr.isDefined) {
- srcString.trimLeft(trimStr.get.eval(input).asInstanceOf[UTF8String])
- } else {
- srcString.trimLeft()
- }
- }
- }
+ override def doEval(srcString: UTF8String): UTF8String = srcString.trimLeft()
- override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val evals = children.map(_.genCode(ctx))
- val srcString = evals(0)
+ override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String =
+ srcString.trimLeft(trimString)
- if (evals.length == 1) {
- ev.copy(evals.map(_.code) :+ code"""
- boolean ${ev.isNull} = false;
- UTF8String ${ev.value} = null;
- if (${srcString.isNull}) {
- ${ev.isNull} = true;
- } else {
- ${ev.value} = ${srcString.value}.trimLeft();
- }""")
- } else {
- val trimString = evals(1)
- val getTrimLeftFunction =
- s"""
- if (${trimString.isNull}) {
- ${ev.isNull} = true;
- } else {
- ${ev.value} = ${srcString.value}.trimLeft(${trimString.value});
- }"""
- ev.copy(evals.map(_.code) :+ code"""
- boolean ${ev.isNull} = false;
- UTF8String ${ev.value} = null;
- if (${srcString.isNull}) {
- ${ev.isNull} = true;
- } else {
- $getTrimLeftFunction
- }""")
- }
- }
+ override val trimMethod: String = "trimLeft"
}
object StringTrimRight {
@@ -1030,9 +997,7 @@ object StringTrimRight {
since = "1.5.0",
group = "string_funcs")
// scalastyle:on line.size.limit
-case class StringTrimRight(
- srcStr: Expression,
- trimStr: Option[Expression] = None)
+case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = None)
extends String2TrimExpression {
def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr))
@@ -1043,51 +1008,12 @@ case class StringTrimRight(
override protected def direction: String = "TRAILING"
- override def eval(input: InternalRow): Any = {
- val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
- if (srcString == null) {
- null
- } else {
- if (trimStr.isDefined) {
- srcString.trimRight(trimStr.get.eval(input).asInstanceOf[UTF8String])
- } else {
- srcString.trimRight()
- }
- }
- }
+ override def doEval(srcString: UTF8String): UTF8String = srcString.trimRight()
- override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val evals = children.map(_.genCode(ctx))
- val srcString = evals(0)
+ override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String =
+ srcString.trimRight(trimString)
- if (evals.length == 1) {
- ev.copy(evals.map(_.code) :+ code"""
- boolean ${ev.isNull} = false;
- UTF8String ${ev.value} = null;
- if (${srcString.isNull}) {
- ${ev.isNull} = true;
- } else {
- ${ev.value} = ${srcString.value}.trimRight();
- }""")
- } else {
- val trimString = evals(1)
- val getTrimRightFunction =
- s"""
- if (${trimString.isNull}) {
- ${ev.isNull} = true;
- } else {
- ${ev.value} = ${srcString.value}.trimRight(${trimString.value});
- }"""
- ev.copy(evals.map(_.code) :+ code"""
- boolean ${ev.isNull} = false;
- UTF8String ${ev.value} = null;
- if (${srcString.isNull}) {
- ${ev.isNull} = true;
- } else {
- $getTrimRightFunction
- }""")
- }
- }
+ override val trimMethod: String = "trimRight"
}
/**
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org