You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2020/02/12 07:00:44 UTC

[spark] 01/02: Revert "[SPARK-30625][SQL] Support `escape` as third parameter of the `like` function

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git

commit 9c739358487acf3cd2d5171cebdd053c098aeb8c
Author: Maxim Gekk <ma...@gmail.com>
AuthorDate: Tue Feb 11 10:15:34 2020 -0800

    Revert "[SPARK-30625][SQL] Support `escape` as third parameter of the `like` function
    
    In the PR, I propose to revert the commit 8aebc80e0e67bcb1aa300b8c8b1a209159237632.
    
    See the concerns https://github.com/apache/spark/pull/27355#issuecomment-584344438
    
    No
    
    By existing test suites.
    
    Closes #27531 from MaxGekk/revert-like-3-args.
    
    Authored-by: Maxim Gekk <ma...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../apache/spark/sql/catalyst/dsl/package.scala    |  2 +-
 .../catalyst/expressions/regexpExpressions.scala   | 85 +++++++---------------
 .../spark/sql/catalyst/parser/AstBuilder.scala     |  4 +-
 .../apache/spark/sql/DataFrameFunctionsSuite.scala | 15 ----
 4 files changed, 31 insertions(+), 75 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 4099808..b4a8baf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -99,7 +99,7 @@ package object dsl {
     }
 
     def like(other: Expression, escapeChar: Char = '\\'): Expression =
-      Like(expr, other, Literal(escapeChar.toString))
+      Like(expr, other, escapeChar)
     def rlike(other: Expression): Expression = RLike(expr, other)
     def contains(other: Expression): Expression = Contains(expr, other)
     def startsWith(other: Expression): Expression = StartsWith(expr, other)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index f84c476..32a653d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -22,7 +22,6 @@ import java.util.regex.{MatchResult, Pattern}
 
 import org.apache.commons.text.StringEscapeUtils
 
-import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
@@ -30,19 +29,17 @@ import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
 
-trait StringRegexExpression extends Expression
+abstract class StringRegexExpression extends BinaryExpression
   with ImplicitCastInputTypes with NullIntolerant {
 
-  def str: Expression
-  def pattern: Expression
-
   def escape(v: String): String
   def matches(regex: Pattern, str: String): Boolean
 
   override def dataType: DataType = BooleanType
+  override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
 
   // try cache the pattern for Literal
-  private lazy val cache: Pattern = pattern match {
+  private lazy val cache: Pattern = right match {
     case Literal(value: String, StringType) => compile(value)
     case _ => null
   }
@@ -54,9 +51,10 @@ trait StringRegexExpression extends Expression
     Pattern.compile(escape(str))
   }
 
-  def nullSafeMatch(input1: Any, input2: Any): Any = {
-    val s = input2.asInstanceOf[UTF8String].toString
-    val regex = if (cache == null) compile(s) else cache
+  protected def pattern(str: String) = if (cache == null) compile(str) else cache
+
+  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val regex = pattern(input2.asInstanceOf[UTF8String].toString)
     if(regex == null) {
       null
     } else {
@@ -64,7 +62,7 @@ trait StringRegexExpression extends Expression
     }
   }
 
-  override def sql: String = s"${str.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${pattern.sql}"
+  override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}"
 }
 
 // scalastyle:off line.contains.tab
@@ -109,65 +107,46 @@ trait StringRegexExpression extends Expression
       true
       > SELECT '%SystemDrive%/Users/John' _FUNC_ '/%SystemDrive/%//Users%' ESCAPE '/';
       true
-      > SELECT _FUNC_('_Apache Spark_', '__%Spark__', '_');
-      true
   """,
   note = """
     Use RLIKE to match with standard regular expressions.
   """,
   since = "1.0.0")
 // scalastyle:on line.contains.tab
-case class Like(str: Expression, pattern: Expression, escape: Expression)
-  extends TernaryExpression with StringRegexExpression {
-
-  def this(str: Expression, pattern: Expression) = this(str, pattern, Literal("\\"))
-
-  override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
-  override def children: Seq[Expression] = Seq(str, pattern, escape)
+case class Like(left: Expression, right: Expression, escapeChar: Char)
+  extends StringRegexExpression {
 
-  private lazy val escapeChar: Char = if (escape.foldable) {
-    escape.eval() match {
-      case s: UTF8String if s != null && s.numChars() == 1 => s.toString.charAt(0)
-      case s => throw new AnalysisException(
-        s"The 'escape' parameter must be a string literal of one char but it is $s.")
-    }
-  } else {
-    throw new AnalysisException("The 'escape' parameter must be a string literal.")
-  }
+  def this(left: Expression, right: Expression) = this(left, right, '\\')
 
   override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar)
 
   override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()
 
   override def toString: String = escapeChar match {
-    case '\\' => s"$str LIKE $pattern"
-    case c => s"$str LIKE $pattern ESCAPE '$c'"
-  }
-
-  protected override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = {
-    nullSafeMatch(input1, input2)
+    case '\\' => s"$left LIKE $right"
+    case c => s"$left LIKE $right ESCAPE '$c'"
   }
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val patternClass = classOf[Pattern].getName
     val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex"
 
-    if (pattern.foldable) {
-      val patternVal = pattern.eval()
-      if (patternVal != null) {
+    if (right.foldable) {
+      val rVal = right.eval()
+      if (rVal != null) {
         val regexStr =
-          StringEscapeUtils.escapeJava(escape(patternVal.asInstanceOf[UTF8String].toString()))
-        val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern",
+          StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString()))
+        val pattern = ctx.addMutableState(patternClass, "patternLike",
           v => s"""$v = $patternClass.compile("$regexStr");""")
 
         // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
-        val eval = str.genCode(ctx)
+        val eval = left.genCode(ctx)
         ev.copy(code = code"""
           ${eval.code}
           boolean ${ev.isNull} = ${eval.isNull};
           ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
           if (!${ev.isNull}) {
-            ${ev.value} = $compiledPattern.matcher(${eval.value}.toString()).matches();
+            ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches();
           }
         """)
       } else {
@@ -177,8 +156,8 @@ case class Like(str: Expression, pattern: Expression, escape: Expression)
         """)
       }
     } else {
-      val patternStr = ctx.freshName("patternStr")
-      val compiledPattern = ctx.freshName("compiledPattern")
+      val pattern = ctx.freshName("pattern")
+      val rightStr = ctx.freshName("rightStr")
       // We need double escape to avoid org.codehaus.commons.compiler.CompileException.
       // '\\' will cause exception 'Single quote must be backslash-escaped in character literal'.
       // '\"' will cause exception 'Line break in literal not allowed'.
@@ -187,12 +166,12 @@ case class Like(str: Expression, pattern: Expression, escape: Expression)
       } else {
         escapeChar
       }
-      nullSafeCodeGen(ctx, ev, (eval1, eval2, _) => {
+      nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
         s"""
-          String $patternStr = $eval2.toString();
-          $patternClass $compiledPattern = $patternClass.compile(
-            $escapeFunc($patternStr, '$newEscapeChar'));
-          ${ev.value} = $compiledPattern.matcher($eval1.toString()).matches();
+          String $rightStr = $eval2.toString();
+          $patternClass $pattern = $patternClass.compile(
+            $escapeFunc($rightStr, '$newEscapeChar'));
+          ${ev.value} = $pattern.matcher($eval1.toString()).matches();
         """
       })
     }
@@ -231,20 +210,12 @@ case class Like(str: Expression, pattern: Expression, escape: Expression)
   """,
   since = "1.0.0")
 // scalastyle:on line.contains.tab
-case class RLike(left: Expression, right: Expression)
-  extends BinaryExpression with StringRegexExpression {
-
-  override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
-
-  override def str: Expression = left
-  override def pattern: Expression = right
+case class RLike(left: Expression, right: Expression) extends StringRegexExpression {
 
   override def escape(v: String): String = v
   override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
   override def toString: String = s"$left RLIKE $right"
 
-  protected override def nullSafeEval(input1: Any, input2: Any): Any = nullSafeMatch(input1, input2)
-
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val patternClass = classOf[Pattern].getName
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 6fc65e1..62e5685 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1392,9 +1392,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
             throw new ParseException("Invalid escape string." +
               "Escape string must contains only one character.", ctx)
           }
-          str
+          str.charAt(0)
         }.getOrElse('\\')
-        invertIfNotDefined(Like(e, expression(ctx.pattern), Literal(escapeChar)))
+        invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar))
       case SqlBaseParser.RLIKE =>
         invertIfNotDefined(RLike(e, expression(ctx.pattern)))
       case SqlBaseParser.NULL if ctx.NOT != null =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 9e9d8c3..6012678 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -3560,21 +3560,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
       Seq(Row(1)))
   }
 
-  test("the like function with the escape parameter") {
-    val df = Seq(("abc", "a_c", "!")).toDF("str", "pattern", "escape")
-    checkAnswer(df.selectExpr("like(str, pattern, '@')"), Row(true))
-
-    val longEscapeError = intercept[AnalysisException] {
-      df.selectExpr("like(str, pattern, '@%')").collect()
-    }.getMessage
-    assert(longEscapeError.contains("The 'escape' parameter must be a string literal of one char"))
-
-    val nonFoldableError = intercept[AnalysisException] {
-      df.selectExpr("like(str, pattern, escape)").collect()
-    }.getMessage
-    assert(nonFoldableError.contains("The 'escape' parameter must be a string literal"))
-  }
-
   test("SPARK-29462: Empty array of NullType for array function with no arguments") {
     Seq((true, StringType), (false, NullType)).foreach {
       case (arrayDefaultToString, expectedType) =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org