You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/09/18 19:12:38 UTC

spark git commit: [SPARK-14878][SQL] Trim characters string function support

Repository: spark
Updated Branches:
  refs/heads/master 3b049abf1 -> c66d64b3d


[SPARK-14878][SQL] Trim characters string function support

#### What changes were proposed in this pull request?

This PR enhances the TRIM function support in Spark SQL by allowing the specification
of trim characters set. Below is the SQL syntax :

``` SQL
<trim function> ::= TRIM <left paren> <trim operands> <right paren>
<trim operands> ::= [ [ <trim specification> ] [ <trim character set> ] FROM ] <trim source>
<trim source> ::= <character value expression>
<trim specification> ::=
  LEADING
| TRAILING
| BOTH
<trim character set> ::= <characters value expression>
```
or
``` SQL
LTRIM (source-exp [, trim-exp])
RTRIM (source-exp [, trim-exp])
```

Here are the documentation link of support of this feature by other mainstream databases.
- **Oracle:** [TRIM function](http://docs.oracle.com/cd/B28359_01/olap.111/b28126/dml_functions_2126.htm#OLADM704)
- **DB2:** [TRIM scalar function](https://www.ibm.com/support/knowledgecenter/en/SSMKHH_10.0.0/com.ibm.etools.mft.doc/ak05270_.htm)
- **MySQL:** [Trim function](http://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_trim)
- **Oracle:** [ltrim](https://docs.oracle.com/cd/B28359_01/olap.111/b28126/dml_functions_2018.htm#OLADM594)
- **DB2:** [ltrim](https://www.ibm.com/support/knowledgecenter/en/SSEPEK_11.0.0/sqlref/src/tpc/db2z_bif_ltrim.html)

This PR is to implement the above enhancement. In the implementation, the design principle is to keep the changes to the minimum. Also, the exiting trim functions (which handles a special case, i.e., trimming space characters) are kept unchanged for performane reasons.
#### How was this patch tested?

The unit test cases are added in the following files:
- UTF8StringSuite.java
- StringExpressionsSuite.scala
- sql/SQLQuerySuite.scala
- StringFunctionsSuite.scala

Author: Kevin Yu <qy...@us.ibm.com>

Closes #12646 from kevinyu98/spark-14878.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c66d64b3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c66d64b3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c66d64b3

Branch: refs/heads/master
Commit: c66d64b3df9d9ffba0b16a62015680f6f876fc68
Parents: 3b049ab
Author: Kevin Yu <qy...@us.ibm.com>
Authored: Mon Sep 18 12:12:35 2017 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Mon Sep 18 12:12:35 2017 -0700

----------------------------------------------------------------------
 .../apache/spark/unsafe/types/UTF8String.java   |  93 ++++++
 .../spark/unsafe/types/UTF8StringSuite.java     |  57 ++++
 .../apache/spark/sql/catalyst/parser/SqlBase.g4 |   6 +
 .../expressions/stringExpressions.scala         | 280 +++++++++++++++++--
 .../spark/sql/catalyst/parser/AstBuilder.scala  |  24 +-
 .../expressions/StringExpressionsSuite.scala    |  65 ++++-
 .../sql/catalyst/parser/PlanParserSuite.scala   |  18 ++
 .../parser/TableIdentifierParserSuite.scala     |   2 +-
 .../scala/org/apache/spark/sql/functions.scala  |  27 ++
 .../apache/spark/sql/StringFunctionsSuite.scala |  14 +-
 10 files changed, 554 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c66d64b3/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index dd67f15..76db0fb 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -511,6 +511,21 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
     }
   }
 
+  /**
+   * Based on the given trim string, trim this string starting from both ends
+   * This method searches for each character in the source string, removes the character if it is found
+   * in the trim string, stops at the first not found. It calls the trimLeft first, then trimRight.
+   * It returns a new string in which both ends trim characters have been removed.
+   * @param trimString the trim character string
+   */
+  public UTF8String trim(UTF8String trimString) {
+    if (trimString != null) {
+      return trimLeft(trimString).trimRight(trimString);
+    } else {
+      return null;
+    }
+  }
+
   public UTF8String trimLeft() {
     int s = 0;
     // skip all of the space (0x20) in the left side
@@ -523,6 +538,40 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
     }
   }
 
+  /**
+   * Based on the given trim string, trim this string starting from left end
+   * This method searches each character in the source string starting from the left end, removes the character if it
+   * is in the trim string, stops at the first character which is not in the trim string, returns the new string.
+   * @param trimString the trim character string
+   */
+  public UTF8String trimLeft(UTF8String trimString) {
+    if (trimString == null) return null;
+    // the searching byte position in the source string
+    int srchIdx = 0;
+    // the first beginning byte position of a non-matching character
+    int trimIdx = 0;
+
+    while (srchIdx < numBytes) {
+      UTF8String searchChar = copyUTF8String(srchIdx, srchIdx + numBytesForFirstByte(this.getByte(srchIdx)) - 1);
+      int searchCharBytes = searchChar.numBytes;
+      // try to find the matching for the searchChar in the trimString set
+      if (trimString.find(searchChar, 0) >= 0) {
+        trimIdx += searchCharBytes;
+      } else {
+        // no matching, exit the search
+        break;
+      }
+      srchIdx += searchCharBytes;
+    }
+
+    if (trimIdx >= numBytes) {
+      // empty string
+      return EMPTY_UTF8;
+    } else {
+      return copyUTF8String(trimIdx, numBytes - 1);
+    }
+  }
+
   public UTF8String trimRight() {
     int e = numBytes - 1;
     // skip all of the space (0x20) in the right side
@@ -536,6 +585,50 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
     }
   }
 
+  /**
+   * Based on the given trim string, trim this string starting from right end
+   * This method searches each character in the source string starting from the right end, removes the character if it
+   * is in the trim string, stops at the first character which is not in the trim string, returns the new string.
+   * @param trimString the trim character string
+   */
+  public UTF8String trimRight(UTF8String trimString) {
+    if (trimString == null) return null;
+    int charIdx = 0;
+    // number of characters from the source string
+    int numChars = 0;
+    // array of character length for the source string
+    int[] stringCharLen = new int[numBytes];
+    // array of the first byte position for each character in the source string
+    int[] stringCharPos = new int[numBytes];
+    // build the position and length array
+    while (charIdx < numBytes) {
+      stringCharPos[numChars] = charIdx;
+      stringCharLen[numChars] = numBytesForFirstByte(getByte(charIdx));
+      charIdx += stringCharLen[numChars];
+      numChars ++;
+    }
+
+    // index trimEnd points to the first no matching byte position from the right side of the source string.
+    int trimEnd = numBytes - 1;
+    while (numChars > 0) {
+      UTF8String searchChar =
+        copyUTF8String(stringCharPos[numChars - 1], stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1);
+      if (trimString.find(searchChar, 0) >= 0) {
+        trimEnd -= stringCharLen[numChars - 1];
+      } else {
+        break;
+      }
+      numChars --;
+    }
+
+    if (trimEnd < 0) {
+      // empty string
+      return EMPTY_UTF8;
+    } else {
+      return copyUTF8String(0, trimEnd);
+    }
+  }
+
   public UTF8String reverse() {
     byte[] result = new byte[this.numBytes];
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c66d64b3/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index c376371..f086001 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -730,4 +730,61 @@ public class UTF8StringSuite {
       assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper));
     }
   }
+
+  @Test
+  public void trimBothWithTrimString() {
+    assertEquals(fromString("hello"), fromString("  hello ").trim(fromString(" ")));
+    assertEquals(fromString("o"), fromString("  hello ").trim(fromString(" hle")));
+    assertEquals(fromString("h e"), fromString("ooh e ooo").trim(fromString("o ")));
+    assertEquals(fromString(""), fromString("ooo...oooo").trim(fromString("o.")));
+    assertEquals(fromString("b"), fromString("%^b[]@").trim(fromString("][@^%")));
+
+    assertEquals(EMPTY_UTF8, fromString("  ").trim(fromString(" ")));
+
+    assertEquals(fromString("数据砖头"), fromString("  数据砖头 ").trim());
+    assertEquals(fromString("数"), fromString("a数b").trim(fromString("ab")));
+    assertEquals(fromString(""), fromString("a").trim(fromString("a数b")));
+    assertEquals(fromString(""), fromString("数数 数数数").trim(fromString("数 ")));
+    assertEquals(fromString("据砖头"), fromString("数]数[数据砖头#数数").trim(fromString("[数]#")));
+    assertEquals(fromString("据砖头数数 "), fromString("数数数据砖头数数 ").trim(fromString("数")));
+  }
+
+  @Test
+  public void trimLeftWithTrimString() {
+    assertEquals(fromString("  hello "), fromString("  hello ").trimLeft(fromString("")));
+    assertEquals(fromString(""), fromString("a").trimLeft(fromString("a")));
+    assertEquals(fromString("b"), fromString("b").trimLeft(fromString("a")));
+    assertEquals(fromString("ba"), fromString("ba").trimLeft(fromString("a")));
+    assertEquals(fromString(""), fromString("aaaaaaa").trimLeft(fromString("a")));
+    assertEquals(fromString("trim"), fromString("oabtrim").trimLeft(fromString("bao")));
+    assertEquals(fromString("rim "), fromString("ooootrim ").trimLeft(fromString("otm")));
+
+    assertEquals(EMPTY_UTF8, fromString("  ").trimLeft(fromString(" ")));
+
+    assertEquals(fromString("数据砖头 "), fromString("  数据砖头 ").trimLeft(fromString(" ")));
+    assertEquals(fromString("数"), fromString("数").trimLeft(fromString("a")));
+    assertEquals(fromString("a"), fromString("a").trimLeft(fromString("数")));
+    assertEquals(fromString("砖头数数"), fromString("数数数据砖头数数").trimLeft(fromString("据数")));
+    assertEquals(fromString("据砖头数数"), fromString(" 数数数据砖头数数").trimLeft(fromString("数 ")));
+    assertEquals(fromString("据砖头数数"), fromString("aa数数数据砖头数数").trimLeft(fromString("a数砖")));
+    assertEquals(fromString("$S,.$BR"), fromString(",,,,%$S,.$BR").trimLeft(fromString("%,")));
+  }
+
+  @Test
+  public void trimRightWithTrimString() {
+    assertEquals(fromString("  hello "), fromString("  hello ").trimRight(fromString("")));
+    assertEquals(fromString(""), fromString("a").trimRight(fromString("a")));
+    assertEquals(fromString("cc"), fromString("ccbaaaa").trimRight(fromString("ba")));
+    assertEquals(fromString(""), fromString("aabbbbaaa").trimRight(fromString("ab")));
+    assertEquals(fromString("  he"), fromString("  hello ").trimRight(fromString(" ol")));
+    assertEquals(fromString("oohell"), fromString("oohellooo../*&").trimRight(fromString("./,&%*o")));
+
+    assertEquals(EMPTY_UTF8, fromString("  ").trimRight(fromString(" ")));
+
+    assertEquals(fromString("  数据砖头"), fromString("  数据砖头 ").trimRight(fromString(" ")));
+    assertEquals(fromString("数数砖头"), fromString("数数砖头数aa数").trimRight(fromString("a数")));
+    assertEquals(fromString(""), fromString("数数数据砖ab").trimRight(fromString("数据砖ab")));
+    assertEquals(fromString("头"), fromString("头a???/").trimRight(fromString("数?/*&^%a")));
+    assertEquals(fromString("头"), fromString("头数b数数 [").trimRight(fromString(" []数b")));
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c66d64b3/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 33bc79a..d0a5428 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -580,6 +580,8 @@ primaryExpression
     | '(' query ')'                                                                            #subqueryExpression
     | qualifiedName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')'
        (OVER windowSpec)?                                                                      #functionCall
+    | qualifiedName '(' trimOption=(BOTH | LEADING | TRAILING) argument+=expression
+      FROM argument+=expression ')'                                                            #functionCall
     | value=primaryExpression '[' index=valueExpression ']'                                    #subscript
     | identifier                                                                               #columnReference
     | base=primaryExpression '.' fieldName=identifier                                          #dereference
@@ -748,6 +750,7 @@ nonReserved
     | UNBOUNDED | WHEN
     | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT | CURRENT_DATE | CURRENT_TIMESTAMP
     | DIRECTORY
+    | BOTH | LEADING | TRAILING
     ;
 
 SELECT: 'SELECT';
@@ -861,6 +864,9 @@ COMMIT: 'COMMIT';
 ROLLBACK: 'ROLLBACK';
 MACRO: 'MACRO';
 IGNORE: 'IGNORE';
+BOTH: 'BOTH';
+LEADING: 'LEADING';
+TRAILING: 'TRAILING';
 
 IF: 'IF';
 POSITION: 'POSITION';

http://git-wip-us.apache.org/repos/asf/spark/blob/c66d64b3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 7ab45a6..83de515 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -24,6 +24,7 @@ import java.util.regex.Pattern
 
 import scala.collection.mutable.ArrayBuffer
 
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -503,69 +504,304 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi
   override def prettyName: String = "find_in_set"
 }
 
+trait String2TrimExpression extends Expression with ImplicitCastInputTypes {
+
+  override def dataType: DataType = StringType
+  override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
+
+  override def nullable: Boolean = children.exists(_.nullable)
+  override def foldable: Boolean = children.forall(_.foldable)
+}
+
+object StringTrim {
+  def apply(str: Expression, trimStr: Expression) : StringTrim = StringTrim(str, Some(trimStr))
+  def apply(str: Expression) : StringTrim = StringTrim(str, None)
+}
+
 /**
- * A function that trim the spaces from both ends for the specified string.
+ * A function that takes a character string, removes the leading and trailing characters matching with any character
+ * in the trim string, returns the new string.
+ * If BOTH and trimStr keywords are not specified, it defaults to remove space character from both ends. The trim
+ * function will have one argument, which contains the source string.
+ * If BOTH and trimStr keywords are specified, it trims the characters from both ends, and the trim function will have
+ * two arguments, the first argument contains trimStr, the second argument contains the source string.
+ * trimStr: A character string to be trimmed from the source string, if it has multiple characters, the function
+ * searches for each character in the source string, removes the characters from the source string until it
+ * encounters the first non-match character.
+ * BOTH: removes any character from both ends of the source string that matches characters in the trim string.
  */
 @ExpressionDescription(
-  usage = "_FUNC_(str) - Removes the leading and trailing space characters from `str`.",
+  usage = """
+    _FUNC_(str) - Removes the leading and trailing space characters from `str`.
+    _FUNC_(BOTH trimStr FROM str) - Remove the leading and trailing trimString from `str`
+  """,
+  arguments = """
+    Arguments:
+      * str - a string expression
+      * trimString - the trim string
+      * BOTH, FROM - these are keyword to specify for trim string from both ends of the string
+  """,
   examples = """
     Examples:
       > SELECT _FUNC_('    SparkSQL   ');
        SparkSQL
+      > SELECT _FUNC_(BOTH 'SL' FROM 'SSparkSQLS');
+       parkSQ
   """)
-case class StringTrim(child: Expression)
-  extends UnaryExpression with String2StringExpression {
+case class StringTrim(
+    srcStr: Expression,
+    trimStr: Option[Expression] = None)
+  extends String2TrimExpression {
+
+  def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr))
 
-  def convert(v: UTF8String): UTF8String = v.trim()
+  def this(srcStr: Expression) = this(srcStr, None)
 
   override def prettyName: String = "trim"
 
+  override def children: Seq[Expression] = if (trimStr.isDefined) {
+    srcStr :: trimStr.get :: Nil
+  } else {
+    srcStr :: Nil
+  }
+  override def eval(input: InternalRow): Any = {
+    val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
+    if (srcString == null) {
+      null
+    } else {
+      if (trimStr.isDefined) {
+        srcString.trim(trimStr.get.eval(input).asInstanceOf[UTF8String])
+      } else {
+        srcString.trim()
+      }
+    }
+  }
+
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    defineCodeGen(ctx, ev, c => s"($c).trim()")
+    val evals = children.map(_.genCode(ctx))
+    val srcString = evals(0)
+
+    if (evals.length == 1) {
+      ev.copy(evals.map(_.code).mkString + s"""
+        boolean ${ev.isNull} = false;
+        UTF8String ${ev.value} = null;
+        if (${srcString.isNull}) {
+          ${ev.isNull} = true;
+        } else {
+          ${ev.value} = ${srcString.value}.trim();
+        }""")
+    } else {
+      val trimString = evals(1)
+      val getTrimFunction =
+        s"""
+        if (${trimString.isNull}) {
+          ${ev.isNull} = true;
+        } else {
+          ${ev.value} = ${srcString.value}.trim(${trimString.value});
+        }"""
+      ev.copy(evals.map(_.code).mkString + s"""
+        boolean ${ev.isNull} = false;
+        UTF8String ${ev.value} = null;
+        if (${srcString.isNull}) {
+          ${ev.isNull} = true;
+        } else {
+          $getTrimFunction
+        }""")
+    }
   }
 }
 
+object StringTrimLeft {
+  def apply(str: Expression, trimStr: Expression) : StringTrimLeft = StringTrimLeft(str, Some(trimStr))
+  def apply(str: Expression) : StringTrimLeft = StringTrimLeft(str, None)
+}
+
 /**
- * A function that trim the spaces from left end for given string.
+ * A function that trims the characters from left end for a given string.
+ * If LEADING and trimStr keywords are not specified, it defaults to remove space character from the left end. The ltrim
+ * function will have one argument, which contains the source string.
+ * If LEADING and trimStr keywords are not specified, it trims the characters from left end. The ltrim function will
+ * have two arguments, the first argument contains trimStr, the second argument contains the source string.
+ * trimStr: the function removes any character from the left end of the source string which matches with the characters
+ * from trimStr, it stops at the first non-match character.
+ * LEADING: removes any character from the left end of the source string that matches characters in the trim string.
  */
 @ExpressionDescription(
-  usage = "_FUNC_(str) - Removes the leading and trailing space characters from `str`.",
+  usage = """
+    _FUNC_(str) - Removes the leading space characters from `str`.
+    _FUNC_(trimStr, str) - Removes the leading string contains the characters from the trim string
+  """,
+  arguments = """
+    Arguments:
+      * str - a string expression
+      * trimString - the trim string
+      * BOTH, FROM - these are keyword to specify for trim string from both ends of the string
+  """,
   examples = """
     Examples:
-      > SELECT _FUNC_('    SparkSQL');
+      > SELECT _FUNC_('    SparkSQL   ');
        SparkSQL
+      > SELECT _FUNC_('Sp', 'SSparkSQLS');
+       arkSQLS
   """)
-case class StringTrimLeft(child: Expression)
-  extends UnaryExpression with String2StringExpression {
+case class StringTrimLeft(
+    srcStr: Expression,
+    trimStr: Option[Expression] = None)
+  extends String2TrimExpression {
 
-  def convert(v: UTF8String): UTF8String = v.trimLeft()
+  def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr))
+
+  def this(srcStr: Expression) = this(srcStr, None)
 
   override def prettyName: String = "ltrim"
 
-  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    defineCodeGen(ctx, ev, c => s"($c).trimLeft()")
+  override def children: Seq[Expression] = if (trimStr.isDefined) {
+    srcStr :: trimStr.get :: Nil
+  } else {
+    srcStr :: Nil
   }
+
+  override def eval(input: InternalRow): Any = {
+    val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
+    if (srcString == null) {
+      null
+    } else {
+      if (trimStr.isDefined) {
+        srcString.trimLeft(trimStr.get.eval(input).asInstanceOf[UTF8String])
+      } else {
+        srcString.trimLeft()
+      }
+    }
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val evals = children.map(_.genCode(ctx))
+    val srcString = evals(0)
+
+    if (evals.length == 1) {
+      ev.copy(evals.map(_.code).mkString + s"""
+        boolean ${ev.isNull} = false;
+        UTF8String ${ev.value} = null;
+        if (${srcString.isNull}) {
+          ${ev.isNull} = true;
+        } else {
+          ${ev.value} = ${srcString.value}.trimLeft();
+        }""")
+    } else {
+      val trimString = evals(1)
+      val getTrimLeftFunction =
+        s"""
+        if (${trimString.isNull}) {
+          ${ev.isNull} = true;
+        } else {
+          ${ev.value} = ${srcString.value}.trimLeft(${trimString.value});
+        }"""
+      ev.copy(evals.map(_.code).mkString + s"""
+        boolean ${ev.isNull} = false;
+        UTF8String ${ev.value} = null;
+        if (${srcString.isNull}) {
+          ${ev.isNull} = true;
+        } else {
+          $getTrimLeftFunction
+        }""")
+    }
+  }
+}
+
+object StringTrimRight {
+  def apply(str: Expression, trimStr: Expression) : StringTrimRight = StringTrimRight(str, Some(trimStr))
+  def apply(str: Expression) : StringTrimRight = StringTrimRight(str, None)
 }
 
 /**
- * A function that trim the spaces from right end for given string.
+ * A function that trims the characters from right end for a given string.
+ * If TRAILING and trimStr keywords are not specified, it defaults to remove space character from the right end. The
+ * rtrim function will have one argument, which contains the source string.
+ * If TRAILING and trimStr keywords are specified, it trims the characters from right end. The rtrim function will
+ * have two arguments, the first argument contains trimStr, the second argument contains the source string.
+ * trimStr: the function removes any character from the right end of source string which matches with the characters
+ * from trimStr, it stops at the first non-match character.
+ * TRAILING: removes any character from the right end of the source string that matches characters in the trim string.
  */
 @ExpressionDescription(
-  usage = "_FUNC_(str) - Removes the trailing space characters from `str`.",
+  usage = """
+    _FUNC_(str) - Removes the trailing space characters from `str`.
+    _FUNC_(trimStr, str) - Removes the trailing string which contains the characters from the trim string from the `str`
+  """,
+  arguments = """
+    Arguments:
+      * str - a string expression
+      * trimString - the trim string
+      * BOTH, FROM - these are keyword to specify for trim string from both ends of the string
+  """,
   examples = """
     Examples:
       > SELECT _FUNC_('    SparkSQL   ');
-           SparkSQL
+       SparkSQL
+      > SELECT _FUNC_('LQSa', 'SSparkSQLS');
+       SSpark
   """)
-case class StringTrimRight(child: Expression)
-  extends UnaryExpression with String2StringExpression {
+case class StringTrimRight(
+    srcStr: Expression,
+    trimStr: Option[Expression] = None)
+  extends String2TrimExpression {
 
-  def convert(v: UTF8String): UTF8String = v.trimRight()
+  def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr))
+
+  def this(srcStr: Expression) = this(srcStr, None)
 
   override def prettyName: String = "rtrim"
 
-  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    defineCodeGen(ctx, ev, c => s"($c).trimRight()")
+  override def children: Seq[Expression] = if (trimStr.isDefined) {
+    srcStr :: trimStr.get :: Nil
+  } else {
+    srcStr :: Nil
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
+    if (srcString == null) {
+      null
+    } else {
+      if (trimStr.isDefined) {
+        srcString.trimRight(trimStr.get.eval(input).asInstanceOf[UTF8String])
+      } else {
+        srcString.trimRight()
+      }
+    }
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val evals = children.map(_.genCode(ctx))
+    val srcString = evals(0)
+
+    if (evals.length == 1) {
+      ev.copy(evals.map(_.code).mkString + s"""
+        boolean ${ev.isNull} = false;
+        UTF8String ${ev.value} = null;
+        if (${srcString.isNull}) {
+          ${ev.isNull} = true;
+        } else {
+          ${ev.value} = ${srcString.value}.trimRight();
+        }""")
+    } else {
+      val trimString = evals(1)
+      val getTrimRightFunction =
+        s"""
+        if (${trimString.isNull}) {
+          ${ev.isNull} = true;
+        } else {
+          ${ev.value} = ${srcString.value}.trimRight(${trimString.value});
+        }"""
+      ev.copy(evals.map(_.code).mkString + s"""
+        boolean ${ev.isNull} = false;
+        UTF8String ${ev.value} = null;
+        if (${srcString.isNull}) {
+          ${ev.isNull} = true;
+        } else {
+          $getTrimRightFunction
+        }""")
+    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c66d64b3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 891f616..85b492e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1179,6 +1179,26 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
    * Create a (windowed) Function expression.
    */
   override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) {
+    def replaceFunctions(
+        funcID: FunctionIdentifier,
+        ctx: FunctionCallContext): FunctionIdentifier = {
+      val opt = ctx.trimOption
+      if (opt != null) {
+        if (ctx.qualifiedName.getText.toLowerCase(Locale.ROOT) != "trim") {
+          throw new ParseException(s"The specified function ${ctx.qualifiedName.getText} " +
+            s"doesn't support with option ${opt.getText}.", ctx)
+        }
+        opt.getType match {
+          case SqlBaseParser.BOTH => funcID
+          case SqlBaseParser.LEADING => funcID.copy(funcName = "ltrim")
+          case SqlBaseParser.TRAILING => funcID.copy(funcName = "rtrim")
+          case _ => throw new ParseException("Function trim doesn't support with " +
+            s"type ${opt.getType}. Please use BOTH, LEADING or Trailing as trim type", ctx)
+        }
+      } else {
+        funcID
+      }
+    }
     // Create the function call.
     val name = ctx.qualifiedName.getText
     val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null)
@@ -1190,7 +1210,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
       case expressions =>
         expressions
     }
-    val function = UnresolvedFunction(visitFunctionName(ctx.qualifiedName), arguments, isDistinct)
+    val funcId = replaceFunctions(visitFunctionName(ctx.qualifiedName), ctx)
+    val function = UnresolvedFunction(funcId, arguments, isDistinct)
+
 
     // Check if the function is evaluated in a windowed context.
     ctx.windowSpec match {

http://git-wip-us.apache.org/repos/asf/spark/blob/c66d64b3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 4f08031..18ef4bc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.types._
 
-
 class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
   test("concat") {
@@ -406,26 +405,78 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     // scalastyle:on
   }
 
-  test("TRIM/LTRIM/RTRIM") {
+  test("TRIM") {
     val s = 'a.string.at(0)
     checkEvaluation(StringTrim(Literal(" aa  ")), "aa", create_row(" abdef "))
+    checkEvaluation(StringTrim("aa", "a"), "", create_row(" abdef "))
+    checkEvaluation(StringTrim(Literal(" aabbtrimccc"), "ab cd"), "trim", create_row("bdef"))
+    checkEvaluation(StringTrim(Literal("a<a >@>.,>"), "a.,@<>"), " ", create_row(" abdef "))
     checkEvaluation(StringTrim(s), "abdef", create_row(" abdef "))
+    checkEvaluation(StringTrim(s, "abd"), "ef", create_row("abdefa"))
+    checkEvaluation(StringTrim(s, "a"), "bdef", create_row("aaabdefaaaa"))
+    checkEvaluation(StringTrim(s, "SLSQ"), "park", create_row("SSparkSQLS"))
 
+    // scalastyle:off
+    // non ascii characters are not allowed in the source code, so we disable the scalastyle.
+    checkEvaluation(StringTrim(s), "花花世界", create_row("  花花世界 "))
+    checkEvaluation(StringTrim(s, "花世界"), "", create_row("花花世界花花"))
+    checkEvaluation(StringTrim(s, "花 "), "世界", create_row(" 花花世界花花"))
+    checkEvaluation(StringTrim(s, "花 "), "世界", create_row(" 花 花 世界 花 花 "))
+    checkEvaluation(StringTrim(s, "a花世"), "界", create_row("aa花花世界花花aa"))
+    checkEvaluation(StringTrim(s, "a@#( )"), "花花世界花花", create_row("aa()花花世界花花@ #"))
+    checkEvaluation(StringTrim(Literal("花trim"), "花 "), "trim", create_row(" abdef "))
+    // scalastyle:on
+    checkEvaluation(StringTrim(Literal("a"), Literal.create(null, StringType)), null)
+    checkEvaluation(StringTrim(Literal.create(null, StringType), Literal("a")), null)
+  }
+
+  test("LTRIM") {
+    val s = 'a.string.at(0)
     checkEvaluation(StringTrimLeft(Literal(" aa  ")), "aa  ", create_row(" abdef "))
+    checkEvaluation(StringTrimLeft(Literal("aa"), "a"), "", create_row(" abdef "))
+    checkEvaluation(StringTrimLeft(Literal("aa "), "a "), "", create_row(" abdef "))
+    checkEvaluation(StringTrimLeft(Literal("aabbcaaaa"), "ab"), "caaaa", create_row(" abdef "))
     checkEvaluation(StringTrimLeft(s), "abdef ", create_row(" abdef "))
+    checkEvaluation(StringTrimLeft(s, "a"), "bdefa", create_row("abdefa"))
+    checkEvaluation(StringTrimLeft(s, "a "), "bdefaaaa", create_row(" aaabdefaaaa"))
+    checkEvaluation(StringTrimLeft(s, "Spk"), "arkSQLS", create_row("SSparkSQLS"))
 
+    // scalastyle:off
+    // non ascii characters are not allowed in the source code, so we disable the scalastyle.
+    checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row("  花花世界 "))
+    checkEvaluation(StringTrimLeft(s, "花"), "世界花花", create_row("花花世界花花"))
+    checkEvaluation(StringTrimLeft(s, "花 世"), "界花花", create_row(" 花花世界花花"))
+    checkEvaluation(StringTrimLeft(s, "花"), "a花花世界花花 ", create_row("a花花世界花花 "))
+    checkEvaluation(StringTrimLeft(s, "a花界"), "世界花花aa", create_row("aa花花世界花花aa"))
+    checkEvaluation(StringTrimLeft(s, "a世界"), "花花世界花花", create_row("花花世界花花"))
+    // scalastyle:on
+    checkEvaluation(StringTrimLeft(Literal.create(null, StringType), Literal("a")), null)
+    checkEvaluation(StringTrimLeft(Literal("a"), Literal.create(null, StringType)), null)
+  }
+
+  test("RTRIM") {
+    val s = 'a.string.at(0)
     checkEvaluation(StringTrimRight(Literal(" aa  ")), " aa", create_row(" abdef "))
+    checkEvaluation(StringTrimRight(Literal("a"), "a"), "", create_row(" abdef "))
+    checkEvaluation(StringTrimRight(Literal("ab"), "ab"), "", create_row(" abdef "))
+    checkEvaluation(StringTrimRight(Literal("aabbaaaa %"), "a %"), "aabb", create_row("def"))
     checkEvaluation(StringTrimRight(s), " abdef", create_row(" abdef "))
+    checkEvaluation(StringTrimRight(s, "a"), "abdef", create_row("abdefa"))
+    checkEvaluation(StringTrimRight(s, "abf de"), "", create_row(" aaabdefaaaa"))
+    checkEvaluation(StringTrimRight(s, "S*&"), "SSparkSQL", create_row("SSparkSQLS*"))
 
     // scalastyle:off
     // non ascii characters are not allowed in the source code, so we disable the scalastyle.
+    checkEvaluation(StringTrimRight(Literal("a"), "花"), "a", create_row(" abdef "))
+    checkEvaluation(StringTrimRight(Literal("花"), "a"), "花", create_row(" abdef "))
+    checkEvaluation(StringTrimRight(Literal("花花世界"), "界花世"), "", create_row(" abdef "))
     checkEvaluation(StringTrimRight(s), "  花花世界", create_row("  花花世界 "))
-    checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row("  花花世界 "))
-    checkEvaluation(StringTrim(s), "花花世界", create_row("  花花世界 "))
+    checkEvaluation(StringTrimRight(s, "花a#"), "花花世界", create_row("花花世界花花###aa花"))
+    checkEvaluation(StringTrimRight(s, "花"), "", create_row("花花花花"))
+    checkEvaluation(StringTrimRight(s, "花 界b@"), " 花花世", create_row(" 花花世 b界@花花 "))
     // scalastyle:on
-    checkEvaluation(StringTrim(Literal.create(null, StringType)), null)
-    checkEvaluation(StringTrimLeft(Literal.create(null, StringType)), null)
-    checkEvaluation(StringTrimRight(Literal.create(null, StringType)), null)
+    checkEvaluation(StringTrimRight(Literal("a"), Literal.create(null, StringType)), null)
+    checkEvaluation(StringTrimRight(Literal.create(null, StringType), Literal("a")), null)
   }
 
   test("FORMAT") {

http://git-wip-us.apache.org/repos/asf/spark/blob/c66d64b3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index b0d2fb2..306e6f2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -651,4 +651,22 @@ class PlanParserSuite extends AnalysisTest {
       )
     )
   }
+
+  test("TRIM function") {
+    intercept("select ltrim(both 'S' from 'SS abc S'", "missing ')' at '<EOF>'")
+    intercept("select rtrim(trailing 'S' from 'SS abc S'", "missing ')' at '<EOF>'")
+
+    assertEqual(
+      "SELECT TRIM(BOTH '@$%&( )abc' FROM '@ $ % & ()abc ' )",
+        OneRowRelation().select('TRIM.function("@$%&( )abc", "@ $ % & ()abc "))
+    )
+    assertEqual(
+      "SELECT TRIM(LEADING 'c []' FROM '[ ccccbcc ')",
+        OneRowRelation().select('ltrim.function("c []", "[ ccccbcc "))
+    )
+    assertEqual(
+      "SELECT TRIM(TRAILING 'c&^,.' FROM 'bc...,,,&&&ccc')",
+      OneRowRelation().select('rtrim.function("c&^,.", "bc...,,,&&&ccc"))
+    )
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c66d64b3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
index 76be6ee..cc80a41 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
@@ -51,7 +51,7 @@ class TableIdentifierParserSuite extends SparkFunSuite {
     "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger",
     "true", "truncate", "update", "user", "values", "with", "regexp", "rlike",
     "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float",
-    "int", "smallint", "timestamp", "at", "position")
+    "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing")
 
   val hiveStrictNonReservedKeyword = Seq("anti", "full", "inner", "left", "semi", "right",
     "natural", "union", "intersect", "except", "database", "on", "join", "cross", "select", "from",

http://git-wip-us.apache.org/repos/asf/spark/blob/c66d64b3/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 47324ed..c6d0d86 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2334,6 +2334,15 @@ object functions {
   def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) }
 
   /**
+   * Trim the specified character string from left end for the specified string column.
+   * @group string_funcs
+   * @since 2.3.0
+   */
+  def ltrim(e: Column, trimString: String): Column = withExpr {
+    StringTrimLeft(e.expr, Literal(trimString))
+  }
+
+  /**
    * Extract a specific group matched by a Java regex, from the specified string column.
    * If the regex did not match, or the specified group did not match, an empty string is returned.
    *
@@ -2411,6 +2420,15 @@ object functions {
   def rtrim(e: Column): Column = withExpr { StringTrimRight(e.expr) }
 
   /**
+   * Trim the specified character string from right end for the specified string column.
+   * @group string_funcs
+   * @since 2.3.0
+   */
+  def rtrim(e: Column, trimString: String): Column = withExpr {
+    StringTrimRight(e.expr, Literal(trimString))
+  }
+
+  /**
    * Returns the soundex code for the specified expression.
    *
    * @group string_funcs
@@ -2478,6 +2496,15 @@ object functions {
   def trim(e: Column): Column = withExpr { StringTrim(e.expr) }
 
   /**
+   * Trim the specified character from both ends for the specified string column.
+   * @group string_funcs
+   * @since 2.3.0
+   */
+  def trim(e: Column, trimString: String): Column = withExpr {
+    StringTrim(e.expr, Literal(trimString))
+  }
+
+  /**
    * Converts a string column to upper case.
    *
    * @group string_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/c66d64b3/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index a12efc8..3d76b9a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -161,13 +161,25 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
   }
 
   test("string trim functions") {
-    val df = Seq(("  example  ", "")).toDF("a", "b")
+    val df = Seq(("  example  ", "", "example")).toDF("a", "b", "c")
 
     checkAnswer(
       df.select(ltrim($"a"), rtrim($"a"), trim($"a")),
       Row("example  ", "  example", "example"))
 
     checkAnswer(
+      df.select(ltrim($"c", "e"), rtrim($"c", "e"), trim($"c", "e")),
+      Row("xample", "exampl", "xampl"))
+
+    checkAnswer(
+      df.select(ltrim($"c", "xe"), rtrim($"c", "emlp"), trim($"c", "elxp")),
+      Row("ample", "exa", "am"))
+
+    checkAnswer(
+      df.select(trim($"c", "xyz")),
+      Row("example"))
+
+    checkAnswer(
       df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"),
       Row("example  ", "  example", "example"))
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org