You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/07/09 18:23:46 UTC

spark git commit: [SPARK-8830] [SQL] native levenshtein distance

Repository: spark
Updated Branches:
  refs/heads/master 23448a9e9 -> a1964e9d9


[SPARK-8830] [SQL] native levenshtein distance

Jira: https://issues.apache.org/jira/browse/SPARK-8830

rxin and HuJiayin can you have a look on it.

Author: Tarek Auel <ta...@googlemail.com>

Closes #7236 from tarekauel/native-levenshtein-distance and squashes the following commits:

ee4c4de [Tarek Auel] [SPARK-8830] implemented improvement proposals
c252e71 [Tarek Auel] [SPARK-8830] removed chartAt; use unsafe method for byte array comparison
ddf2222 [Tarek Auel] Merge branch 'master' into native-levenshtein-distance
179920a [Tarek Auel] [SPARK-8830] added description
5e9ed54 [Tarek Auel] [SPARK-8830] removed StringUtils import
dce4308 [Tarek Auel] [SPARK-8830] native levenshtein distance


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a1964e9d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a1964e9d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a1964e9d

Branch: refs/heads/master
Commit: a1964e9d902bb31f001893da8bc81f6dce08c908
Parents: 23448a9
Author: Tarek Auel <ta...@googlemail.com>
Authored: Thu Jul 9 09:22:24 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Thu Jul 9 09:23:35 2015 -0700

----------------------------------------------------------------------
 .../catalyst/expressions/stringOperations.scala |  9 ++-
 .../expressions/StringFunctionsSuite.scala      |  5 ++
 .../apache/spark/unsafe/types/UTF8String.java   | 66 +++++++++++++++++++-
 .../spark/unsafe/types/UTF8StringSuite.java     | 24 +++++++
 4 files changed, 97 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a1964e9d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 47fc7cd..57f4364 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -284,13 +284,12 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
 
   override def dataType: DataType = IntegerType
 
-  protected override def nullSafeEval(input1: Any, input2: Any): Any =
-    StringUtils.getLevenshteinDistance(input1.toString, input2.toString)
+  protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any =
+    leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String])
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    val stringUtils = classOf[StringUtils].getName
-    defineCodeGen(ctx, ev, (left, right) =>
-      s"$stringUtils.getLevenshteinDistance($left.toString(), $right.toString())")
+    nullSafeCodeGen(ctx, ev, (left, right) =>
+      s"${ev.primitive} = $left.levenshteinDistance($right);")
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a1964e9d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
index 1efbe1a..69bef1c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
@@ -282,5 +282,10 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(Levenshtein(Literal("abc"), Literal("abc")), 0)
     checkEvaluation(Levenshtein(Literal("kitten"), Literal("sitting")), 3)
     checkEvaluation(Levenshtein(Literal("frog"), Literal("fog")), 1)
+    // scalastyle:off
+    // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+    checkEvaluation(Levenshtein(Literal("千世"), Literal("fog")), 3)
+    checkEvaluation(Levenshtein(Literal("世界千世"), Literal("大a界b")), 4)
+    // scalastyle:on
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a1964e9d/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index d2a2509..847d80a 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -99,8 +99,6 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
 
   /**
    * Returns the number of code points in it.
-   *
-   * This is only used by Substring() when `start` is negative.
    */
   public int numChars() {
     int len = 0;
@@ -254,6 +252,70 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
     }
   }
 
+  /**
+   * Levenshtein distance is a metric for measuring the distance of two strings. The distance is
+   * defined by the minimum number of single-character edits (i.e. insertions, deletions or
+   * substitutions) that are required to change one of the strings into the other.
+   */
+  public int levenshteinDistance(UTF8String other) {
+    // Implementation adopted from org.apache.common.lang3.StringUtils.getLevenshteinDistance
+
+    int n = numChars();
+    int m = other.numChars();
+
+    if (n == 0) {
+      return m;
+    } else if (m == 0) {
+      return n;
+    }
+
+    UTF8String s, t;
+
+    if (n <= m) {
+      s = this;
+      t = other;
+    } else {
+      s = other;
+      t = this;
+      int swap;
+      swap = n;
+      n = m;
+      m = swap;
+    }
+
+    int p[] = new int[n + 1];
+    int d[] = new int[n + 1];
+    int swap[];
+
+    int i, i_bytes, j, j_bytes, num_bytes_j, cost;
+
+    for (i = 0; i <= n; i++) {
+      p[i] = i;
+    }
+
+    for (j = 0, j_bytes = 0; j < m; j_bytes += num_bytes_j, j++) {
+      num_bytes_j = numBytesForFirstByte(t.getByte(j_bytes));
+      d[0] = j + 1;
+
+      for (i = 0, i_bytes = 0; i < n; i_bytes += numBytesForFirstByte(s.getByte(i_bytes)), i++) {
+        if (s.getByte(i_bytes) != t.getByte(j_bytes) ||
+              num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) {
+          cost = 1;
+        } else {
+          cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base,
+              s.offset + i_bytes, num_bytes_j)) ? 0 : 1;
+        }
+        d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost);
+      }
+
+      swap = p;
+      p = d;
+      d = swap;
+    }
+
+    return p[n];
+  }
+
   @Override
   public int hashCode() {
     int result = 1;

http://git-wip-us.apache.org/repos/asf/spark/blob/a1964e9d/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
----------------------------------------------------------------------
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 8ec69eb..fb463ba 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -128,4 +128,28 @@ public class UTF8StringSuite {
     assertEquals(fromString("数据砖头").substring(3, 5), fromString("头"));
     assertEquals(fromString("ߵ梷").substring(0, 2), fromString("ߵ梷"));
   }
+  
+  @Test
+  public void levenshteinDistance() {
+    assertEquals(
+        UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("")), 0);
+    assertEquals(
+        UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("a")), 1);
+    assertEquals(
+        UTF8String.fromString("aaapppp").levenshteinDistance(UTF8String.fromString("")), 7);
+    assertEquals(
+        UTF8String.fromString("frog").levenshteinDistance(UTF8String.fromString("fog")), 1);
+    assertEquals(
+        UTF8String.fromString("fly").levenshteinDistance(UTF8String.fromString("ant")),3);
+    assertEquals(
+        UTF8String.fromString("elephant").levenshteinDistance(UTF8String.fromString("hippo")), 7);
+    assertEquals(
+        UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("elephant")), 7);
+    assertEquals(
+        UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("zzzzzzzz")), 8);
+    assertEquals(
+        UTF8String.fromString("hello").levenshteinDistance(UTF8String.fromString("hallo")),1);
+    assertEquals(
+        UTF8String.fromString("世界千世").levenshteinDistance(UTF8String.fromString("千a世b")),4);
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org