You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/07/09 18:23:46 UTC
spark git commit: [SPARK-8830] [SQL] native levenshtein distance
Repository: spark
Updated Branches:
refs/heads/master 23448a9e9 -> a1964e9d9
[SPARK-8830] [SQL] native levenshtein distance
Jira: https://issues.apache.org/jira/browse/SPARK-8830
rxin and HuJiayin can you have a look on it.
Author: Tarek Auel <ta...@googlemail.com>
Closes #7236 from tarekauel/native-levenshtein-distance and squashes the following commits:
ee4c4de [Tarek Auel] [SPARK-8830] implemented improvement proposals
c252e71 [Tarek Auel] [SPARK-8830] removed chartAt; use unsafe method for byte array comparison
ddf2222 [Tarek Auel] Merge branch 'master' into native-levenshtein-distance
179920a [Tarek Auel] [SPARK-8830] added description
5e9ed54 [Tarek Auel] [SPARK-8830] removed StringUtils import
dce4308 [Tarek Auel] [SPARK-8830] native levenshtein distance
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a1964e9d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a1964e9d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a1964e9d
Branch: refs/heads/master
Commit: a1964e9d902bb31f001893da8bc81f6dce08c908
Parents: 23448a9
Author: Tarek Auel <ta...@googlemail.com>
Authored: Thu Jul 9 09:22:24 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Thu Jul 9 09:23:35 2015 -0700
----------------------------------------------------------------------
.../catalyst/expressions/stringOperations.scala | 9 ++-
.../expressions/StringFunctionsSuite.scala | 5 ++
.../apache/spark/unsafe/types/UTF8String.java | 66 +++++++++++++++++++-
.../spark/unsafe/types/UTF8StringSuite.java | 24 +++++++
4 files changed, 97 insertions(+), 7 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a1964e9d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 47fc7cd..57f4364 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -284,13 +284,12 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
override def dataType: DataType = IntegerType
- protected override def nullSafeEval(input1: Any, input2: Any): Any =
- StringUtils.getLevenshteinDistance(input1.toString, input2.toString)
+ protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any =
+ leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String])
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val stringUtils = classOf[StringUtils].getName
- defineCodeGen(ctx, ev, (left, right) =>
- s"$stringUtils.getLevenshteinDistance($left.toString(), $right.toString())")
+ nullSafeCodeGen(ctx, ev, (left, right) =>
+ s"${ev.primitive} = $left.levenshteinDistance($right);")
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a1964e9d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
index 1efbe1a..69bef1c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
@@ -282,5 +282,10 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Levenshtein(Literal("abc"), Literal("abc")), 0)
checkEvaluation(Levenshtein(Literal("kitten"), Literal("sitting")), 3)
checkEvaluation(Levenshtein(Literal("frog"), Literal("fog")), 1)
+ // scalastyle:off
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+ checkEvaluation(Levenshtein(Literal("千世"), Literal("fog")), 3)
+ checkEvaluation(Levenshtein(Literal("世界千世"), Literal("大a界b")), 4)
+ // scalastyle:on
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a1964e9d/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index d2a2509..847d80a 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -99,8 +99,6 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
/**
* Returns the number of code points in it.
- *
- * This is only used by Substring() when `start` is negative.
*/
public int numChars() {
int len = 0;
@@ -254,6 +252,70 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
}
}
+ /**
+ * Levenshtein distance is a metric for measuring the distance of two strings. The distance is
+ * defined by the minimum number of single-character edits (i.e. insertions, deletions or
+ * substitutions) that are required to change one of the strings into the other.
+ */
+ public int levenshteinDistance(UTF8String other) {
+ // Implementation adopted from org.apache.common.lang3.StringUtils.getLevenshteinDistance
+
+ int n = numChars();
+ int m = other.numChars();
+
+ if (n == 0) {
+ return m;
+ } else if (m == 0) {
+ return n;
+ }
+
+ UTF8String s, t;
+
+ if (n <= m) {
+ s = this;
+ t = other;
+ } else {
+ s = other;
+ t = this;
+ int swap;
+ swap = n;
+ n = m;
+ m = swap;
+ }
+
+ int p[] = new int[n + 1];
+ int d[] = new int[n + 1];
+ int swap[];
+
+ int i, i_bytes, j, j_bytes, num_bytes_j, cost;
+
+ for (i = 0; i <= n; i++) {
+ p[i] = i;
+ }
+
+ for (j = 0, j_bytes = 0; j < m; j_bytes += num_bytes_j, j++) {
+ num_bytes_j = numBytesForFirstByte(t.getByte(j_bytes));
+ d[0] = j + 1;
+
+ for (i = 0, i_bytes = 0; i < n; i_bytes += numBytesForFirstByte(s.getByte(i_bytes)), i++) {
+ if (s.getByte(i_bytes) != t.getByte(j_bytes) ||
+ num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) {
+ cost = 1;
+ } else {
+ cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base,
+ s.offset + i_bytes, num_bytes_j)) ? 0 : 1;
+ }
+ d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost);
+ }
+
+ swap = p;
+ p = d;
+ d = swap;
+ }
+
+ return p[n];
+ }
+
@Override
public int hashCode() {
int result = 1;
http://git-wip-us.apache.org/repos/asf/spark/blob/a1964e9d/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
----------------------------------------------------------------------
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 8ec69eb..fb463ba 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -128,4 +128,28 @@ public class UTF8StringSuite {
assertEquals(fromString("数据砖头").substring(3, 5), fromString("头"));
assertEquals(fromString("ߵ梷").substring(0, 2), fromString("ߵ梷"));
}
+
+ @Test
+ public void levenshteinDistance() {
+ assertEquals(
+ UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("")), 0);
+ assertEquals(
+ UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("a")), 1);
+ assertEquals(
+ UTF8String.fromString("aaapppp").levenshteinDistance(UTF8String.fromString("")), 7);
+ assertEquals(
+ UTF8String.fromString("frog").levenshteinDistance(UTF8String.fromString("fog")), 1);
+ assertEquals(
+ UTF8String.fromString("fly").levenshteinDistance(UTF8String.fromString("ant")),3);
+ assertEquals(
+ UTF8String.fromString("elephant").levenshteinDistance(UTF8String.fromString("hippo")), 7);
+ assertEquals(
+ UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("elephant")), 7);
+ assertEquals(
+ UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("zzzzzzzz")), 8);
+ assertEquals(
+ UTF8String.fromString("hello").levenshteinDistance(UTF8String.fromString("hallo")),1);
+ assertEquals(
+ UTF8String.fromString("世界千世").levenshteinDistance(UTF8String.fromString("千a世b")),4);
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org