You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/04/22 20:10:38 UTC
[2/2] spark git commit: [SPARK-14841][SQL] Move SQLBuilder into
sql/core
[SPARK-14841][SQL] Move SQLBuilder into sql/core
## What changes were proposed in this pull request?
This patch moves SQLBuilder into sql/core so we can in the future move view generation also into sql/core.
## How was this patch tested?
Also moved unit tests.
Author: Reynold Xin <rx...@databricks.com>
Author: Wenchen Fan <we...@databricks.com>
Closes #12602 from rxin/SPARK-14841.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/aeb52bea
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/aeb52bea
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/aeb52bea
Branch: refs/heads/master
Commit: aeb52bea56d0409f7d039ace366b3f7ef9d24dcb
Parents: 8098f15
Author: Reynold Xin <rx...@databricks.com>
Authored: Fri Apr 22 11:10:31 2016 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Fri Apr 22 11:10:31 2016 -0700
----------------------------------------------------------------------
.../apache/spark/sql/catalyst/SQLBuilder.scala | 533 +++++++++++++
.../scala/org/apache/spark/sql/QueryTest.scala | 1 +
.../org/apache/spark/sql/hive/SQLBuilder.scala | 537 -------------
.../sql/hive/execution/CreateViewAsSelect.scala | 5 +-
.../catalyst/ExpressionSQLBuilderSuite.scala | 82 ++
.../sql/catalyst/ExpressionToSQLSuite.scala | 280 +++++++
.../sql/catalyst/LogicalPlanToSQLSuite.scala | 744 +++++++++++++++++++
.../spark/sql/catalyst/SQLBuilderTest.scala | 73 ++
.../sql/hive/ExpressionSQLBuilderSuite.scala | 81 --
.../spark/sql/hive/ExpressionToSQLSuite.scala | 280 -------
.../spark/sql/hive/LogicalPlanToSQLSuite.scala | 744 -------------------
.../apache/spark/sql/hive/SQLBuilderTest.scala | 72 --
.../sql/hive/execution/HiveComparisonTest.scala | 4 +-
13 files changed, 1718 insertions(+), 1718 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/aeb52bea/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
new file mode 100644
index 0000000..d65b3cb
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
@@ -0,0 +1,533 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.catalyst.catalog.CatalogRelation
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.optimizer.{CollapseProject, CombineUnions}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
+import org.apache.spark.sql.catalyst.util.quoteIdentifier
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, NullType}
+
+/**
+ * A builder class used to convert a resolved logical plan into a SQL query string. Note that not
+ * all resolved logical plan are convertible. They either don't have corresponding SQL
+ * representations (e.g. logical plans that operate on local Scala collections), or are simply not
+ * supported by this builder (yet).
+ */
+class SQLBuilder(logicalPlan: LogicalPlan) extends Logging {
+ require(logicalPlan.resolved, "SQLBuilder only supports resolved logical query plans")
+
+ def this(df: Dataset[_]) = this(df.queryExecution.analyzed)
+
+ private val nextSubqueryId = new AtomicLong(0)
+ private def newSubqueryName(): String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}"
+
+ def toSQL: String = {
+ val canonicalizedPlan = Canonicalizer.execute(logicalPlan)
+ val outputNames = logicalPlan.output.map(_.name)
+ val qualifiers = logicalPlan.output.flatMap(_.qualifier).distinct
+
+ // Keep the qualifier information by using it as sub-query name, if there is only one qualifier
+ // present.
+ val finalName = if (qualifiers.length == 1) {
+ qualifiers.head
+ } else {
+ newSubqueryName()
+ }
+
+ // Canonicalizer will remove all naming information, we should add it back by adding an extra
+ // Project and alias the outputs.
+ val aliasedOutput = canonicalizedPlan.output.zip(outputNames).map {
+ case (attr, name) => Alias(attr.withQualifier(None), name)()
+ }
+ val finalPlan = Project(aliasedOutput, SubqueryAlias(finalName, canonicalizedPlan))
+
+ try {
+ val replaced = finalPlan.transformAllExpressions {
+ case e: SubqueryExpression =>
+ SubqueryHolder(new SQLBuilder(e.query).toSQL)
+ case e: NonSQLExpression =>
+ throw new UnsupportedOperationException(
+ s"Expression $e doesn't have a SQL representation"
+ )
+ case e => e
+ }
+
+ val generatedSQL = toSQL(replaced)
+ logDebug(
+ s"""Built SQL query string successfully from given logical plan:
+ |
+ |# Original logical plan:
+ |${logicalPlan.treeString}
+ |# Canonicalized logical plan:
+ |${replaced.treeString}
+ |# Generated SQL:
+ |$generatedSQL
+ """.stripMargin)
+ generatedSQL
+ } catch { case NonFatal(e) =>
+ logDebug(
+ s"""Failed to build SQL query string from given logical plan:
+ |
+ |# Original logical plan:
+ |${logicalPlan.treeString}
+ |# Canonicalized logical plan:
+ |${canonicalizedPlan.treeString}
+ """.stripMargin)
+ throw e
+ }
+ }
+
+ private def toSQL(node: LogicalPlan): String = node match {
+ case Distinct(p: Project) =>
+ projectToSQL(p, isDistinct = true)
+
+ case p: Project =>
+ projectToSQL(p, isDistinct = false)
+
+ case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) =>
+ groupingSetToSQL(a, e, p)
+
+ case p: Aggregate =>
+ aggregateToSQL(p)
+
+ case w: Window =>
+ windowToSQL(w)
+
+ case g: Generate =>
+ generateToSQL(g)
+
+ case Limit(limitExpr, child) =>
+ s"${toSQL(child)} LIMIT ${limitExpr.sql}"
+
+ case Filter(condition, child) =>
+ val whereOrHaving = child match {
+ case _: Aggregate => "HAVING"
+ case _ => "WHERE"
+ }
+ build(toSQL(child), whereOrHaving, condition.sql)
+
+ case p @ Distinct(u: Union) if u.children.length > 1 =>
+ val childrenSql = u.children.map(c => s"(${toSQL(c)})")
+ childrenSql.mkString(" UNION DISTINCT ")
+
+ case p: Union if p.children.length > 1 =>
+ val childrenSql = p.children.map(c => s"(${toSQL(c)})")
+ childrenSql.mkString(" UNION ALL ")
+
+ case p: Intersect =>
+ build("(" + toSQL(p.left), ") INTERSECT (", toSQL(p.right) + ")")
+
+ case p: Except =>
+ build("(" + toSQL(p.left), ") EXCEPT (", toSQL(p.right) + ")")
+
+ case p: SubqueryAlias => build("(" + toSQL(p.child) + ")", "AS", p.alias)
+
+ case p: Join =>
+ build(
+ toSQL(p.left),
+ p.joinType.sql,
+ "JOIN",
+ toSQL(p.right),
+ p.condition.map(" ON " + _.sql).getOrElse(""))
+
+ case SQLTable(database, table, _, sample) =>
+ val qualifiedName = s"${quoteIdentifier(database)}.${quoteIdentifier(table)}"
+ sample.map { case (lowerBound, upperBound) =>
+ val fraction = math.min(100, math.max(0, (upperBound - lowerBound) * 100))
+ qualifiedName + " TABLESAMPLE(" + fraction + " PERCENT)"
+ }.getOrElse(qualifiedName)
+
+ case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _))
+ if orders.map(_.child) == partitionExprs =>
+ build(toSQL(child), "CLUSTER BY", partitionExprs.map(_.sql).mkString(", "))
+
+ case p: Sort =>
+ build(
+ toSQL(p.child),
+ if (p.global) "ORDER BY" else "SORT BY",
+ p.order.map(_.sql).mkString(", ")
+ )
+
+ case p: RepartitionByExpression =>
+ build(
+ toSQL(p.child),
+ "DISTRIBUTE BY",
+ p.partitionExpressions.map(_.sql).mkString(", ")
+ )
+
+ case p: ScriptTransformation =>
+ scriptTransformationToSQL(p)
+
+ case OneRowRelation =>
+ ""
+
+ case _ =>
+ throw new UnsupportedOperationException(s"unsupported plan $node")
+ }
+
+ /**
+ * Turns a bunch of string segments into a single string and separate each segment by a space.
+ * The segments are trimmed so only a single space appears in the separation.
+ * For example, `build("a", " b ", " c")` becomes "a b c".
+ */
+ private def build(segments: String*): String =
+ segments.map(_.trim).filter(_.nonEmpty).mkString(" ")
+
+ private def projectToSQL(plan: Project, isDistinct: Boolean): String = {
+ build(
+ "SELECT",
+ if (isDistinct) "DISTINCT" else "",
+ plan.projectList.map(_.sql).mkString(", "),
+ if (plan.child == OneRowRelation) "" else "FROM",
+ toSQL(plan.child)
+ )
+ }
+
+ private def scriptTransformationToSQL(plan: ScriptTransformation): String = {
+ val inputRowFormatSQL = plan.ioschema.inputRowFormatSQL.getOrElse(
+ throw new UnsupportedOperationException(
+ s"unsupported row format ${plan.ioschema.inputRowFormat}"))
+ val outputRowFormatSQL = plan.ioschema.outputRowFormatSQL.getOrElse(
+ throw new UnsupportedOperationException(
+ s"unsupported row format ${plan.ioschema.outputRowFormat}"))
+
+ val outputSchema = plan.output.map { attr =>
+ s"${attr.sql} ${attr.dataType.simpleString}"
+ }.mkString(", ")
+
+ build(
+ "SELECT TRANSFORM",
+ "(" + plan.input.map(_.sql).mkString(", ") + ")",
+ inputRowFormatSQL,
+ s"USING \'${plan.script}\'",
+ "AS (" + outputSchema + ")",
+ outputRowFormatSQL,
+ if (plan.child == OneRowRelation) "" else "FROM",
+ toSQL(plan.child)
+ )
+ }
+
+ private def aggregateToSQL(plan: Aggregate): String = {
+ val groupingSQL = plan.groupingExpressions.map(_.sql).mkString(", ")
+ build(
+ "SELECT",
+ plan.aggregateExpressions.map(_.sql).mkString(", "),
+ if (plan.child == OneRowRelation) "" else "FROM",
+ toSQL(plan.child),
+ if (groupingSQL.isEmpty) "" else "GROUP BY",
+ groupingSQL
+ )
+ }
+
+ private def generateToSQL(g: Generate): String = {
+ val columnAliases = g.generatorOutput.map(_.sql).mkString(", ")
+
+ val childSQL = if (g.child == OneRowRelation) {
+ // This only happens when we put UDTF in project list and there is no FROM clause. Because we
+ // always generate LATERAL VIEW for `Generate`, here we use a trick to put a dummy sub-query
+ // after FROM clause, so that we can generate a valid LATERAL VIEW SQL string.
+ // For example, if the original SQL is: "SELECT EXPLODE(ARRAY(1, 2))", we will convert in to
+ // LATERAL VIEW format, and generate:
+ // SELECT col FROM (SELECT 1) sub_q0 LATERAL VIEW EXPLODE(ARRAY(1, 2)) sub_q1 AS col
+ s"(SELECT 1) ${newSubqueryName()}"
+ } else {
+ toSQL(g.child)
+ }
+
+ // The final SQL string for Generate contains 7 parts:
+ // 1. the SQL of child, can be a table or sub-query
+ // 2. the LATERAL VIEW keyword
+ // 3. an optional OUTER keyword
+ // 4. the SQL of generator, e.g. EXPLODE(array_col)
+ // 5. the table alias for output columns of generator.
+ // 6. the AS keyword
+ // 7. the column alias, can be more than one, e.g. AS key, value
+ // An concrete example: "tbl LATERAL VIEW EXPLODE(map_col) sub_q AS key, value", and the builder
+ // will put it in FROM clause later.
+ build(
+ childSQL,
+ "LATERAL VIEW",
+ if (g.outer) "OUTER" else "",
+ g.generator.sql,
+ newSubqueryName(),
+ "AS",
+ columnAliases
+ )
+ }
+
+ private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
+ output1.size == output2.size &&
+ output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
+
+ private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = {
+ assert(a.child == e && e.child == p)
+ a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && sameOutput(
+ e.output.drop(p.child.output.length),
+ a.groupingExpressions.map(_.asInstanceOf[Attribute]))
+ }
+
+ private def groupingSetToSQL(agg: Aggregate, expand: Expand, project: Project): String = {
+ assert(agg.groupingExpressions.length > 1)
+
+ // The last column of Expand is always grouping ID
+ val gid = expand.output.last
+
+ val numOriginalOutput = project.child.output.length
+ // Assumption: Aggregate's groupingExpressions is composed of
+ // 1) the grouping attributes
+ // 2) gid, which is always the last one
+ val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
+ // Assumption: Project's projectList is composed of
+ // 1) the original output (Project's child.output),
+ // 2) the aliased group by expressions.
+ val expandedAttributes = project.output.drop(numOriginalOutput)
+ val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
+ val groupingSQL = groupByExprs.map(_.sql).mkString(", ")
+
+ // a map from group by attributes to the original group by expressions.
+ val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))
+ // a map from expanded attributes to the original group by expressions.
+ val expandedAttrMap = AttributeMap(expandedAttributes.zip(groupByExprs))
+
+ val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project =>
+ // Assumption: expand.projections is composed of
+ // 1) the original output (Project's child.output),
+ // 2) expanded attributes(or null literal)
+ // 3) gid, which is always the last one in each project in Expand
+ project.drop(numOriginalOutput).dropRight(1).collect {
+ case attr: Attribute if expandedAttrMap.contains(attr) => expandedAttrMap(attr)
+ }
+ }
+ val groupingSetSQL = "GROUPING SETS(" +
+ groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")"
+
+ val aggExprs = agg.aggregateExpressions.map { case aggExpr =>
+ val originalAggExpr = aggExpr.transformDown {
+ // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back.
+ case ar: AttributeReference if ar == gid => GroupingID(Nil)
+ case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar)
+ case a @ Cast(BitwiseAnd(
+ ShiftRight(ar: AttributeReference, Literal(value: Any, IntegerType)),
+ Literal(1, IntegerType)), ByteType) if ar == gid =>
+ // for converting an expression to its original SQL format grouping(col)
+ val idx = groupByExprs.length - 1 - value.asInstanceOf[Int]
+ groupByExprs.lift(idx).map(Grouping).getOrElse(a)
+ }
+
+ originalAggExpr match {
+ // Ancestor operators may reference the output of this grouping set, and we use exprId to
+ // generate a unique name for each attribute, so we should make sure the transformed
+ // aggregate expression won't change the output, i.e. exprId and alias name should remain
+ // the same.
+ case ne: NamedExpression if ne.exprId == aggExpr.exprId => ne
+ case e => Alias(e, normalizedName(aggExpr))(exprId = aggExpr.exprId)
+ }
+ }
+
+ build(
+ "SELECT",
+ aggExprs.map(_.sql).mkString(", "),
+ if (agg.child == OneRowRelation) "" else "FROM",
+ toSQL(project.child),
+ "GROUP BY",
+ groupingSQL,
+ groupingSetSQL
+ )
+ }
+
+ private def windowToSQL(w: Window): String = {
+ build(
+ "SELECT",
+ (w.child.output ++ w.windowExpressions).map(_.sql).mkString(", "),
+ if (w.child == OneRowRelation) "" else "FROM",
+ toSQL(w.child)
+ )
+ }
+
+ private def normalizedName(n: NamedExpression): String = "gen_attr_" + n.exprId.id
+
+ object Canonicalizer extends RuleExecutor[LogicalPlan] {
+ override protected def batches: Seq[Batch] = Seq(
+ Batch("Prepare", FixedPoint(100),
+ // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over
+ // `Aggregate`s to perform type casting. This rule merges these `Project`s into
+ // `Aggregate`s.
+ CollapseProject,
+ // Parser is unable to parse the following query:
+ // SELECT `u_1`.`id`
+ // FROM (((SELECT `t0`.`id` FROM `default`.`t0`)
+ // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`))
+ // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1
+ // This rule combine adjacent Unions together so we can generate flat UNION ALL SQL string.
+ CombineUnions),
+ Batch("Recover Scoping Info", Once,
+ // A logical plan is allowed to have same-name outputs with different qualifiers(e.g. the
+ // `Join` operator). However, this kind of plan can't be put under a sub query as we will
+ // erase and assign a new qualifier to all outputs and make it impossible to distinguish
+ // same-name outputs. This rule renames all attributes, to guarantee different
+ // attributes(with different exprId) always have different names. It also removes all
+ // qualifiers, as attributes have unique names now and we don't need qualifiers to resolve
+ // ambiguity.
+ NormalizedAttribute,
+ // Our analyzer will add one or more sub-queries above table relation, this rule removes
+ // these sub-queries so that next rule can combine adjacent table relation and sample to
+ // SQLTable.
+ RemoveSubqueriesAboveSQLTable,
+ // Finds the table relations and wrap them with `SQLTable`s. If there are any `Sample`
+ // operators on top of a table relation, merge the sample information into `SQLTable` of
+ // that table relation, as we can only convert table sample to standard SQL string.
+ ResolveSQLTable,
+ // Insert sub queries on top of operators that need to appear after FROM clause.
+ AddSubquery
+ )
+ )
+
+ object NormalizedAttribute extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
+ case a: AttributeReference =>
+ AttributeReference(normalizedName(a), a.dataType)(exprId = a.exprId, qualifier = None)
+ case a: Alias =>
+ Alias(a.child, normalizedName(a))(exprId = a.exprId, qualifier = None)
+ }
+ }
+
+ object RemoveSubqueriesAboveSQLTable extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case SubqueryAlias(_, t @ ExtractSQLTable(_)) => t
+ }
+ }
+
+ object ResolveSQLTable extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown {
+ case Sample(lowerBound, upperBound, _, _, ExtractSQLTable(table)) =>
+ aliasColumns(table.withSample(lowerBound, upperBound))
+ case ExtractSQLTable(table) =>
+ aliasColumns(table)
+ }
+
+ /**
+ * Aliases the table columns to the generated attribute names, as we use exprId to generate
+ * unique name for each attribute when normalize attributes, and we can't reference table
+ * columns with their real names.
+ */
+ private def aliasColumns(table: SQLTable): LogicalPlan = {
+ val aliasedOutput = table.output.map { attr =>
+ Alias(attr, normalizedName(attr))(exprId = attr.exprId)
+ }
+ addSubquery(Project(aliasedOutput, table))
+ }
+ }
+
+ object AddSubquery extends Rule[LogicalPlan] {
+ override def apply(tree: LogicalPlan): LogicalPlan = tree transformUp {
+ // This branch handles aggregate functions within HAVING clauses. For example:
+ //
+ // SELECT key FROM src GROUP BY key HAVING max(value) > "val_255"
+ //
+ // This kind of query results in query plans of the following form because of analysis rule
+ // `ResolveAggregateFunctions`:
+ //
+ // Project ...
+ // +- Filter ...
+ // +- Aggregate ...
+ // +- MetastoreRelation default, src, None
+ case p @ Project(_, f @ Filter(_, _: Aggregate)) => p.copy(child = addSubquery(f))
+
+ case w @ Window(_, _, _, f @ Filter(_, _: Aggregate)) => w.copy(child = addSubquery(f))
+
+ case p: Project => p.copy(child = addSubqueryIfNeeded(p.child))
+
+ // We will generate "SELECT ... FROM ..." for Window operator, so its child operator should
+ // be able to put in the FROM clause, or we wrap it with a subquery.
+ case w: Window => w.copy(child = addSubqueryIfNeeded(w.child))
+
+ case j: Join => j.copy(
+ left = addSubqueryIfNeeded(j.left),
+ right = addSubqueryIfNeeded(j.right))
+
+ // A special case for Generate. When we put UDTF in project list, followed by WHERE, e.g.
+ // SELECT EXPLODE(arr) FROM tbl WHERE id > 1, the Filter operator will be under Generate
+ // operator and we need to add a sub-query between them, as it's not allowed to have a WHERE
+ // before LATERAL VIEW, e.g. "... FROM tbl WHERE id > 2 EXPLODE(arr) ..." is illegal.
+ case g @ Generate(_, _, _, _, _, f: Filter) =>
+ // Add an extra `Project` to make sure we can generate legal SQL string for sub-query,
+ // for example, Subquery -> Filter -> Table will generate "(tbl WHERE ...) AS name", which
+ // misses the SELECT part.
+ val proj = Project(f.output, f)
+ g.copy(child = addSubquery(proj))
+ }
+ }
+
+ private def addSubquery(plan: LogicalPlan): SubqueryAlias = {
+ SubqueryAlias(newSubqueryName(), plan)
+ }
+
+ private def addSubqueryIfNeeded(plan: LogicalPlan): LogicalPlan = plan match {
+ case _: SubqueryAlias => plan
+ case _: Filter => plan
+ case _: Join => plan
+ case _: LocalLimit => plan
+ case _: GlobalLimit => plan
+ case _: SQLTable => plan
+ case _: Generate => plan
+ case OneRowRelation => plan
+ case _ => addSubquery(plan)
+ }
+ }
+
+ case class SQLTable(
+ database: String,
+ table: String,
+ output: Seq[Attribute],
+ sample: Option[(Double, Double)] = None) extends LeafNode {
+ def withSample(lowerBound: Double, upperBound: Double): SQLTable =
+ this.copy(sample = Some(lowerBound -> upperBound))
+ }
+
+ object ExtractSQLTable {
+ def unapply(plan: LogicalPlan): Option[SQLTable] = plan match {
+ case l @ LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) =>
+ Some(SQLTable(database, table, l.output.map(_.withQualifier(None))))
+
+ case relation: CatalogRelation =>
+ val m = relation.catalogTable
+ Some(SQLTable(m.database, m.identifier.table, relation.output.map(_.withQualifier(None))))
+
+ case _ => None
+ }
+ }
+
+ /**
+ * A place holder for generated SQL for subquery expression.
+ */
+ case class SubqueryHolder(query: String) extends LeafExpression with Unevaluable {
+ override def dataType: DataType = NullType
+ override def nullable: Boolean = true
+ override def sql: String = s"($query)"
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/aeb52bea/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index b0d7b05..d9b374b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -212,6 +212,7 @@ abstract class QueryTest extends PlanTest {
case _: ObjectProducer => return
case _: AppendColumns => return
case _: LogicalRelation => return
+ case p if p.getClass.getSimpleName == "MetastoreRelation" => return
case _: MemoryPlan => return
}.transformAllExpressions {
case a: ImperativeAggregate => return
http://git-wip-us.apache.org/repos/asf/spark/blob/aeb52bea/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
deleted file mode 100644
index 3a0e22c..0000000
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
+++ /dev/null
@@ -1,537 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import java.util.concurrent.atomic.AtomicLong
-
-import scala.util.control.NonFatal
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{Dataset, SQLContext}
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.CatalogRelation
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.optimizer.{CollapseProject, CombineUnions}
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
-import org.apache.spark.sql.catalyst.util.quoteIdentifier
-import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, NullType}
-
-/**
- * A builder class used to convert a resolved logical plan into a SQL query string. Note that not
- * all resolved logical plan are convertible. They either don't have corresponding SQL
- * representations (e.g. logical plans that operate on local Scala collections), or are simply not
- * supported by this builder (yet).
- */
-class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging {
- require(logicalPlan.resolved, "SQLBuilder only supports resolved logical query plans")
-
- def this(df: Dataset[_]) = this(df.queryExecution.analyzed, df.sqlContext)
-
- private val nextSubqueryId = new AtomicLong(0)
- private def newSubqueryName(): String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}"
-
- def toSQL: String = {
- val canonicalizedPlan = Canonicalizer.execute(logicalPlan)
- val outputNames = logicalPlan.output.map(_.name)
- val qualifiers = logicalPlan.output.flatMap(_.qualifier).distinct
-
- // Keep the qualifier information by using it as sub-query name, if there is only one qualifier
- // present.
- val finalName = if (qualifiers.length == 1) {
- qualifiers.head
- } else {
- newSubqueryName()
- }
-
- // Canonicalizer will remove all naming information, we should add it back by adding an extra
- // Project and alias the outputs.
- val aliasedOutput = canonicalizedPlan.output.zip(outputNames).map {
- case (attr, name) => Alias(attr.withQualifier(None), name)()
- }
- val finalPlan = Project(aliasedOutput, SubqueryAlias(finalName, canonicalizedPlan))
-
- try {
- val replaced = finalPlan.transformAllExpressions {
- case e: SubqueryExpression =>
- SubqueryHolder(new SQLBuilder(e.query, sqlContext).toSQL)
- case e: NonSQLExpression =>
- throw new UnsupportedOperationException(
- s"Expression $e doesn't have a SQL representation"
- )
- case e => e
- }
-
- val generatedSQL = toSQL(replaced)
- logDebug(
- s"""Built SQL query string successfully from given logical plan:
- |
- |# Original logical plan:
- |${logicalPlan.treeString}
- |# Canonicalized logical plan:
- |${replaced.treeString}
- |# Generated SQL:
- |$generatedSQL
- """.stripMargin)
- generatedSQL
- } catch { case NonFatal(e) =>
- logDebug(
- s"""Failed to build SQL query string from given logical plan:
- |
- |# Original logical plan:
- |${logicalPlan.treeString}
- |# Canonicalized logical plan:
- |${canonicalizedPlan.treeString}
- """.stripMargin)
- throw e
- }
- }
-
- private def toSQL(node: LogicalPlan): String = node match {
- case Distinct(p: Project) =>
- projectToSQL(p, isDistinct = true)
-
- case p: Project =>
- projectToSQL(p, isDistinct = false)
-
- case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) =>
- groupingSetToSQL(a, e, p)
-
- case p: Aggregate =>
- aggregateToSQL(p)
-
- case w: Window =>
- windowToSQL(w)
-
- case g: Generate =>
- generateToSQL(g)
-
- case Limit(limitExpr, child) =>
- s"${toSQL(child)} LIMIT ${limitExpr.sql}"
-
- case Filter(condition, child) =>
- val whereOrHaving = child match {
- case _: Aggregate => "HAVING"
- case _ => "WHERE"
- }
- build(toSQL(child), whereOrHaving, condition.sql)
-
- case p @ Distinct(u: Union) if u.children.length > 1 =>
- val childrenSql = u.children.map(c => s"(${toSQL(c)})")
- childrenSql.mkString(" UNION DISTINCT ")
-
- case p: Union if p.children.length > 1 =>
- val childrenSql = p.children.map(c => s"(${toSQL(c)})")
- childrenSql.mkString(" UNION ALL ")
-
- case p: Intersect =>
- build("(" + toSQL(p.left), ") INTERSECT (", toSQL(p.right) + ")")
-
- case p: Except =>
- build("(" + toSQL(p.left), ") EXCEPT (", toSQL(p.right) + ")")
-
- case p: SubqueryAlias => build("(" + toSQL(p.child) + ")", "AS", p.alias)
-
- case p: Join =>
- build(
- toSQL(p.left),
- p.joinType.sql,
- "JOIN",
- toSQL(p.right),
- p.condition.map(" ON " + _.sql).getOrElse(""))
-
- case SQLTable(database, table, _, sample) =>
- val qualifiedName = s"${quoteIdentifier(database)}.${quoteIdentifier(table)}"
- sample.map { case (lowerBound, upperBound) =>
- val fraction = math.min(100, math.max(0, (upperBound - lowerBound) * 100))
- qualifiedName + " TABLESAMPLE(" + fraction + " PERCENT)"
- }.getOrElse(qualifiedName)
-
- case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _))
- if orders.map(_.child) == partitionExprs =>
- build(toSQL(child), "CLUSTER BY", partitionExprs.map(_.sql).mkString(", "))
-
- case p: Sort =>
- build(
- toSQL(p.child),
- if (p.global) "ORDER BY" else "SORT BY",
- p.order.map(_.sql).mkString(", ")
- )
-
- case p: RepartitionByExpression =>
- build(
- toSQL(p.child),
- "DISTRIBUTE BY",
- p.partitionExpressions.map(_.sql).mkString(", ")
- )
-
- case p: ScriptTransformation =>
- scriptTransformationToSQL(p)
-
- case OneRowRelation =>
- ""
-
- case _ =>
- throw new UnsupportedOperationException(s"unsupported plan $node")
- }
-
- /**
- * Turns a bunch of string segments into a single string and separate each segment by a space.
- * The segments are trimmed so only a single space appears in the separation.
- * For example, `build("a", " b ", " c")` becomes "a b c".
- */
- private def build(segments: String*): String =
- segments.map(_.trim).filter(_.nonEmpty).mkString(" ")
-
- private def projectToSQL(plan: Project, isDistinct: Boolean): String = {
- build(
- "SELECT",
- if (isDistinct) "DISTINCT" else "",
- plan.projectList.map(_.sql).mkString(", "),
- if (plan.child == OneRowRelation) "" else "FROM",
- toSQL(plan.child)
- )
- }
-
- private def scriptTransformationToSQL(plan: ScriptTransformation): String = {
- val inputRowFormatSQL = plan.ioschema.inputRowFormatSQL.getOrElse(
- throw new UnsupportedOperationException(
- s"unsupported row format ${plan.ioschema.inputRowFormat}"))
- val outputRowFormatSQL = plan.ioschema.outputRowFormatSQL.getOrElse(
- throw new UnsupportedOperationException(
- s"unsupported row format ${plan.ioschema.outputRowFormat}"))
-
- val outputSchema = plan.output.map { attr =>
- s"${attr.sql} ${attr.dataType.simpleString}"
- }.mkString(", ")
-
- build(
- "SELECT TRANSFORM",
- "(" + plan.input.map(_.sql).mkString(", ") + ")",
- inputRowFormatSQL,
- s"USING \'${plan.script}\'",
- "AS (" + outputSchema + ")",
- outputRowFormatSQL,
- if (plan.child == OneRowRelation) "" else "FROM",
- toSQL(plan.child)
- )
- }
-
- private def aggregateToSQL(plan: Aggregate): String = {
- val groupingSQL = plan.groupingExpressions.map(_.sql).mkString(", ")
- build(
- "SELECT",
- plan.aggregateExpressions.map(_.sql).mkString(", "),
- if (plan.child == OneRowRelation) "" else "FROM",
- toSQL(plan.child),
- if (groupingSQL.isEmpty) "" else "GROUP BY",
- groupingSQL
- )
- }
-
- private def generateToSQL(g: Generate): String = {
- val columnAliases = g.generatorOutput.map(_.sql).mkString(", ")
-
- val childSQL = if (g.child == OneRowRelation) {
- // This only happens when we put UDTF in project list and there is no FROM clause. Because we
- // always generate LATERAL VIEW for `Generate`, here we use a trick to put a dummy sub-query
- // after FROM clause, so that we can generate a valid LATERAL VIEW SQL string.
- // For example, if the original SQL is: "SELECT EXPLODE(ARRAY(1, 2))", we will convert in to
- // LATERAL VIEW format, and generate:
- // SELECT col FROM (SELECT 1) sub_q0 LATERAL VIEW EXPLODE(ARRAY(1, 2)) sub_q1 AS col
- s"(SELECT 1) ${newSubqueryName()}"
- } else {
- toSQL(g.child)
- }
-
- // The final SQL string for Generate contains 7 parts:
- // 1. the SQL of child, can be a table or sub-query
- // 2. the LATERAL VIEW keyword
- // 3. an optional OUTER keyword
- // 4. the SQL of generator, e.g. EXPLODE(array_col)
- // 5. the table alias for output columns of generator.
- // 6. the AS keyword
- // 7. the column alias, can be more than one, e.g. AS key, value
- // An concrete example: "tbl LATERAL VIEW EXPLODE(map_col) sub_q AS key, value", and the builder
- // will put it in FROM clause later.
- build(
- childSQL,
- "LATERAL VIEW",
- if (g.outer) "OUTER" else "",
- g.generator.sql,
- newSubqueryName(),
- "AS",
- columnAliases
- )
- }
-
- private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
- output1.size == output2.size &&
- output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
-
- private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = {
- assert(a.child == e && e.child == p)
- a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && sameOutput(
- e.output.drop(p.child.output.length),
- a.groupingExpressions.map(_.asInstanceOf[Attribute]))
- }
-
- private def groupingSetToSQL(
- agg: Aggregate,
- expand: Expand,
- project: Project): String = {
- assert(agg.groupingExpressions.length > 1)
-
- // The last column of Expand is always grouping ID
- val gid = expand.output.last
-
- val numOriginalOutput = project.child.output.length
- // Assumption: Aggregate's groupingExpressions is composed of
- // 1) the grouping attributes
- // 2) gid, which is always the last one
- val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
- // Assumption: Project's projectList is composed of
- // 1) the original output (Project's child.output),
- // 2) the aliased group by expressions.
- val expandedAttributes = project.output.drop(numOriginalOutput)
- val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
- val groupingSQL = groupByExprs.map(_.sql).mkString(", ")
-
- // a map from group by attributes to the original group by expressions.
- val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))
- // a map from expanded attributes to the original group by expressions.
- val expandedAttrMap = AttributeMap(expandedAttributes.zip(groupByExprs))
-
- val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project =>
- // Assumption: expand.projections is composed of
- // 1) the original output (Project's child.output),
- // 2) expanded attributes(or null literal)
- // 3) gid, which is always the last one in each project in Expand
- project.drop(numOriginalOutput).dropRight(1).collect {
- case attr: Attribute if expandedAttrMap.contains(attr) => expandedAttrMap(attr)
- }
- }
- val groupingSetSQL = "GROUPING SETS(" +
- groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")"
-
- val aggExprs = agg.aggregateExpressions.map { case aggExpr =>
- val originalAggExpr = aggExpr.transformDown {
- // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back.
- case ar: AttributeReference if ar == gid => GroupingID(Nil)
- case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar)
- case a @ Cast(BitwiseAnd(
- ShiftRight(ar: AttributeReference, Literal(value: Any, IntegerType)),
- Literal(1, IntegerType)), ByteType) if ar == gid =>
- // for converting an expression to its original SQL format grouping(col)
- val idx = groupByExprs.length - 1 - value.asInstanceOf[Int]
- groupByExprs.lift(idx).map(Grouping).getOrElse(a)
- }
-
- originalAggExpr match {
- // Ancestor operators may reference the output of this grouping set, and we use exprId to
- // generate a unique name for each attribute, so we should make sure the transformed
- // aggregate expression won't change the output, i.e. exprId and alias name should remain
- // the same.
- case ne: NamedExpression if ne.exprId == aggExpr.exprId => ne
- case e => Alias(e, normalizedName(aggExpr))(exprId = aggExpr.exprId)
- }
- }
-
- build(
- "SELECT",
- aggExprs.map(_.sql).mkString(", "),
- if (agg.child == OneRowRelation) "" else "FROM",
- toSQL(project.child),
- "GROUP BY",
- groupingSQL,
- groupingSetSQL
- )
- }
-
- private def windowToSQL(w: Window): String = {
- build(
- "SELECT",
- (w.child.output ++ w.windowExpressions).map(_.sql).mkString(", "),
- if (w.child == OneRowRelation) "" else "FROM",
- toSQL(w.child)
- )
- }
-
- private def normalizedName(n: NamedExpression): String = "gen_attr_" + n.exprId.id
-
- object Canonicalizer extends RuleExecutor[LogicalPlan] {
- override protected def batches: Seq[Batch] = Seq(
- Batch("Prepare", FixedPoint(100),
- // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over
- // `Aggregate`s to perform type casting. This rule merges these `Project`s into
- // `Aggregate`s.
- CollapseProject,
- // Parser is unable to parse the following query:
- // SELECT `u_1`.`id`
- // FROM (((SELECT `t0`.`id` FROM `default`.`t0`)
- // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`))
- // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1
- // This rule combine adjacent Unions together so we can generate flat UNION ALL SQL string.
- CombineUnions),
- Batch("Recover Scoping Info", Once,
- // A logical plan is allowed to have same-name outputs with different qualifiers(e.g. the
- // `Join` operator). However, this kind of plan can't be put under a sub query as we will
- // erase and assign a new qualifier to all outputs and make it impossible to distinguish
- // same-name outputs. This rule renames all attributes, to guarantee different
- // attributes(with different exprId) always have different names. It also removes all
- // qualifiers, as attributes have unique names now and we don't need qualifiers to resolve
- // ambiguity.
- NormalizedAttribute,
- // Our analyzer will add one or more sub-queries above table relation, this rule removes
- // these sub-queries so that next rule can combine adjacent table relation and sample to
- // SQLTable.
- RemoveSubqueriesAboveSQLTable,
- // Finds the table relations and wrap them with `SQLTable`s. If there are any `Sample`
- // operators on top of a table relation, merge the sample information into `SQLTable` of
- // that table relation, as we can only convert table sample to standard SQL string.
- ResolveSQLTable,
- // Insert sub queries on top of operators that need to appear after FROM clause.
- AddSubquery
- )
- )
-
- object NormalizedAttribute extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
- case a: AttributeReference =>
- AttributeReference(normalizedName(a), a.dataType)(exprId = a.exprId, qualifier = None)
- case a: Alias =>
- Alias(a.child, normalizedName(a))(exprId = a.exprId, qualifier = None)
- }
- }
-
- object RemoveSubqueriesAboveSQLTable extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case SubqueryAlias(_, t @ ExtractSQLTable(_)) => t
- }
- }
-
- object ResolveSQLTable extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown {
- case Sample(lowerBound, upperBound, _, _, ExtractSQLTable(table)) =>
- aliasColumns(table.withSample(lowerBound, upperBound))
- case ExtractSQLTable(table) =>
- aliasColumns(table)
- }
-
- /**
- * Aliases the table columns to the generated attribute names, as we use exprId to generate
- * unique name for each attribute when normalize attributes, and we can't reference table
- * columns with their real names.
- */
- private def aliasColumns(table: SQLTable): LogicalPlan = {
- val aliasedOutput = table.output.map { attr =>
- Alias(attr, normalizedName(attr))(exprId = attr.exprId)
- }
- addSubquery(Project(aliasedOutput, table))
- }
- }
-
- object AddSubquery extends Rule[LogicalPlan] {
- override def apply(tree: LogicalPlan): LogicalPlan = tree transformUp {
- // This branch handles aggregate functions within HAVING clauses. For example:
- //
- // SELECT key FROM src GROUP BY key HAVING max(value) > "val_255"
- //
- // This kind of query results in query plans of the following form because of analysis rule
- // `ResolveAggregateFunctions`:
- //
- // Project ...
- // +- Filter ...
- // +- Aggregate ...
- // +- MetastoreRelation default, src, None
- case p @ Project(_, f @ Filter(_, _: Aggregate)) => p.copy(child = addSubquery(f))
-
- case w @ Window(_, _, _, f @ Filter(_, _: Aggregate)) => w.copy(child = addSubquery(f))
-
- case p: Project => p.copy(child = addSubqueryIfNeeded(p.child))
-
- // We will generate "SELECT ... FROM ..." for Window operator, so its child operator should
- // be able to put in the FROM clause, or we wrap it with a subquery.
- case w: Window => w.copy(child = addSubqueryIfNeeded(w.child))
-
- case j: Join => j.copy(
- left = addSubqueryIfNeeded(j.left),
- right = addSubqueryIfNeeded(j.right))
-
- // A special case for Generate. When we put UDTF in project list, followed by WHERE, e.g.
- // SELECT EXPLODE(arr) FROM tbl WHERE id > 1, the Filter operator will be under Generate
- // operator and we need to add a sub-query between them, as it's not allowed to have a WHERE
- // before LATERAL VIEW, e.g. "... FROM tbl WHERE id > 2 EXPLODE(arr) ..." is illegal.
- case g @ Generate(_, _, _, _, _, f: Filter) =>
- // Add an extra `Project` to make sure we can generate legal SQL string for sub-query,
- // for example, Subquery -> Filter -> Table will generate "(tbl WHERE ...) AS name", which
- // misses the SELECT part.
- val proj = Project(f.output, f)
- g.copy(child = addSubquery(proj))
- }
- }
-
- private def addSubquery(plan: LogicalPlan): SubqueryAlias = {
- SubqueryAlias(newSubqueryName(), plan)
- }
-
- private def addSubqueryIfNeeded(plan: LogicalPlan): LogicalPlan = plan match {
- case _: SubqueryAlias => plan
- case _: Filter => plan
- case _: Join => plan
- case _: LocalLimit => plan
- case _: GlobalLimit => plan
- case _: SQLTable => plan
- case _: Generate => plan
- case OneRowRelation => plan
- case _ => addSubquery(plan)
- }
- }
-
- case class SQLTable(
- database: String,
- table: String,
- output: Seq[Attribute],
- sample: Option[(Double, Double)] = None) extends LeafNode {
- def withSample(lowerBound: Double, upperBound: Double): SQLTable =
- this.copy(sample = Some(lowerBound -> upperBound))
- }
-
- object ExtractSQLTable {
- def unapply(plan: LogicalPlan): Option[SQLTable] = plan match {
- case l @ LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) =>
- Some(SQLTable(database, table, l.output.map(_.withQualifier(None))))
-
- case relation: CatalogRelation =>
- val m = relation.catalogTable
- Some(SQLTable(m.database, m.identifier.table, relation.output.map(_.withQualifier(None))))
-
- case _ => None
- }
- }
-
- /**
- * A place holder for generated SQL for subquery expression.
- */
- case class SubqueryHolder(query: String) extends LeafExpression with Unevaluable {
- override def dataType: DataType = NullType
- override def nullable: Boolean = true
- override def sql: String = s"($query)"
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/aeb52bea/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala
index 1e234d8..fa830a1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.hive.execution
import scala.util.control.NonFatal
import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
+import org.apache.spark.sql.catalyst.SQLBuilder
import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.command.RunnableCommand
-import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveSessionState, SQLBuilder}
+import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveSessionState}
/**
* Create Hive view on non-hive-compatible tables by specifying schema ourselves instead of
@@ -128,7 +129,7 @@ private[hive] case class CreateViewAsSelect(
}
sqlContext.executePlan(Project(projectList, child)).analyzed
}
- new SQLBuilder(logicalPlan, sqlContext).toSQL
+ new SQLBuilder(logicalPlan).toSQL
}
// escape backtick with double-backtick in column name and wrap it with backtick.
http://git-wip-us.apache.org/repos/asf/spark/blob/aeb52bea/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala
new file mode 100644
index 0000000..c8bf20d
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import java.sql.Timestamp
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{If, Literal}
+
+
+class ExpressionSQLBuilderSuite extends SQLBuilderTest {
+ test("literal") {
+ checkSQL(Literal("foo"), "\"foo\"")
+ checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"")
+ checkSQL(Literal(1: Byte), "1Y")
+ checkSQL(Literal(2: Short), "2S")
+ checkSQL(Literal(4: Int), "4")
+ checkSQL(Literal(8: Long), "8L")
+ checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)")
+ checkSQL(Literal(2.5D), "2.5D")
+ checkSQL(
+ Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')")
+ // TODO tests for decimals
+ }
+
+ test("attributes") {
+ checkSQL('a.int, "`a`")
+ checkSQL(Symbol("foo bar").int, "`foo bar`")
+ // Keyword
+ checkSQL('int.int, "`int`")
+ }
+
+ test("binary comparisons") {
+ checkSQL('a.int === 'b.int, "(`a` = `b`)")
+ checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)")
+ checkSQL('a.int =!= 'b.int, "(NOT (`a` = `b`))")
+
+ checkSQL('a.int < 'b.int, "(`a` < `b`)")
+ checkSQL('a.int <= 'b.int, "(`a` <= `b`)")
+ checkSQL('a.int > 'b.int, "(`a` > `b`)")
+ checkSQL('a.int >= 'b.int, "(`a` >= `b`)")
+
+ checkSQL('a.int in ('b.int, 'c.int), "(`a` IN (`b`, `c`))")
+ checkSQL('a.int in (1, 2), "(`a` IN (1, 2))")
+
+ checkSQL('a.int.isNull, "(`a` IS NULL)")
+ checkSQL('a.int.isNotNull, "(`a` IS NOT NULL)")
+ }
+
+ test("logical operators") {
+ checkSQL('a.boolean && 'b.boolean, "(`a` AND `b`)")
+ checkSQL('a.boolean || 'b.boolean, "(`a` OR `b`)")
+ checkSQL(!'a.boolean, "(NOT `a`)")
+ checkSQL(If('a.boolean, 'b.int, 'c.int), "(IF(`a`, `b`, `c`))")
+ }
+
+ test("arithmetic expressions") {
+ checkSQL('a.int + 'b.int, "(`a` + `b`)")
+ checkSQL('a.int - 'b.int, "(`a` - `b`)")
+ checkSQL('a.int * 'b.int, "(`a` * `b`)")
+ checkSQL('a.int / 'b.int, "(`a` / `b`)")
+ checkSQL('a.int % 'b.int, "(`a` % `b`)")
+
+ checkSQL(-'a.int, "(-`a`)")
+ checkSQL(-('a.int + 'b.int), "(-(`a` + `b`))")
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/aeb52bea/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
new file mode 100644
index 0000000..a7782ab
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SQLTestUtils
+
+class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
+ import testImplicits._
+
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ sql("DROP TABLE IF EXISTS t0")
+ sql("DROP TABLE IF EXISTS t1")
+ sql("DROP TABLE IF EXISTS t2")
+
+ val bytes = Array[Byte](1, 2, 3, 4)
+ Seq((bytes, "AQIDBA==")).toDF("a", "b").write.saveAsTable("t0")
+
+ sqlContext
+ .range(10)
+ .select('id as 'key, concat(lit("val_"), 'id) as 'value)
+ .write
+ .saveAsTable("t1")
+
+ sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2")
+ }
+
+ override protected def afterAll(): Unit = {
+ try {
+ sql("DROP TABLE IF EXISTS t0")
+ sql("DROP TABLE IF EXISTS t1")
+ sql("DROP TABLE IF EXISTS t2")
+ } finally {
+ super.afterAll()
+ }
+ }
+
+ private def checkSqlGeneration(hiveQl: String): Unit = {
+ val df = sql(hiveQl)
+
+ val convertedSQL = try new SQLBuilder(df).toSQL catch {
+ case NonFatal(e) =>
+ fail(
+ s"""Cannot convert the following HiveQL query plan back to SQL query string:
+ |
+ |# Original HiveQL query string:
+ |$hiveQl
+ |
+ |# Resolved query plan:
+ |${df.queryExecution.analyzed.treeString}
+ """.stripMargin)
+ }
+
+ try {
+ checkAnswer(sql(convertedSQL), df)
+ } catch { case cause: Throwable =>
+ fail(
+ s"""Failed to execute converted SQL string or got wrong answer:
+ |
+ |# Converted SQL query string:
+ |$convertedSQL
+ |
+ |# Original HiveQL query string:
+ |$hiveQl
+ |
+ |# Resolved query plan:
+ |${df.queryExecution.analyzed.treeString}
+ """.stripMargin,
+ cause)
+ }
+ }
+
+ test("misc non-aggregate functions") {
+ checkSqlGeneration("SELECT abs(15), abs(-15)")
+ checkSqlGeneration("SELECT array(1,2,3)")
+ checkSqlGeneration("SELECT coalesce(null, 1, 2)")
+ checkSqlGeneration("SELECT explode(array(1,2,3))")
+ checkSqlGeneration("SELECT greatest(1,null,3)")
+ checkSqlGeneration("SELECT if(1==2, 'yes', 'no')")
+ checkSqlGeneration("SELECT isnan(15), isnan('invalid')")
+ checkSqlGeneration("SELECT isnull(null), isnull('a')")
+ checkSqlGeneration("SELECT isnotnull(null), isnotnull('a')")
+ checkSqlGeneration("SELECT least(1,null,3)")
+ checkSqlGeneration("SELECT map(1, 'a', 2, 'b')")
+ checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)")
+ checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2")
+ checkSqlGeneration("SELECT nvl(null, 1, 2)")
+ checkSqlGeneration("SELECT rand(1)")
+ checkSqlGeneration("SELECT randn(3)")
+ checkSqlGeneration("SELECT struct(1,2,3)")
+ }
+
+ test("math functions") {
+ checkSqlGeneration("SELECT acos(-1)")
+ checkSqlGeneration("SELECT asin(-1)")
+ checkSqlGeneration("SELECT atan(1)")
+ checkSqlGeneration("SELECT atan2(1, 1)")
+ checkSqlGeneration("SELECT bin(10)")
+ checkSqlGeneration("SELECT cbrt(1000.0)")
+ checkSqlGeneration("SELECT ceil(2.333)")
+ checkSqlGeneration("SELECT ceiling(2.333)")
+ checkSqlGeneration("SELECT cos(1.0)")
+ checkSqlGeneration("SELECT cosh(1.0)")
+ checkSqlGeneration("SELECT conv(15, 10, 16)")
+ checkSqlGeneration("SELECT degrees(pi())")
+ checkSqlGeneration("SELECT e()")
+ checkSqlGeneration("SELECT exp(1.0)")
+ checkSqlGeneration("SELECT expm1(1.0)")
+ checkSqlGeneration("SELECT floor(-2.333)")
+ checkSqlGeneration("SELECT factorial(5)")
+ checkSqlGeneration("SELECT hex(10)")
+ checkSqlGeneration("SELECT hypot(3, 4)")
+ checkSqlGeneration("SELECT log(10.0)")
+ checkSqlGeneration("SELECT log10(1000.0)")
+ checkSqlGeneration("SELECT log1p(0.0)")
+ checkSqlGeneration("SELECT log2(8.0)")
+ checkSqlGeneration("SELECT ln(10.0)")
+ checkSqlGeneration("SELECT negative(-1)")
+ checkSqlGeneration("SELECT pi()")
+ checkSqlGeneration("SELECT pmod(3, 2)")
+ checkSqlGeneration("SELECT positive(3)")
+ checkSqlGeneration("SELECT pow(2, 3)")
+ checkSqlGeneration("SELECT power(2, 3)")
+ checkSqlGeneration("SELECT radians(180.0)")
+ checkSqlGeneration("SELECT rint(1.63)")
+ checkSqlGeneration("SELECT round(31.415, -1)")
+ checkSqlGeneration("SELECT shiftleft(2, 3)")
+ checkSqlGeneration("SELECT shiftright(16, 3)")
+ checkSqlGeneration("SELECT shiftrightunsigned(16, 3)")
+ checkSqlGeneration("SELECT sign(-2.63)")
+ checkSqlGeneration("SELECT signum(-2.63)")
+ checkSqlGeneration("SELECT sin(1.0)")
+ checkSqlGeneration("SELECT sinh(1.0)")
+ checkSqlGeneration("SELECT sqrt(100.0)")
+ checkSqlGeneration("SELECT tan(1.0)")
+ checkSqlGeneration("SELECT tanh(1.0)")
+ }
+
+ test("aggregate functions") {
+ checkSqlGeneration("SELECT approx_count_distinct(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT avg(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT corr(value, key) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT count(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT covar_pop(value, key) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT covar_samp(value, key) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT first(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT first_value(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT kurtosis(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT last(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT last_value(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT max(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT mean(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT min(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT skewness(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT stddev(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT stddev_pop(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT stddev_samp(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT sum(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT variance(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT var_pop(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT var_samp(value) FROM t1 GROUP BY key")
+ }
+
+ test("string functions") {
+ checkSqlGeneration("SELECT ascii('SparkSql')")
+ checkSqlGeneration("SELECT base64(a) FROM t0")
+ checkSqlGeneration("SELECT concat('This ', 'is ', 'a ', 'test')")
+ checkSqlGeneration("SELECT concat_ws(' ', 'This', 'is', 'a', 'test')")
+ checkSqlGeneration("SELECT decode(a, 'UTF-8') FROM t0")
+ checkSqlGeneration("SELECT encode('SparkSql', 'UTF-8')")
+ checkSqlGeneration("SELECT find_in_set('ab', 'abc,b,ab,c,def')")
+ checkSqlGeneration("SELECT format_number(1234567.890, 2)")
+ checkSqlGeneration("SELECT format_string('aa%d%s',123, 'cc')")
+ checkSqlGeneration("SELECT get_json_object('{\"a\":\"bc\"}','$.a')")
+ checkSqlGeneration("SELECT initcap('This is a test')")
+ checkSqlGeneration("SELECT instr('This is a test', 'is')")
+ checkSqlGeneration("SELECT lcase('SparkSql')")
+ checkSqlGeneration("SELECT length('This is a test')")
+ checkSqlGeneration("SELECT levenshtein('This is a test', 'Another test')")
+ checkSqlGeneration("SELECT lower('SparkSql')")
+ checkSqlGeneration("SELECT locate('is', 'This is a test', 3)")
+ checkSqlGeneration("SELECT lpad('SparkSql', 16, 'Learning')")
+ checkSqlGeneration("SELECT ltrim(' SparkSql ')")
+ checkSqlGeneration("SELECT json_tuple('{\"f1\": \"value1\", \"f2\": \"value2\"}','f1')")
+ checkSqlGeneration("SELECT printf('aa%d%s', 123, 'cc')")
+ checkSqlGeneration("SELECT regexp_extract('100-200', '(\\d+)-(\\d+)', 1)")
+ checkSqlGeneration("SELECT regexp_replace('100-200', '(\\d+)', 'num')")
+ checkSqlGeneration("SELECT repeat('SparkSql', 3)")
+ checkSqlGeneration("SELECT reverse('SparkSql')")
+ checkSqlGeneration("SELECT rpad('SparkSql', 16, ' is Cool')")
+ checkSqlGeneration("SELECT rtrim(' SparkSql ')")
+ checkSqlGeneration("SELECT soundex('SparkSql')")
+ checkSqlGeneration("SELECT space(2)")
+ checkSqlGeneration("SELECT split('aa2bb3cc', '[1-9]+')")
+ checkSqlGeneration("SELECT space(2)")
+ checkSqlGeneration("SELECT substr('This is a test', 1)")
+ checkSqlGeneration("SELECT substring('This is a test', 1)")
+ checkSqlGeneration("SELECT substring_index('www.apache.org','.',1)")
+ checkSqlGeneration("SELECT translate('translate', 'rnlt', '123')")
+ checkSqlGeneration("SELECT trim(' SparkSql ')")
+ checkSqlGeneration("SELECT ucase('SparkSql')")
+ checkSqlGeneration("SELECT unbase64('SparkSql')")
+ checkSqlGeneration("SELECT unhex(41)")
+ checkSqlGeneration("SELECT upper('SparkSql')")
+ }
+
+ test("datetime functions") {
+ checkSqlGeneration("SELECT add_months('2001-03-31', 1)")
+ checkSqlGeneration("SELECT count(current_date())")
+ checkSqlGeneration("SELECT count(current_timestamp())")
+ checkSqlGeneration("SELECT datediff('2001-01-02', '2001-01-01')")
+ checkSqlGeneration("SELECT date_add('2001-01-02', 1)")
+ checkSqlGeneration("SELECT date_format('2001-05-02', 'yyyy-dd')")
+ checkSqlGeneration("SELECT date_sub('2001-01-02', 1)")
+ checkSqlGeneration("SELECT day('2001-05-02')")
+ checkSqlGeneration("SELECT dayofyear('2001-05-02')")
+ checkSqlGeneration("SELECT dayofmonth('2001-05-02')")
+ checkSqlGeneration("SELECT from_unixtime(1000, 'yyyy-MM-dd HH:mm:ss')")
+ checkSqlGeneration("SELECT from_utc_timestamp('2015-07-24 00:00:00', 'PST')")
+ checkSqlGeneration("SELECT hour('11:35:55')")
+ checkSqlGeneration("SELECT last_day('2001-01-01')")
+ checkSqlGeneration("SELECT minute('11:35:55')")
+ checkSqlGeneration("SELECT month('2001-05-02')")
+ checkSqlGeneration("SELECT months_between('2001-10-30 10:30:00', '1996-10-30')")
+ checkSqlGeneration("SELECT next_day('2001-05-02', 'TU')")
+ checkSqlGeneration("SELECT count(now())")
+ checkSqlGeneration("SELECT quarter('2001-05-02')")
+ checkSqlGeneration("SELECT second('11:35:55')")
+ checkSqlGeneration("SELECT to_date('2001-10-30 10:30:00')")
+ checkSqlGeneration("SELECT to_unix_timestamp('2015-07-24 00:00:00', 'yyyy-MM-dd HH:mm:ss')")
+ checkSqlGeneration("SELECT to_utc_timestamp('2015-07-24 00:00:00', 'PST')")
+ checkSqlGeneration("SELECT trunc('2001-10-30 10:30:00', 'YEAR')")
+ checkSqlGeneration("SELECT unix_timestamp('2001-10-30 10:30:00')")
+ checkSqlGeneration("SELECT weekofyear('2001-05-02')")
+ checkSqlGeneration("SELECT year('2001-05-02')")
+
+ checkSqlGeneration("SELECT interval 3 years - 3 month 7 week 123 microseconds as i")
+ }
+
+ test("collection functions") {
+ checkSqlGeneration("SELECT array_contains(array(2, 9, 8), 9)")
+ checkSqlGeneration("SELECT size(array('b', 'd', 'c', 'a'))")
+ checkSqlGeneration("SELECT sort_array(array('b', 'd', 'c', 'a'))")
+ }
+
+ test("misc functions") {
+ checkSqlGeneration("SELECT crc32('Spark')")
+ checkSqlGeneration("SELECT md5('Spark')")
+ checkSqlGeneration("SELECT hash('Spark')")
+ checkSqlGeneration("SELECT sha('Spark')")
+ checkSqlGeneration("SELECT sha1('Spark')")
+ checkSqlGeneration("SELECT sha2('Spark', 0)")
+ checkSqlGeneration("SELECT spark_partition_id()")
+ checkSqlGeneration("SELECT input_file_name()")
+ checkSqlGeneration("SELECT monotonically_increasing_id()")
+ }
+
+ test("subquery") {
+ checkSqlGeneration("SELECT 1 + (SELECT 2)")
+ checkSqlGeneration("SELECT 1 + (SELECT 2 + (SELECT 3 as a))")
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org