You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/10/21 22:22:40 UTC

spark git commit: [SPARK-10743][SQL] keep the name of expression if possible when do cast

Repository: spark
Updated Branches:
  refs/heads/master 49ea0e9d7 -> 7c74ebca0


[SPARK-10743][SQL] keep the name of expression if possible when do cast

Author: Wenchen Fan <cl...@163.com>

Closes #8859 from cloud-fan/cast.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7c74ebca
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7c74ebca
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7c74ebca

Branch: refs/heads/master
Commit: 7c74ebca05f40a2d8fe8f10f24a10486ce4f76c0
Parents: 49ea0e9
Author: Wenchen Fan <cl...@163.com>
Authored: Wed Oct 21 13:22:35 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Wed Oct 21 13:22:35 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 33 ++++++++++----------
 .../scala/org/apache/spark/sql/Column.scala     |  4 +--
 .../spark/sql/ColumnExpressionSuite.scala       |  6 ----
 .../org/apache/spark/sql/DataFrameSuite.scala   |  5 +++
 4 files changed, 23 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7c74ebca/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 9237f2f..016dc29 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -141,32 +141,31 @@ class Analyzer(
    */
   object ResolveAliases extends Rule[LogicalPlan] {
     private def assignAliases(exprs: Seq[NamedExpression]) = {
-      // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need
-      // to traverse the whole tree.
       exprs.zipWithIndex.map {
-        case (u @ UnresolvedAlias(child), i) =>
-          child match {
-            case _: UnresolvedAttribute => u
-            case ne: NamedExpression => ne
-            case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil)
-            case e if !e.resolved => u
-            case other => Alias(other, s"_c$i")()
+        case (expr, i) =>
+          expr transform {
+            case u @ UnresolvedAlias(child) => child match {
+              case ne: NamedExpression => ne
+              case e if !e.resolved => u
+              case g: Generator if g.elementTypes.size > 1 => MultiAlias(g, Nil)
+              case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)()
+              case other => Alias(other, s"_c$i")()
+            }
           }
-        case (other, _) => other
-      }
+      }.asInstanceOf[Seq[NamedExpression]]
     }
 
+    private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) =
+      exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined)
+
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
-      case Aggregate(groups, aggs, child)
-        if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) =>
+      case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
         Aggregate(groups, assignAliases(aggs), child)
 
-      case g: GroupingAnalytics
-        if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) =>
+      case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
         g.withNewAggs(assignAliases(g.aggregations))
 
-      case Project(projectList, child)
-        if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) =>
+      case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) =>
         Project(assignAliases(projectList), child)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/7c74ebca/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 1f82688..37d559c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -835,8 +835,8 @@ class Column(protected[sql] val expr: Expression) extends Logging {
    * @since 1.3.0
    */
   def cast(to: DataType): Column = expr match {
-    // Lift alias out of cast so we can support col.as("name").cast(IntegerType)
-    case Alias(childExpr, name) => Alias(Cast(childExpr, to), name)()
+    // keeps the name of expression if possible when do cast.
+    case ne: NamedExpression => UnresolvedAlias(Cast(expr, to))
     case _ => Cast(expr, to)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7c74ebca/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 4e988f0..fa559c9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -588,12 +588,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
     }
   }
 
-  test("lift alias out of cast") {
-    compareExpressions(
-      col("1234").as("name").cast("int").expr,
-      col("1234").cast("int").as("name").expr)
-  }
-
   test("columns can be compared") {
     assert('key.desc == 'key.desc)
     assert('key.desc != 'key.asc)

http://git-wip-us.apache.org/repos/asf/spark/blob/7c74ebca/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 832ea02..6424f1f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -975,4 +975,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
       expected(except)
     )
   }
+
+  test("SPARK-10743: keep the name of expression if possible when do cast") {
+    val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src")
+    assert(df.select($"src.i".cast(StringType)).columns.head === "i")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org