You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2015/12/18 18:54:38 UTC
spark git commit: [SPARK-11619][SQL] cannot use UDTF in
DataFrame.selectExpr
Repository: spark
Updated Branches:
refs/heads/master 278281828 -> ee444fe4b
[SPARK-11619][SQL] cannot use UDTF in DataFrame.selectExpr
Description of the problem from cloud-fan
Actually this line: https://github.com/apache/spark/blob/branch-1.5/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala#L689
When we use `selectExpr`, we pass in `UnresolvedFunction` to `DataFrame.select` and fall in the last case. A workaround is to do special handling for UDTF like we did for `explode`(and `json_tuple` in 1.6), wrap it with `MultiAlias`.
Another workaround is using `expr`, for example, `df.select(expr("explode(a)").as(Nil))`, I think `selectExpr` is no longer needed after we have the `expr` function....
Author: Dilip Biswal <db...@us.ibm.com>
Closes #9981 from dilipbiswal/spark-11619.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ee444fe4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ee444fe4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ee444fe4
Branch: refs/heads/master
Commit: ee444fe4b8c9f382524e1fa346c67ba6da8104d8
Parents: 2782818
Author: Dilip Biswal <db...@us.ibm.com>
Authored: Fri Dec 18 09:54:30 2015 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Fri Dec 18 09:54:30 2015 -0800
----------------------------------------------------------------------
.../apache/spark/sql/catalyst/analysis/Analyzer.scala | 12 ++++++------
.../apache/spark/sql/catalyst/analysis/unresolved.scala | 6 +++++-
.../src/main/scala/org/apache/spark/sql/Column.scala | 12 +++++++-----
.../src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +-
.../scala/org/apache/spark/sql/DataFrameSuite.scala | 7 +++++++
.../scala/org/apache/spark/sql/JsonFunctionsSuite.scala | 4 ++++
.../main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +-
7 files changed, 31 insertions(+), 14 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/ee444fe4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 64dd83a..c396546 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -149,12 +149,12 @@ class Analyzer(
exprs.zipWithIndex.map {
case (expr, i) =>
expr transform {
- case u @ UnresolvedAlias(child) => child match {
+ case u @ UnresolvedAlias(child, optionalAliasName) => child match {
case ne: NamedExpression => ne
case e if !e.resolved => u
case g: Generator => MultiAlias(g, Nil)
case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)()
- case other => Alias(other, s"_c$i")()
+ case other => Alias(other, optionalAliasName.getOrElse(s"_c$i"))()
}
}
}.asInstanceOf[Seq[NamedExpression]]
@@ -287,7 +287,7 @@ class Analyzer(
}
}
val newGroupByExprs = groupByExprs.map {
- case UnresolvedAlias(e) => e
+ case UnresolvedAlias(e, _) => e
case e => e
}
Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child)
@@ -352,19 +352,19 @@ class Analyzer(
Project(
projectList.flatMap {
case s: Star => s.expand(child, resolver)
- case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) =>
+ case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) =>
val newChildren = expandStarExpressions(args, child)
UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil
case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) =>
val newChildren = expandStarExpressions(args, child)
Alias(child = f.copy(children = newChildren), name)() :: Nil
- case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) =>
+ case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
}
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
- case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) =>
+ case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
http://git-wip-us.apache.org/repos/asf/spark/blob/ee444fe4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 4f89b46..64cad6e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -284,8 +284,12 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
/**
* Holds the expression that has yet to be aliased.
+ *
+ * @param child The computation that is needs to be resolved during analysis.
+ * @param aliasName The name if specified to be asoosicated with the result of computing [[child]]
+ *
*/
-case class UnresolvedAlias(child: Expression)
+case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
extends UnaryExpression with NamedExpression with Unevaluable {
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
http://git-wip-us.apache.org/repos/asf/spark/blob/ee444fe4/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 297ef22..5026c0d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -17,20 +17,19 @@
package org.apache.spark.sql
-import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
-
import scala.language.implicitConversions
-import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
-import org.apache.spark.sql.functions.lit
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.SqlParser._
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DataTypeParser
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
+import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
-
private[sql] object Column {
def apply(colName: String): Column = new Column(colName)
@@ -130,8 +129,11 @@ class Column(protected[sql] val expr: Expression) extends Logging {
// Leave an unaliased generator with an empty list of names since the analyzer will generate
// the correct defaults after the nested expression's type has been resolved.
case explode: Explode => MultiAlias(explode, Nil)
+
case jt: JsonTuple => MultiAlias(jt, Nil)
+ case func: UnresolvedFunction => UnresolvedAlias(func, Some(func.prettyString))
+
case expr: Expression => Alias(expr, expr.prettyString)()
}
http://git-wip-us.apache.org/repos/asf/spark/blob/ee444fe4/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 79b4244..d201d65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -450,7 +450,7 @@ class Dataset[T] private[sql](
*/
@scala.annotation.varargs
def groupBy(cols: Column*): GroupedDataset[Row, T] = {
- val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias)
+ val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_))
val withKey = Project(withKeyColumns, logicalPlan)
val executed = sqlContext.executePlan(withKey)
http://git-wip-us.apache.org/repos/asf/spark/blob/ee444fe4/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 0644bda..4c3e12a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -176,6 +176,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
testData.select("key").collect().toSeq)
}
+ test("selectExpr with udtf") {
+ val df = Seq((Map("1" -> 1), 1)).toDF("a", "b")
+ checkAnswer(
+ df.selectExpr("explode(a)"),
+ Row("1", 1) :: Nil)
+ }
+
test("filterExpr") {
val res = testData.collect().filter(_.getInt(0) > 90).toSeq
checkAnswer(testData.filter("key > 90"), res)
http://git-wip-us.apache.org/repos/asf/spark/blob/ee444fe4/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index 1f384ed..1391c9d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -73,6 +73,10 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
checkAnswer(
df.select($"key", functions.json_tuple($"jstring", "f1", "f2", "f3", "f4", "f5")),
expected)
+
+ checkAnswer(
+ df.selectExpr("key", "json_tuple(jstring, 'f1', 'f2', 'f3', 'f4', 'f5')"),
+ expected)
}
test("json_tuple filter and group") {
http://git-wip-us.apache.org/repos/asf/spark/blob/ee444fe4/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index da41b65..0e89928 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1107,7 +1107,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
// (if there is a group by) or a script transformation.
val withProject: LogicalPlan = transformation.getOrElse {
val selectExpressions =
- select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias)
+ select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_))
Seq(
groupByClause.map(e => e match {
case Token("TOK_GROUPBY", children) =>
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org