You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/07/15 23:12:02 UTC

git commit: SPARK-2407: Added internal implementation of SQL SUBSTR()

Repository: spark
Updated Branches:
  refs/heads/master 8af46d584 -> 61de65bc6


SPARK-2407: Added internal implementation of SQL SUBSTR()

This replaces the Hive UDF for SUBSTR(ING) with an implementation in Catalyst
and adds tests to verify correct operation.

Author: William Benton <wi...@redhat.com>

Closes #1359 from willb/internalSqlSubstring and squashes the following commits:

ccedc47 [William Benton] Fixed too-long line.
a30a037 [William Benton] replace view bounds with implicit parameters
ec35c80 [William Benton] Adds fixes from review:
4f3bfdb [William Benton] Added internal implementation of SQL SUBSTR()


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/61de65bc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/61de65bc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/61de65bc

Branch: refs/heads/master
Commit: 61de65bc69f9a5fc396b76713193c6415436d452
Parents: 8af46d5
Author: William Benton <wi...@redhat.com>
Authored: Tue Jul 15 14:11:57 2014 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Jul 15 14:11:57 2014 -0700

----------------------------------------------------------------------
 .../catalyst/expressions/stringOperations.scala | 77 +++++++++++++++++++-
 .../expressions/ExpressionEvaluationSuite.scala | 49 +++++++++++++
 .../org/apache/spark/sql/hive/HiveQl.scala      |  5 ++
 3 files changed, 128 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/61de65bc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index b385053..4bd7bf5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions
 
 import java.util.regex.Pattern
 
-import org.apache.spark.sql.catalyst.types.DataType
-import org.apache.spark.sql.catalyst.types.StringType
-import org.apache.spark.sql.catalyst.types.BooleanType
+import scala.collection.IndexedSeqOptimized
+
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedException
+import org.apache.spark.sql.catalyst.types.{BinaryType, BooleanType, DataType, StringType}
 
 trait StringRegexExpression {
   self: BinaryExpression =>
@@ -205,3 +207,72 @@ case class EndsWith(left: Expression, right: Expression)
     extends BinaryExpression with StringComparison {
   def compare(l: String, r: String) = l.endsWith(r)
 }
+
+/**
+ * A function that takes a substring of its first argument starting at a given position.
+ * Defined for String and Binary types.
+ */
+case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression {
+  
+  type EvaluatedType = Any
+  
+  def nullable: Boolean = true
+  def dataType: DataType = {
+    if (!resolved) {
+      throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved")
+    }
+    if (str.dataType == BinaryType) str.dataType else StringType
+  }
+  
+  def references = children.flatMap(_.references).toSet
+  
+  override def children = str :: pos :: len :: Nil
+  
+  @inline
+  def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int)
+      (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = {
+    val len = str.length
+    // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
+    // negative indices for start positions. If a start index i is greater than 0, it 
+    // refers to element i-1 in the sequence. If a start index i is less than 0, it refers
+    // to the -ith element before the end of the sequence. If a start index i is 0, it
+    // refers to the first element.
+    
+    val start = startPos match {
+      case pos if pos > 0 => pos - 1
+      case neg if neg < 0 => len + neg
+      case _ => 0
+    }
+    
+    val end = sliceLen match {
+      case max if max == Integer.MAX_VALUE => max
+      case x => start + x
+    }
+      
+    str.slice(start, end)    
+  }
+  
+  override def eval(input: Row): Any = {
+    val string = str.eval(input)
+
+    val po = pos.eval(input)
+    val ln = len.eval(input)
+    
+    if ((string == null) || (po == null) || (ln == null)) {
+      null
+    } else {
+      val start = po.asInstanceOf[Int]
+      val length = ln.asInstanceOf[Int] 
+      
+      string match {
+        case ba: Array[Byte] => slice(ba, start, length)
+        case other => slice(other.toString, start, length)
+      }
+    }
+  }
+  
+  override def toString = len match {
+    case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)"
+    case _ => s"SUBSTR($str, $pos, $len)"
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/61de65bc/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 84d7281..f1d7aed 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -466,5 +466,54 @@ class ExpressionEvaluationSuite extends FunSuite {
     checkEvaluation(c1 === c2, false, row)
     checkEvaluation(c1 !== c2, true, row)
   }
+  
+  test("Substring") {
+    val row = new GenericRow(Array[Any]("example", "example".toArray.map(_.toByte)))
+    
+    val s = 'a.string.at(0)
+    
+    // substring from zero position with less-than-full length
+    checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)), "ex", row)
+    checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(2, IntegerType)), "ex", row)
+
+    // substring from zero position with full length
+    checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(7, IntegerType)), "example", row)
+    checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(7, IntegerType)), "example", row)
+
+    // substring from zero position with greater-than-full length
+    checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(100, IntegerType)), "example", row)
+    checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(100, IntegerType)), "example", row)
+
+    // substring from nonzero position with less-than-full length
+    checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(2, IntegerType)), "xa", row)
+
+    // substring from nonzero position with full length
+    checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(6, IntegerType)), "xample", row)
+
+    // substring from nonzero position with greater-than-full length
+    checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(100, IntegerType)), "xample", row)
+
+    // zero-length substring (within string bounds)
+    checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(0, IntegerType)), "", row)
+
+    // zero-length substring (beyond string bounds)
+    checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), "", row)
+
+    // substring(null, _, _) -> null
+    checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), null, new GenericRow(Array[Any](null)))
+    
+    // substring(_, null, _) -> null
+    checkEvaluation(Substring(s, Literal(null, IntegerType), Literal(4, IntegerType)), null, row)
+
+    // substring(_, _, null) -> null
+    checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(null, IntegerType)), null, row)
+
+    // 2-arg substring from zero position
+    checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row)
+    checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row)
+
+    // 2-arg substring from nonzero position
+    checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "xample", row)
+  }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/61de65bc/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 56aa27a..300e249 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -860,6 +860,7 @@ private[hive] object HiveQl {
   val BETWEEN = "(?i)BETWEEN".r
   val WHEN = "(?i)WHEN".r
   val CASE = "(?i)CASE".r
+  val SUBSTR = "(?i)SUBSTR(?:ING)?".r
 
   protected def nodeToExpr(node: Node): Expression = node match {
     /* Attribute References */
@@ -984,6 +985,10 @@ private[hive] object HiveQl {
 
     /* Other functions */
     case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand
+    case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => 
+      Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType))
+    case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => 
+      Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length))
 
     /* UDFs - Must be last otherwise will preempt built in functions */
     case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>