You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/11/03 22:20:39 UTC

git commit: [SPARK-4202][SQL] Simple DSL support for Scala UDF

Repository: spark
Updated Branches:
  refs/heads/master 24544fbce -> c238fb423


[SPARK-4202][SQL] Simple DSL support for Scala UDF

This feature is based on an offline discussion with mengxr, hopefully can be useful for the new MLlib pipeline API.

For the following test snippet

```scala
case class KeyValue(key: Int, value: String)
val testData = sc.parallelize(1 to 10).map(i => KeyValue(i, i.toString)).toSchemaRDD
def foo(a: Int, b: String) => a.toString + b
```

the newly introduced DSL enables the following syntax

```scala
import org.apache.spark.sql.catalyst.dsl._
testData.select(Star(None), foo.call('key, 'value) as 'result)
```

which is equivalent to

```scala
testData.registerTempTable("testData")
sqlContext.registerFunction("foo", foo)
sql("SELECT *, foo(key, value) AS result FROM testData")
```

Author: Cheng Lian <li...@databricks.com>

Closes #3067 from liancheng/udf-dsl and squashes the following commits:

f132818 [Cheng Lian] Adds DSL support for Scala UDF


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c238fb42
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c238fb42
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c238fb42

Branch: refs/heads/master
Commit: c238fb423d1011bd1b1e6201d769b72e52664fc6
Parents: 24544fb
Author: Cheng Lian <li...@databricks.com>
Authored: Mon Nov 3 13:20:33 2014 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Mon Nov 3 13:20:33 2014 -0800

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/dsl/package.scala | 59 ++++++++++++++++++++
 .../org/apache/spark/sql/DslQuerySuite.scala    | 17 ++++--
 2 files changed, 72 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c238fb42/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 7e6d770..3314e15 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
 import org.apache.spark.sql.catalyst.types.decimal.Decimal
 
 import scala.language.implicitConversions
+import scala.reflect.runtime.universe.{TypeTag, typeTag}
 
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.expressions._
@@ -285,4 +286,62 @@ package object dsl {
       def writeToFile(path: String) = WriteToFile(path, logicalPlan)
     }
   }
+
+  case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) {
+    def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args)
+  }
+
+  // scalastyle:off
+  /** functionToUdfBuilder 1-22 were generated by this script
+
+    (1 to 22).map { x =>
+      val argTypes = Seq.fill(x)("_").mkString(", ")
+      s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]) = ScalaUdfBuilder(func)"
+    }
+  */
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+  implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+  // scalastyle:on
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c238fb42/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 45e58af..e70ad89 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.sql
 
 import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.test._
 
 /* Implicits */
-import TestSQLContext._
+import org.apache.spark.sql.catalyst.dsl._
+import org.apache.spark.sql.test.TestSQLContext._
 
 class DslQuerySuite extends QueryTest {
-  import TestData._
+  import org.apache.spark.sql.TestData._
 
   test("table scan") {
     checkAnswer(
@@ -216,4 +215,14 @@ class DslQuerySuite extends QueryTest {
       (4, "d") :: Nil)
     checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
   }
+
+  test("udf") {
+    val foo = (a: Int, b: String) => a.toString + b
+
+    checkAnswer(
+      // SELECT *, foo(key, value) FROM testData
+      testData.select(Star(None), foo.call('key, 'value)).limit(3),
+      (1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org