You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/01/28 01:08:43 UTC
[2/5] spark git commit: [SPARK-5097][SQL] DataFrame
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala
new file mode 100644
index 0000000..29c3d26
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/dsl/package.scala
@@ -0,0 +1,495 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.sql.{Timestamp, Date}
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.{TypeTag, typeTag}
+
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.DataType
+
+
+package object dsl {
+
+ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
+
+ /** Converts $"col name" into an [[Column]]. */
+ implicit class StringToColumn(val sc: StringContext) extends AnyVal {
+ def $(args: Any*): ColumnName = {
+ new ColumnName(sc.s(args :_*))
+ }
+ }
+
+ private[this] implicit def toColumn(expr: Expression): Column = new Column(expr)
+
+ def sum(e: Column): Column = Sum(e.expr)
+ def sumDistinct(e: Column): Column = SumDistinct(e.expr)
+ def count(e: Column): Column = Count(e.expr)
+
+ @scala.annotation.varargs
+ def countDistinct(expr: Column, exprs: Column*): Column =
+ CountDistinct((expr +: exprs).map(_.expr))
+
+ def avg(e: Column): Column = Average(e.expr)
+ def first(e: Column): Column = First(e.expr)
+ def last(e: Column): Column = Last(e.expr)
+ def min(e: Column): Column = Min(e.expr)
+ def max(e: Column): Column = Max(e.expr)
+ def upper(e: Column): Column = Upper(e.expr)
+ def lower(e: Column): Column = Lower(e.expr)
+ def sqrt(e: Column): Column = Sqrt(e.expr)
+ def abs(e: Column): Column = Abs(e.expr)
+
+ // scalastyle:off
+
+ object literals {
+
+ implicit def booleanToLiteral(b: Boolean): Column = Literal(b)
+
+ implicit def byteToLiteral(b: Byte): Column = Literal(b)
+
+ implicit def shortToLiteral(s: Short): Column = Literal(s)
+
+ implicit def intToLiteral(i: Int): Column = Literal(i)
+
+ implicit def longToLiteral(l: Long): Column = Literal(l)
+
+ implicit def floatToLiteral(f: Float): Column = Literal(f)
+
+ implicit def doubleToLiteral(d: Double): Column = Literal(d)
+
+ implicit def stringToLiteral(s: String): Column = Literal(s)
+
+ implicit def dateToLiteral(d: Date): Column = Literal(d)
+
+ implicit def bigDecimalToLiteral(d: BigDecimal): Column = Literal(d.underlying())
+
+ implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Column = Literal(d)
+
+ implicit def timestampToLiteral(t: Timestamp): Column = Literal(t)
+
+ implicit def binaryToLiteral(a: Array[Byte]): Column = Literal(a)
+ }
+
+
+ /* Use the following code to generate:
+ (0 to 22).map { x =>
+ val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"})
+ val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _)
+ val args = (1 to x).map(i => s"arg$i: Column").mkString(", ")
+ val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ")
+ println(s"""
+ /**
+ * Call a Scala function of ${x} arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[$typeTags](f: Function$x[$types]${if (args.length > 0) ", " + args else ""}): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq($argsInUdf))
+ }""")
+ }
+
+ (0 to 22).map { x =>
+ val args = (1 to x).map(i => s"arg$i: Column").mkString(", ")
+ val fTypes = Seq.fill(x + 1)("_").mkString(", ")
+ val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ")
+ println(s"""
+ /**
+ * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = {
+ ScalaUdf(f, returnType, Seq($argsInUdf))
+ }""")
+ }
+ }
+ */
+ /**
+ * Call a Scala function of 0 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag](f: Function0[RT]): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq())
+ }
+
+ /**
+ * Call a Scala function of 1 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT], arg1: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr))
+ }
+
+ /**
+ * Call a Scala function of 2 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT], arg1: Column, arg2: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr))
+ }
+
+ /**
+ * Call a Scala function of 3 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT], arg1: Column, arg2: Column, arg3: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr))
+ }
+
+ /**
+ * Call a Scala function of 4 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr))
+ }
+
+ /**
+ * Call a Scala function of 5 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr))
+ }
+
+ /**
+ * Call a Scala function of 6 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr))
+ }
+
+ /**
+ * Call a Scala function of 7 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr))
+ }
+
+ /**
+ * Call a Scala function of 8 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr))
+ }
+
+ /**
+ * Call a Scala function of 9 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr))
+ }
+
+ /**
+ * Call a Scala function of 10 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr))
+ }
+
+ /**
+ * Call a Scala function of 11 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](f: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr))
+ }
+
+ /**
+ * Call a Scala function of 12 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](f: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr))
+ }
+
+ /**
+ * Call a Scala function of 13 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](f: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr))
+ }
+
+ /**
+ * Call a Scala function of 14 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](f: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr))
+ }
+
+ /**
+ * Call a Scala function of 15 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](f: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr))
+ }
+
+ /**
+ * Call a Scala function of 16 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](f: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr))
+ }
+
+ /**
+ * Call a Scala function of 17 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](f: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr))
+ }
+
+ /**
+ * Call a Scala function of 18 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](f: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr))
+ }
+
+ /**
+ * Call a Scala function of 19 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](f: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr))
+ }
+
+ /**
+ * Call a Scala function of 20 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](f: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr))
+ }
+
+ /**
+ * Call a Scala function of 21 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](f: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr))
+ }
+
+ /**
+ * Call a Scala function of 22 arguments as user-defined function (UDF), and automatically
+ * infer the data types based on the function's signature.
+ */
+ def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](f: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = {
+ ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr))
+ }
+
+ //////////////////////////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Call a Scala function of 0 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function0[_], returnType: DataType): Column = {
+ ScalaUdf(f, returnType, Seq())
+ }
+
+ /**
+ * Call a Scala function of 1 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr))
+ }
+
+ /**
+ * Call a Scala function of 2 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr))
+ }
+
+ /**
+ * Call a Scala function of 3 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr))
+ }
+
+ /**
+ * Call a Scala function of 4 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr))
+ }
+
+ /**
+ * Call a Scala function of 5 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr))
+ }
+
+ /**
+ * Call a Scala function of 6 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr))
+ }
+
+ /**
+ * Call a Scala function of 7 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr))
+ }
+
+ /**
+ * Call a Scala function of 8 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr))
+ }
+
+ /**
+ * Call a Scala function of 9 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr))
+ }
+
+ /**
+ * Call a Scala function of 10 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr))
+ }
+
+ /**
+ * Call a Scala function of 11 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr))
+ }
+
+ /**
+ * Call a Scala function of 12 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr))
+ }
+
+ /**
+ * Call a Scala function of 13 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr))
+ }
+
+ /**
+ * Call a Scala function of 14 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr))
+ }
+
+ /**
+ * Call a Scala function of 15 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr))
+ }
+
+ /**
+ * Call a Scala function of 16 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr))
+ }
+
+ /**
+ * Call a Scala function of 17 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr))
+ }
+
+ /**
+ * Call a Scala function of 18 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr))
+ }
+
+ /**
+ * Call a Scala function of 19 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr))
+ }
+
+ /**
+ * Call a Scala function of 20 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr))
+ }
+
+ /**
+ * Call a Scala function of 21 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr))
+ }
+
+ /**
+ * Call a Scala function of 22 arguments as user-defined function (UDF). This requires
+ * you to specify the return data type.
+ */
+ def callUDF(f: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = {
+ ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr))
+ }
+
+ // scalastyle:on
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 52a31f0..6fba76c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext}
+import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Row, Attribute}
import org.apache.spark.sql.catalyst.plans.logical
@@ -137,7 +137,9 @@ case class CacheTableCommand(
isLazy: Boolean) extends RunnableCommand {
override def run(sqlContext: SQLContext) = {
- plan.foreach(p => new SchemaRDD(sqlContext, p).registerTempTable(tableName))
+ plan.foreach { logicalPlan =>
+ sqlContext.registerRDDAsTable(new DataFrame(sqlContext, logicalPlan), tableName)
+ }
sqlContext.cacheTable(tableName)
if (!isLazy) {
@@ -159,7 +161,7 @@ case class CacheTableCommand(
case class UncacheTableCommand(tableName: String) extends RunnableCommand {
override def run(sqlContext: SQLContext) = {
- sqlContext.table(tableName).unpersist()
+ sqlContext.table(tableName).unpersist(blocking = false)
Seq.empty[Row]
}
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 4d7e338..aeb0960 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.HashSet
import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext._
-import org.apache.spark.sql.{SchemaRDD, Row}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.types._
@@ -42,7 +42,7 @@ package object debug {
* Augments SchemaRDDs with debug methods.
*/
@DeveloperApi
- implicit class DebugQuery(query: SchemaRDD) {
+ implicit class DebugQuery(query: DataFrame) {
def debug(): Unit = {
val plan = query.queryExecution.executedPlan
val visited = new collection.mutable.HashSet[TreeNodeRef]()
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/main/scala/org/apache/spark/sql/package.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index 6dd39be..7c49b52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -37,5 +37,5 @@ package object sql {
* Converts a logical plan into zero or more SparkPlans.
*/
@DeveloperApi
- type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan]
+ protected[sql] type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan]
}
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
index 02ce1b3..0b312ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
-import org.apache.spark.sql.{SQLContext, SchemaRDD}
+import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.util
import org.apache.spark.util.Utils
@@ -100,7 +100,7 @@ trait ParquetTest {
*/
protected def withParquetRDD[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
- (f: SchemaRDD => Unit): Unit = {
+ (f: DataFrame => Unit): Unit = {
withParquetFile(data)(path => f(parquetFile(path)))
}
@@ -120,7 +120,7 @@ trait ParquetTest {
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withParquetRDD(data) { rdd =>
- rdd.registerTempTable(tableName)
+ sqlContext.registerRDDAsTable(rdd, tableName)
withTempTable(tableName)(f)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index 37853d4..d13f2ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -18,19 +18,18 @@
package org.apache.spark.sql.sources
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Row
-import org.apache.spark.sql._
+import org.apache.spark.sql.{Row, Strategy}
import org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution
/**
* A Strategy for planning scans over data sources defined using the sources API.
*/
private[sql] object DataSourceStrategy extends Strategy {
- def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) =>
pruneFilterProjectRaw(
l,
@@ -112,23 +111,26 @@ private[sql] object DataSourceStrategy extends Strategy {
}
}
+ /** Turn Catalyst [[Expression]]s into data source [[Filter]]s. */
protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect {
- case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v)
- case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v)
+ case expressions.EqualTo(a: Attribute, expressions.Literal(v, _)) => EqualTo(a.name, v)
+ case expressions.EqualTo(expressions.Literal(v, _), a: Attribute) => EqualTo(a.name, v)
- case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v)
- case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v)
+ case expressions.GreaterThan(a: Attribute, expressions.Literal(v, _)) => GreaterThan(a.name, v)
+ case expressions.GreaterThan(expressions.Literal(v, _), a: Attribute) => LessThan(a.name, v)
- case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v)
- case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v)
+ case expressions.LessThan(a: Attribute, expressions.Literal(v, _)) => LessThan(a.name, v)
+ case expressions.LessThan(expressions.Literal(v, _), a: Attribute) => GreaterThan(a.name, v)
- case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
+ case expressions.GreaterThanOrEqual(a: Attribute, expressions.Literal(v, _)) =>
GreaterThanOrEqual(a.name, v)
- case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
+ case expressions.GreaterThanOrEqual(expressions.Literal(v, _), a: Attribute) =>
LessThanOrEqual(a.name, v)
- case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v)
- case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v)
+ case expressions.LessThanOrEqual(a: Attribute, expressions.Literal(v, _)) =>
+ LessThanOrEqual(a.name, v)
+ case expressions.LessThanOrEqual(expressions.Literal(v, _), a: Attribute) =>
+ GreaterThanOrEqual(a.name, v)
case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 171b816..b4af91a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.sources
import scala.language.implicitConversions
import org.apache.spark.Logging
-import org.apache.spark.sql.{SchemaRDD, SQLContext}
+import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.execution.RunnableCommand
@@ -225,7 +225,8 @@ private [sql] case class CreateTempTableUsing(
def run(sqlContext: SQLContext) = {
val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options)
- new SchemaRDD(sqlContext, LogicalRelation(resolved.relation)).registerTempTable(tableName)
+ sqlContext.registerRDDAsTable(
+ new DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
Seq.empty
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
index f9c0822..2564c84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.test
import scala.language.implicitConversions
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext}
+import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
/** A SQLContext that can be used for local testing. */
@@ -40,8 +40,8 @@ object TestSQLContext
* Turn a logical plan into a SchemaRDD. This should be removed once we have an easier way to
* construct SchemaRDD directly out of local data without relying on implicits.
*/
- protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): SchemaRDD = {
- new SchemaRDD(this, plan)
+ protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
+ new DataFrame(this, plan)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
index 9ff4047..e558893 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
@@ -61,7 +61,7 @@ public class JavaAPISuite implements Serializable {
}
}, DataTypes.IntegerType);
- Row result = sqlContext.sql("SELECT stringLengthTest('test')").first();
+ Row result = sqlContext.sql("SELECT stringLengthTest('test')").head();
assert(result.getInt(0) == 4);
}
@@ -81,7 +81,7 @@ public class JavaAPISuite implements Serializable {
}
}, DataTypes.IntegerType);
- Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").first();
+ Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head();
assert(result.getInt(0) == 9);
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
index 9e96738..badd00d 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
@@ -98,8 +98,8 @@ public class JavaApplySchemaSuite implements Serializable {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- SchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD.rdd(), schema);
- schemaRDD.registerTempTable("people");
+ DataFrame df = javaSqlCtx.applySchema(rowRDD.rdd(), schema);
+ df.registerTempTable("people");
Row[] actual = javaSqlCtx.sql("SELECT * FROM people").collect();
List<Row> expected = new ArrayList<Row>(2);
@@ -147,17 +147,17 @@ public class JavaApplySchemaSuite implements Serializable {
null,
"this is another simple string."));
- SchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD.rdd());
- StructType actualSchema1 = schemaRDD1.schema();
+ DataFrame df1 = javaSqlCtx.jsonRDD(jsonRDD.rdd());
+ StructType actualSchema1 = df1.schema();
Assert.assertEquals(expectedSchema, actualSchema1);
- schemaRDD1.registerTempTable("jsonTable1");
+ df1.registerTempTable("jsonTable1");
List<Row> actual1 = javaSqlCtx.sql("select * from jsonTable1").collectAsList();
Assert.assertEquals(expectedResult, actual1);
- SchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema);
- StructType actualSchema2 = schemaRDD2.schema();
+ DataFrame df2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema);
+ StructType actualSchema2 = df2.schema();
Assert.assertEquals(expectedSchema, actualSchema2);
- schemaRDD2.registerTempTable("jsonTable2");
+ df2.registerTempTable("jsonTable2");
List<Row> actual2 = javaSqlCtx.sql("select * from jsonTable2").collectAsList();
Assert.assertEquals(expectedResult, actual2);
}
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index cfc037c..3476315 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.storage.{StorageLevel, RDDBlockId}
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index afbfe21..a5848f2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -17,12 +17,10 @@
package org.apache.spark.sql
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.types._
/* Implicits */
-import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.test.TestSQLContext._
import scala.language.postfixOps
@@ -44,46 +42,46 @@ class DslQuerySuite extends QueryTest {
test("agg") {
checkAnswer(
- testData2.groupBy('a)('a, sum('b)),
+ testData2.groupBy("a").agg($"a", sum($"b")),
Seq(Row(1,3), Row(2,3), Row(3,3))
)
checkAnswer(
- testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)),
+ testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)),
Row(9)
)
checkAnswer(
- testData2.aggregate(sum('b)),
+ testData2.agg(sum('b)),
Row(9)
)
}
test("convert $\"attribute name\" into unresolved attribute") {
checkAnswer(
- testData.where($"key" === 1).select($"value"),
+ testData.where($"key" === Literal(1)).select($"value"),
Row("1"))
}
test("convert Scala Symbol 'attrname into unresolved attribute") {
checkAnswer(
- testData.where('key === 1).select('value),
+ testData.where('key === Literal(1)).select('value),
Row("1"))
}
test("select *") {
checkAnswer(
- testData.select(Star(None)),
+ testData.select($"*"),
testData.collect().toSeq)
}
test("simple select") {
checkAnswer(
- testData.where('key === 1).select('value),
+ testData.where('key === Literal(1)).select('value),
Row("1"))
}
test("select with functions") {
checkAnswer(
- testData.select(sum('value), avg('value), count(1)),
+ testData.select(sum('value), avg('value), count(Literal(1))),
Row(5050.0, 50.5, 100))
checkAnswer(
@@ -120,46 +118,19 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
arrayData.orderBy('data.getItem(0).asc),
- arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
checkAnswer(
arrayData.orderBy('data.getItem(0).desc),
- arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
checkAnswer(
arrayData.orderBy('data.getItem(1).asc),
- arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
checkAnswer(
arrayData.orderBy('data.getItem(1).desc),
- arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
- }
-
- test("partition wide sorting") {
- // 2 partitions totally, and
- // Partition #1 with values:
- // (1, 1)
- // (1, 2)
- // (2, 1)
- // Partition #2 with values:
- // (2, 2)
- // (3, 1)
- // (3, 2)
- checkAnswer(
- testData2.sortBy('a.asc, 'b.asc),
- Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
-
- checkAnswer(
- testData2.sortBy('a.asc, 'b.desc),
- Seq(Row(1,2), Row(1,1), Row(2,1), Row(2,2), Row(3,2), Row(3,1)))
-
- checkAnswer(
- testData2.sortBy('a.desc, 'b.desc),
- Seq(Row(2,1), Row(1,2), Row(1,1), Row(3,2), Row(3,1), Row(2,2)))
-
- checkAnswer(
- testData2.sortBy('a.desc, 'b.asc),
- Seq(Row(2,1), Row(1,1), Row(1,2), Row(3,1), Row(3,2), Row(2,2)))
+ arrayData.toDF.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
}
test("limit") {
@@ -176,71 +147,51 @@ class DslQuerySuite extends QueryTest {
mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
}
- test("SPARK-3395 limit distinct") {
- val filtered = TestData.testData2
- .distinct()
- .orderBy(SortOrder('a, Ascending), SortOrder('b, Ascending))
- .limit(1)
- .registerTempTable("onerow")
- checkAnswer(
- sql("select * from onerow inner join testData2 on onerow.a = testData2.a"),
- Row(1, 1, 1, 1) ::
- Row(1, 1, 1, 2) :: Nil)
- }
-
- test("SPARK-3858 generator qualifiers are discarded") {
- checkAnswer(
- arrayData.as('ad)
- .generate(Explode("data" :: Nil, 'data), alias = Some("ex"))
- .select("ex.data".attr),
- Seq(1, 2, 3, 2, 3, 4).map(Row(_)))
- }
-
test("average") {
checkAnswer(
- testData2.aggregate(avg('a)),
+ testData2.agg(avg('a)),
Row(2.0))
checkAnswer(
- testData2.aggregate(avg('a), sumDistinct('a)), // non-partial
+ testData2.agg(avg('a), sumDistinct('a)), // non-partial
Row(2.0, 6.0) :: Nil)
checkAnswer(
- decimalData.aggregate(avg('a)),
+ decimalData.agg(avg('a)),
Row(new java.math.BigDecimal(2.0)))
checkAnswer(
- decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial
+ decimalData.agg(avg('a), sumDistinct('a)), // non-partial
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
checkAnswer(
- decimalData.aggregate(avg('a cast DecimalType(10, 2))),
+ decimalData.agg(avg('a cast DecimalType(10, 2))),
Row(new java.math.BigDecimal(2.0)))
checkAnswer(
- decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
+ decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
}
test("null average") {
checkAnswer(
- testData3.aggregate(avg('b)),
+ testData3.agg(avg('b)),
Row(2.0))
checkAnswer(
- testData3.aggregate(avg('b), countDistinct('b)),
+ testData3.agg(avg('b), countDistinct('b)),
Row(2.0, 1))
checkAnswer(
- testData3.aggregate(avg('b), sumDistinct('b)), // non-partial
+ testData3.agg(avg('b), sumDistinct('b)), // non-partial
Row(2.0, 2.0))
}
test("zero average") {
checkAnswer(
- emptyTableData.aggregate(avg('a)),
+ emptyTableData.agg(avg('a)),
Row(null))
checkAnswer(
- emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial
+ emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial
Row(null, null))
}
@@ -248,28 +199,28 @@ class DslQuerySuite extends QueryTest {
assert(testData2.count() === testData2.map(_ => 1).count())
checkAnswer(
- testData2.aggregate(count('a), sumDistinct('a)), // non-partial
+ testData2.agg(count('a), sumDistinct('a)), // non-partial
Row(6, 6.0))
}
test("null count") {
checkAnswer(
- testData3.groupBy('a)('a, count('b)),
+ testData3.groupBy('a).agg('a, count('b)),
Seq(Row(1,0), Row(2, 1))
)
checkAnswer(
- testData3.groupBy('a)('a, count('a + 'b)),
+ testData3.groupBy('a).agg('a, count('a + 'b)),
Seq(Row(1,0), Row(2, 1))
)
checkAnswer(
- testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
+ testData3.agg(count('a), count('b), count(Literal(1)), countDistinct('a), countDistinct('b)),
Row(2, 1, 2, 2, 1)
)
checkAnswer(
- testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial
+ testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial
Row(1, 1, 2)
)
}
@@ -278,19 +229,19 @@ class DslQuerySuite extends QueryTest {
assert(emptyTableData.count() === 0)
checkAnswer(
- emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial
+ emptyTableData.agg(count('a), sumDistinct('a)), // non-partial
Row(0, null))
}
test("zero sum") {
checkAnswer(
- emptyTableData.aggregate(sum('a)),
+ emptyTableData.agg(sum('a)),
Row(null))
}
test("zero sum distinct") {
checkAnswer(
- emptyTableData.aggregate(sumDistinct('a)),
+ emptyTableData.agg(sumDistinct('a)),
Row(null))
}
@@ -320,7 +271,7 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
// SELECT *, foo(key, value) FROM testData
- testData.select(Star(None), foo.call('key, 'value)).limit(3),
+ testData.select($"*", callUDF(foo, 'key, 'value)).limit(3),
Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil
)
}
@@ -362,7 +313,7 @@ class DslQuerySuite extends QueryTest {
test("upper") {
checkAnswer(
lowerCaseData.select(upper('l)),
- ('a' to 'd').map(c => Row(c.toString.toUpperCase()))
+ ('a' to 'd').map(c => Row(c.toString.toUpperCase))
)
checkAnswer(
@@ -379,7 +330,7 @@ class DslQuerySuite extends QueryTest {
test("lower") {
checkAnswer(
upperCaseData.select(lower('L)),
- ('A' to 'F').map(c => Row(c.toString.toLowerCase()))
+ ('A' to 'F').map(c => Row(c.toString.toLowerCase))
)
checkAnswer(
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index cd36da7..7971372 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -20,19 +20,20 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
-import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.test.TestSQLContext._
+
class JoinSuite extends QueryTest with BeforeAndAfterEach {
// Ensures tables are loaded.
TestData
test("equi-join is hash-join") {
- val x = testData2.as('x)
- val y = testData2.as('y)
- val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed
+ val x = testData2.as("x")
+ val y = testData2.as("y")
+ val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.analyzed
val planned = planner.HashJoin(join)
assert(planned.size === 1)
}
@@ -105,17 +106,16 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("multiple-key equi-join is hash-join") {
- val x = testData2.as('x)
- val y = testData2.as('y)
- val join = x.join(y, Inner,
- Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed
+ val x = testData2.as("x")
+ val y = testData2.as("y")
+ val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.analyzed
val planned = planner.HashJoin(join)
assert(planned.size === 1)
}
test("inner join where, one match per row") {
checkAnswer(
- upperCaseData.join(lowerCaseData, Inner).where('n === 'N),
+ upperCaseData.join(lowerCaseData).where('n === 'N),
Seq(
Row(1, "A", 1, "a"),
Row(2, "B", 2, "b"),
@@ -126,7 +126,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("inner join ON, one match per row") {
checkAnswer(
- upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N"),
Seq(
Row(1, "A", 1, "a"),
Row(2, "B", 2, "b"),
@@ -136,10 +136,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("inner join, where, multiple matches") {
- val x = testData2.where('a === 1).as('x)
- val y = testData2.where('a === 1).as('y)
+ val x = testData2.where($"a" === Literal(1)).as("x")
+ val y = testData2.where($"a" === Literal(1)).as("y")
checkAnswer(
- x.join(y).where("x.a".attr === "y.a".attr),
+ x.join(y).where($"x.a" === $"y.a"),
Row(1,1,1,1) ::
Row(1,1,1,2) ::
Row(1,2,1,1) ::
@@ -148,22 +148,21 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("inner join, no matches") {
- val x = testData2.where('a === 1).as('x)
- val y = testData2.where('a === 2).as('y)
+ val x = testData2.where($"a" === Literal(1)).as("x")
+ val y = testData2.where($"a" === Literal(2)).as("y")
checkAnswer(
- x.join(y).where("x.a".attr === "y.a".attr),
+ x.join(y).where($"x.a" === $"y.a"),
Nil)
}
test("big inner join, 4 matches per row") {
val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData)
- val bigDataX = bigData.as('x)
- val bigDataY = bigData.as('y)
+ val bigDataX = bigData.as("x")
+ val bigDataY = bigData.as("y")
checkAnswer(
- bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr),
- testData.flatMap(
- row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
+ bigDataX.join(bigDataY).where($"x.key" === $"y.key"),
+ testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
test("cartisian product join") {
@@ -177,7 +176,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("left outer join") {
checkAnswer(
- upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N", "left"),
Row(1, "A", 1, "a") ::
Row(2, "B", 2, "b") ::
Row(3, "C", 3, "c") ::
@@ -186,7 +185,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(6, "F", null, null) :: Nil)
checkAnswer(
- upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > Literal(1), "left"),
Row(1, "A", null, null) ::
Row(2, "B", 2, "b") ::
Row(3, "C", 3, "c") ::
@@ -195,7 +194,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(6, "F", null, null) :: Nil)
checkAnswer(
- upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > Literal(1), "left"),
Row(1, "A", null, null) ::
Row(2, "B", 2, "b") ::
Row(3, "C", 3, "c") ::
@@ -204,7 +203,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(6, "F", null, null) :: Nil)
checkAnswer(
- upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)),
+ upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left"),
Row(1, "A", 1, "a") ::
Row(2, "B", 2, "b") ::
Row(3, "C", 3, "c") ::
@@ -240,7 +239,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
test("right outer join") {
checkAnswer(
- lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)),
+ lowerCaseData.join(upperCaseData, $"n" === $"N", "right"),
Row(1, "a", 1, "A") ::
Row(2, "b", 2, "B") ::
Row(3, "c", 3, "C") ::
@@ -248,7 +247,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 5, "E") ::
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)),
+ lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > Literal(1), "right"),
Row(null, null, 1, "A") ::
Row(2, "b", 2, "B") ::
Row(3, "c", 3, "C") ::
@@ -256,7 +255,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 5, "E") ::
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)),
+ lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > Literal(1), "right"),
Row(null, null, 1, "A") ::
Row(2, "b", 2, "B") ::
Row(3, "c", 3, "C") ::
@@ -264,7 +263,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 5, "E") ::
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)),
+ lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right"),
Row(1, "a", 1, "A") ::
Row(2, "b", 2, "B") ::
Row(3, "c", 3, "C") ::
@@ -299,14 +298,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("full outer join") {
- upperCaseData.where('N <= 4).registerTempTable("left")
- upperCaseData.where('N >= 3).registerTempTable("right")
+ upperCaseData.where('N <= Literal(4)).registerTempTable("left")
+ upperCaseData.where('N >= Literal(3)).registerTempTable("right")
val left = UnresolvedRelation(Seq("left"), None)
val right = UnresolvedRelation(Seq("right"), None)
checkAnswer(
- left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
+ left.join(right, $"left.N" === $"right.N", "full"),
Row(1, "A", null, null) ::
Row(2, "B", null, null) ::
Row(3, "C", 3, "C") ::
@@ -315,7 +314,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))),
+ left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== Literal(3)), "full"),
Row(1, "A", null, null) ::
Row(2, "B", null, null) ::
Row(3, "C", null, null) ::
@@ -325,7 +324,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
Row(null, null, 6, "F") :: Nil)
checkAnswer(
- left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))),
+ left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== Literal(3)), "full"),
Row(1, "A", null, null) ::
Row(2, "B", null, null) ::
Row(3, "C", null, null) ::
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 42a21c1..07c52de 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -26,12 +26,12 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer contains all of the keywords, or the
* none of keywords are listed in the answer
- * @param rdd the [[SchemaRDD]] to be executed
+ * @param rdd the [[DataFrame]] to be executed
* @param exists true for make sure the keywords are listed in the output, otherwise
* to make sure none of the keyword are not listed in the output
* @param keywords keyword in string array
*/
- def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) {
+ def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) {
val outputs = rdd.collect().map(_.mkString).mkString
for (key <- keywords) {
if (exists) {
@@ -44,10 +44,10 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer matches the expected result.
- * @param rdd the [[SchemaRDD]] to be executed
+ * @param rdd the [[DataFrame]] to be executed
* @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
*/
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = {
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
@@ -91,7 +91,7 @@ class QueryTest extends PlanTest {
}
}
- protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = {
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
checkAnswer(rdd, Seq(expectedAnswer))
}
@@ -102,7 +102,7 @@ class QueryTest extends PlanTest {
}
/** Asserts that a given SchemaRDD will be executed using the given number of cached results. */
- def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = {
+ def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
val planWithCaching = query.queryExecution.withCachedData
val cachedData = planWithCaching collect {
case cached: InMemoryRelation => cached
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 03b44ca..4fff99c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -21,6 +21,7 @@ import java.util.TimeZone
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._
@@ -29,6 +30,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext._
+
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
// Make sure the tables are loaded.
TestData
@@ -381,8 +383,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("big inner join, 4 matches per row") {
-
-
checkAnswer(
sql(
"""
@@ -396,7 +396,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
| SELECT * FROM testData UNION ALL
| SELECT * FROM testData) y
|WHERE x.key = y.key""".stripMargin),
- testData.flatMap(
+ testData.rdd.flatMap(
row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
@@ -742,7 +742,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}
test("metadata is propagated correctly") {
- val person = sql("SELECT * FROM person")
+ val person: DataFrame = sql("SELECT * FROM person")
val schema = person.schema
val docKey = "doc"
val docValue = "first name"
@@ -751,14 +751,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.build()
val schemaWithMeta = new StructType(Array(
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
- val personWithMeta = applySchema(person, schemaWithMeta)
- def validateMetadata(rdd: SchemaRDD): Unit = {
+ val personWithMeta = applySchema(person.rdd, schemaWithMeta)
+ def validateMetadata(rdd: DataFrame): Unit = {
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
personWithMeta.registerTempTable("personWithMeta")
- validateMetadata(personWithMeta.select('name))
- validateMetadata(personWithMeta.select("name".attr))
- validateMetadata(personWithMeta.select('id, 'name))
+ validateMetadata(personWithMeta.select($"name"))
+ validateMetadata(personWithMeta.select($"name"))
+ validateMetadata(personWithMeta.select($"id", $"name"))
validateMetadata(sql("SELECT * FROM personWithMeta"))
validateMetadata(sql("SELECT id, name FROM personWithMeta"))
validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId"))
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 808ed52..fffa2b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.sql.Timestamp
import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.test._
/* Implicits */
@@ -29,11 +30,11 @@ case class TestData(key: Int, value: String)
object TestData {
val testData = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD
+ (1 to 100).map(i => TestData(i, i.toString))).toDF
testData.registerTempTable("testData")
val negativeData = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(-i, (-i).toString))).toSchemaRDD
+ (1 to 100).map(i => TestData(-i, (-i).toString))).toDF
negativeData.registerTempTable("negativeData")
case class LargeAndSmallInts(a: Int, b: Int)
@@ -44,7 +45,7 @@ object TestData {
LargeAndSmallInts(2147483645, 1) ::
LargeAndSmallInts(2, 2) ::
LargeAndSmallInts(2147483646, 1) ::
- LargeAndSmallInts(3, 2) :: Nil).toSchemaRDD
+ LargeAndSmallInts(3, 2) :: Nil).toDF
largeAndSmallInts.registerTempTable("largeAndSmallInts")
case class TestData2(a: Int, b: Int)
@@ -55,7 +56,7 @@ object TestData {
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
- TestData2(3, 2) :: Nil, 2).toSchemaRDD
+ TestData2(3, 2) :: Nil, 2).toDF
testData2.registerTempTable("testData2")
case class DecimalData(a: BigDecimal, b: BigDecimal)
@@ -67,7 +68,7 @@ object TestData {
DecimalData(2, 1) ::
DecimalData(2, 2) ::
DecimalData(3, 1) ::
- DecimalData(3, 2) :: Nil).toSchemaRDD
+ DecimalData(3, 2) :: Nil).toDF
decimalData.registerTempTable("decimalData")
case class BinaryData(a: Array[Byte], b: Int)
@@ -77,17 +78,17 @@ object TestData {
BinaryData("22".getBytes(), 5) ::
BinaryData("122".getBytes(), 3) ::
BinaryData("121".getBytes(), 2) ::
- BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD
+ BinaryData("123".getBytes(), 4) :: Nil).toDF
binaryData.registerTempTable("binaryData")
case class TestData3(a: Int, b: Option[Int])
val testData3 =
TestSQLContext.sparkContext.parallelize(
TestData3(1, None) ::
- TestData3(2, Some(2)) :: Nil).toSchemaRDD
+ TestData3(2, Some(2)) :: Nil).toDF
testData3.registerTempTable("testData3")
- val emptyTableData = logical.LocalRelation('a.int, 'b.int)
+ val emptyTableData = logical.LocalRelation($"a".int, $"b".int)
case class UpperCaseData(N: Int, L: String)
val upperCaseData =
@@ -97,7 +98,7 @@ object TestData {
UpperCaseData(3, "C") ::
UpperCaseData(4, "D") ::
UpperCaseData(5, "E") ::
- UpperCaseData(6, "F") :: Nil).toSchemaRDD
+ UpperCaseData(6, "F") :: Nil).toDF
upperCaseData.registerTempTable("upperCaseData")
case class LowerCaseData(n: Int, l: String)
@@ -106,7 +107,7 @@ object TestData {
LowerCaseData(1, "a") ::
LowerCaseData(2, "b") ::
LowerCaseData(3, "c") ::
- LowerCaseData(4, "d") :: Nil).toSchemaRDD
+ LowerCaseData(4, "d") :: Nil).toDF
lowerCaseData.registerTempTable("lowerCaseData")
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
@@ -200,6 +201,6 @@ object TestData {
TestSQLContext.sparkContext.parallelize(
ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true)
:: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false)
- :: Nil).toSchemaRDD
+ :: Nil).toDF
complexData.registerTempTable("complexData")
}
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 0c98120..5abd7b9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.dsl.StringToColumn
import org.apache.spark.sql.test._
/* Implicits */
@@ -28,17 +29,17 @@ class UDFSuite extends QueryTest {
test("Simple UDF") {
udf.register("strLenScala", (_: String).length)
- assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4)
+ assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
}
test("ZeroArgument UDF") {
udf.register("random0", () => { Math.random()})
- assert(sql("SELECT random0()").first().getDouble(0) >= 0.0)
+ assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)
}
test("TwoArgument UDF") {
udf.register("strLenScala", (_: String).length + (_:Int))
- assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5)
+ assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
}
test("struct UDF") {
@@ -46,7 +47,7 @@ class UDFSuite extends QueryTest {
val result=
sql("SELECT returnStruct('test', 'test2') as ret")
- .select("ret.f1".attr).first().getString(0)
- assert(result == "test")
+ .select($"ret.f1").head().getString(0)
+ assert(result === "test")
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index fbc8704..62b2e89 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -20,9 +20,11 @@ package org.apache.spark.sql
import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types._
+
@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT])
private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable {
override def equals(other: Any): Boolean = other match {
@@ -66,14 +68,14 @@ class UserDefinedTypeSuite extends QueryTest {
test("register user type: MyDenseVector for MyLabeledPoint") {
- val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v }
+ val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v }
val labelsArrays: Array[Double] = labels.collect()
assert(labelsArrays.size === 2)
assert(labelsArrays.contains(1.0))
assert(labelsArrays.contains(0.0))
val features: RDD[MyDenseVector] =
- pointsRDD.select('features).map { case Row(v: MyDenseVector) => v }
+ pointsRDD.select('features).rdd.map { case Row(v: MyDenseVector) => v }
val featuresArrays: Array[MyDenseVector] = features.collect()
assert(featuresArrays.size === 2)
assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0))))
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index e61f3c3..6f051df 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.columnar
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext._
http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 67007b8..be5e63c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.scalatest.FunSuite
import org.apache.spark.sql.{SQLConf, execution}
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
@@ -28,6 +29,7 @@ import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.planner._
import org.apache.spark.sql.types._
+
class PlannerSuite extends FunSuite {
test("unions are collapsed") {
val query = testData.unionAll(testData).unionAll(testData).logicalPlan
@@ -40,7 +42,7 @@ class PlannerSuite extends FunSuite {
}
test("count is partially aggregated") {
- val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed
+ val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
val planned = HashAggregation(query).head
val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
@@ -48,14 +50,14 @@ class PlannerSuite extends FunSuite {
}
test("count distinct is partially aggregated") {
- val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed
+ val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
val planned = HashAggregation(query)
assert(planned.nonEmpty)
}
test("mixed aggregates are partially aggregated") {
val query =
- testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed
+ testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
val planned = HashAggregation(query)
assert(planned.nonEmpty)
}
@@ -128,9 +130,9 @@ class PlannerSuite extends FunSuite {
testData.limit(3).registerTempTable("tiny")
sql("CACHE TABLE tiny")
- val a = testData.as('a)
- val b = table("tiny").as('b)
- val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan
+ val a = testData.as("a")
+ val b = table("tiny").as("b")
+ val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan
val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org