You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/03/31 18:25:19 UTC
[4/8] spark git commit: [SPARK-14211][SQL] Remove ANTLR3 based parser
http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
deleted file mode 100644
index c188c5b..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
+++ /dev/null
@@ -1,933 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.parser
-
-import java.sql.Date
-
-import scala.collection.mutable.ArrayBuffer
-import scala.util.matching.Regex
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.Count
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
-import org.apache.spark.util.random.RandomSampler
-
-
-/**
- * This class translates SQL to Catalyst [[LogicalPlan]]s or [[Expression]]s.
- */
-private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends ParserInterface {
- import ParserUtils._
-
- /**
- * The safeParse method allows a user to focus on the parsing/AST transformation logic. This
- * method will take care of possible errors during the parsing process.
- */
- protected def safeParse[T](sql: String, ast: ASTNode)(toResult: ASTNode => T): T = {
- try {
- toResult(ast)
- } catch {
- case e: MatchError => throw e
- case e: AnalysisException => throw e
- case e: Exception =>
- throw new AnalysisException(e.getMessage)
- case e: NotImplementedError =>
- throw new AnalysisException(
- s"""Unsupported language features in query
- |== SQL ==
- |$sql
- |== AST ==
- |${ast.treeString}
- |== Error ==
- |$e
- |== Stacktrace ==
- |${e.getStackTrace.head}
- """.stripMargin)
- }
- }
-
- /** Creates LogicalPlan for a given SQL string. */
- def parsePlan(sql: String): LogicalPlan =
- safeParse(sql, ParseDriver.parsePlan(sql, conf))(nodeToPlan)
-
- /** Creates Expression for a given SQL string. */
- def parseExpression(sql: String): Expression =
- safeParse(sql, ParseDriver.parseExpression(sql, conf))(selExprNodeToExpr(_).get)
-
- /** Creates TableIdentifier for a given SQL string. */
- def parseTableIdentifier(sql: String): TableIdentifier =
- safeParse(sql, ParseDriver.parseTableName(sql, conf))(extractTableIdent)
-
- /**
- * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2))
- * is equivalent to
- * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2
- * Check the following link for details.
- *
-https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup
- *
- * The bitmask denotes the grouping expressions validity for a grouping set,
- * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive)
- * e.g. In superset (k1, k2, k3), (bit 2: k1, bit 1: k2, and bit 0: k3), the grouping id of
- * GROUPING SETS (k1, k2) and (k2) should be 1 and 5 respectively.
- */
- protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = {
- val (keyASTs, setASTs) = children.partition {
- case Token("TOK_GROUPING_SETS_EXPRESSION", _) => false // grouping sets
- case _ => true // grouping keys
- }
-
- val keys = keyASTs.map(nodeToExpr)
- val keyMap = keyASTs.zipWithIndex.toMap
-
- val mask = (1 << keys.length) - 1
- val bitmasks: Seq[Int] = setASTs.map {
- case Token("TOK_GROUPING_SETS_EXPRESSION", columns) =>
- columns.foldLeft(mask)((bitmap, col) => {
- val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2).getOrElse(
- throw new AnalysisException(s"${col.treeString} doesn't show up in the GROUP BY list"))
- // 0 means that the column at the given index is a grouping column, 1 means it is not,
- // so we unset the bit in bitmap.
- bitmap & ~(1 << (keys.length - 1 - keyIndex))
- })
- case _ => sys.error("Expect GROUPING SETS clause")
- }
-
- (keys, bitmasks)
- }
-
- protected def nodeToPlan(node: ASTNode): LogicalPlan = node match {
- case Token("TOK_SHOWFUNCTIONS", args) =>
- // Skip LIKE.
- val pattern = args match {
- case like :: nodes if like.text.toUpperCase == "LIKE" => nodes
- case nodes => nodes
- }
-
- // Extract Database and Function name
- pattern match {
- case Nil =>
- ShowFunctions(None, None)
- case Token(name, Nil) :: Nil =>
- ShowFunctions(None, Some(unquoteString(cleanIdentifier(name))))
- case Token(db, Nil) :: Token(name, Nil) :: Nil =>
- ShowFunctions(Some(unquoteString(cleanIdentifier(db))),
- Some(unquoteString(cleanIdentifier(name))))
- case _ =>
- noParseRule("SHOW FUNCTIONS", node)
- }
-
- case Token("TOK_DESCFUNCTION", Token(functionName, Nil) :: isExtended) =>
- DescribeFunction(cleanIdentifier(functionName), isExtended.nonEmpty)
-
- case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) =>
- val (fromClause: Option[ASTNode], insertClauses, cteRelations) =
- queryArgs match {
- case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts =>
- val cteRelations = ctes.map { node =>
- val relation = nodeToRelation(node).asInstanceOf[SubqueryAlias]
- relation.alias -> relation
- }
- (Some(from.head), inserts, Some(cteRelations.toMap))
- case Token("TOK_FROM", from) :: inserts =>
- (Some(from.head), inserts, None)
- case Token("TOK_INSERT", _) :: Nil =>
- (None, queryArgs, None)
- }
-
- // Return one query for each insert clause.
- val queries = insertClauses.map {
- case Token("TOK_INSERT", singleInsert) =>
- val (
- intoClause ::
- destClause ::
- selectClause ::
- selectDistinctClause ::
- whereClause ::
- groupByClause ::
- rollupGroupByClause ::
- cubeGroupByClause ::
- groupingSetsClause ::
- orderByClause ::
- havingClause ::
- sortByClause ::
- clusterByClause ::
- distributeByClause ::
- limitClause ::
- lateralViewClause ::
- windowClause :: Nil) = {
- getClauses(
- Seq(
- "TOK_INSERT_INTO",
- "TOK_DESTINATION",
- "TOK_SELECT",
- "TOK_SELECTDI",
- "TOK_WHERE",
- "TOK_GROUPBY",
- "TOK_ROLLUP_GROUPBY",
- "TOK_CUBE_GROUPBY",
- "TOK_GROUPING_SETS",
- "TOK_ORDERBY",
- "TOK_HAVING",
- "TOK_SORTBY",
- "TOK_CLUSTERBY",
- "TOK_DISTRIBUTEBY",
- "TOK_LIMIT",
- "TOK_LATERAL_VIEW",
- "WINDOW"),
- singleInsert)
- }
-
- val relations = fromClause match {
- case Some(f) => nodeToRelation(f)
- case None => OneRowRelation
- }
-
- val withLateralView = lateralViewClause.map { lv =>
- nodeToGenerate(lv.children.head, outer = false, relations)
- }.getOrElse(relations)
-
- val withWhere = whereClause.map { whereNode =>
- val Seq(whereExpr) = whereNode.children
- Filter(nodeToExpr(whereExpr), withLateralView)
- }.getOrElse(withLateralView)
-
- val select = (selectClause orElse selectDistinctClause)
- .getOrElse(sys.error("No select clause."))
-
- val transformation = nodeToTransformation(select.children.head, withWhere)
-
- // The projection of the query can either be a normal projection, an aggregation
- // (if there is a group by) or a script transformation.
- val withProject: LogicalPlan = transformation.getOrElse {
- val selectExpressions =
- select.children.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_))
- Seq(
- groupByClause.map(e => e match {
- case Token("TOK_GROUPBY", children) =>
- // Not a transformation so must be either project or aggregation.
- Aggregate(children.map(nodeToExpr), selectExpressions, withWhere)
- case _ => sys.error("Expect GROUP BY")
- }),
- groupingSetsClause.map(e => e match {
- case Token("TOK_GROUPING_SETS", children) =>
- val(groupByExprs, masks) = extractGroupingSet(children)
- GroupingSets(masks, groupByExprs, withWhere, selectExpressions)
- case _ => sys.error("Expect GROUPING SETS")
- }),
- rollupGroupByClause.map(e => e match {
- case Token("TOK_ROLLUP_GROUPBY", children) =>
- Aggregate(
- Seq(Rollup(children.map(nodeToExpr))),
- selectExpressions,
- withWhere)
- case _ => sys.error("Expect WITH ROLLUP")
- }),
- cubeGroupByClause.map(e => e match {
- case Token("TOK_CUBE_GROUPBY", children) =>
- Aggregate(
- Seq(Cube(children.map(nodeToExpr))),
- selectExpressions,
- withWhere)
- case _ => sys.error("Expect WITH CUBE")
- }),
- Some(Project(selectExpressions, withWhere))).flatten.head
- }
-
- // Handle HAVING clause.
- val withHaving = havingClause.map { h =>
- val havingExpr = h.children match { case Seq(hexpr) => nodeToExpr(hexpr) }
- // Note that we added a cast to boolean. If the expression itself is already boolean,
- // the optimizer will get rid of the unnecessary cast.
- Filter(Cast(havingExpr, BooleanType), withProject)
- }.getOrElse(withProject)
-
- // Handle SELECT DISTINCT
- val withDistinct =
- if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving
-
- // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause.
- val withSort =
- (orderByClause, sortByClause, distributeByClause, clusterByClause) match {
- case (Some(totalOrdering), None, None, None) =>
- Sort(totalOrdering.children.map(nodeToSortOrder), global = true, withDistinct)
- case (None, Some(perPartitionOrdering), None, None) =>
- Sort(
- perPartitionOrdering.children.map(nodeToSortOrder),
- global = false, withDistinct)
- case (None, None, Some(partitionExprs), None) =>
- RepartitionByExpression(
- partitionExprs.children.map(nodeToExpr), withDistinct)
- case (None, Some(perPartitionOrdering), Some(partitionExprs), None) =>
- Sort(
- perPartitionOrdering.children.map(nodeToSortOrder), global = false,
- RepartitionByExpression(
- partitionExprs.children.map(nodeToExpr),
- withDistinct))
- case (None, None, None, Some(clusterExprs)) =>
- Sort(
- clusterExprs.children.map(nodeToExpr).map(SortOrder(_, Ascending)),
- global = false,
- RepartitionByExpression(
- clusterExprs.children.map(nodeToExpr),
- withDistinct))
- case (None, None, None, None) => withDistinct
- case _ => sys.error("Unsupported set of ordering / distribution clauses.")
- }
-
- val withLimit =
- limitClause.map(l => nodeToExpr(l.children.head))
- .map(Limit(_, withSort))
- .getOrElse(withSort)
-
- // Collect all window specifications defined in the WINDOW clause.
- val windowDefinitions = windowClause.map(_.children.collect {
- case Token("TOK_WINDOWDEF",
- Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) =>
- windowName -> nodesToWindowSpecification(spec)
- }.toMap)
- // Handle cases like
- // window w1 as (partition by p_mfgr order by p_name
- // range between 2 preceding and 2 following),
- // w2 as w1
- val resolvedCrossReference = windowDefinitions.map {
- windowDefMap => windowDefMap.map {
- case (windowName, WindowSpecReference(other)) =>
- (windowName, windowDefMap(other).asInstanceOf[WindowSpecDefinition])
- case o => o.asInstanceOf[(String, WindowSpecDefinition)]
- }
- }
-
- val withWindowDefinitions =
- resolvedCrossReference.map(WithWindowDefinition(_, withLimit)).getOrElse(withLimit)
-
- // TOK_INSERT_INTO means to add files to the table.
- // TOK_DESTINATION means to overwrite the table.
- val resultDestination =
- (intoClause orElse destClause).getOrElse(sys.error("No destination found."))
- val overwrite = intoClause.isEmpty
- nodeToDest(
- resultDestination,
- withWindowDefinitions,
- overwrite)
- }
-
- // If there are multiple INSERTS just UNION them together into one query.
- val query = if (queries.length == 1) queries.head else Union(queries)
-
- // return With plan if there is CTE
- cteRelations.map(With(query, _)).getOrElse(query)
-
- case Token("TOK_UNIONALL", left :: right :: Nil) =>
- Union(nodeToPlan(left), nodeToPlan(right))
- case Token("TOK_UNIONDISTINCT", left :: right :: Nil) =>
- Distinct(Union(nodeToPlan(left), nodeToPlan(right)))
- case Token("TOK_EXCEPT", left :: right :: Nil) =>
- Except(nodeToPlan(left), nodeToPlan(right))
- case Token("TOK_INTERSECT", left :: right :: Nil) =>
- Intersect(nodeToPlan(left), nodeToPlan(right))
-
- case _ =>
- noParseRule("Plan", node)
- }
-
- val allJoinTokens = "(TOK_.*JOIN)".r
- val laterViewToken = "TOK_LATERAL_VIEW(.*)".r
- protected def nodeToRelation(node: ASTNode): LogicalPlan = {
- node match {
- case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) =>
- SubqueryAlias(cleanIdentifier(alias), nodeToPlan(query))
-
- case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) =>
- nodeToGenerate(
- selectClause,
- outer = isOuter.nonEmpty,
- nodeToRelation(relationClause))
-
- /* All relations, possibly with aliases or sampling clauses. */
- case Token("TOK_TABREF", clauses) =>
- // If the last clause is not a token then it's the alias of the table.
- val (nonAliasClauses, aliasClause) =
- if (clauses.last.text.startsWith("TOK")) {
- (clauses, None)
- } else {
- (clauses.dropRight(1), Some(clauses.last))
- }
-
- val (Some(tableNameParts) ::
- splitSampleClause ::
- bucketSampleClause :: Nil) = {
- getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"),
- nonAliasClauses)
- }
-
- val tableIdent = extractTableIdent(tableNameParts)
- val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) }
- val relation = UnresolvedRelation(tableIdent, alias)
-
- // Apply sampling if requested.
- (bucketSampleClause orElse splitSampleClause).map {
- case Token("TOK_TABLESPLITSAMPLE",
- Token("TOK_ROWCOUNT", Nil) :: Token(count, Nil) :: Nil) =>
- Limit(Literal(count.toInt), relation)
- case Token("TOK_TABLESPLITSAMPLE",
- Token("TOK_PERCENT", Nil) :: Token(fraction, Nil) :: Nil) =>
- // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling
- // function takes X PERCENT as the input and the range of X is [0, 100], we need to
- // adjust the fraction.
- require(
- fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon)
- && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon),
- s"Sampling fraction ($fraction) must be on interval [0, 100]")
- Sample(0.0, fraction.toDouble / 100, withReplacement = false,
- (math.random * 1000).toInt,
- relation)(
- isTableSample = true)
- case Token("TOK_TABLEBUCKETSAMPLE",
- Token(numerator, Nil) ::
- Token(denominator, Nil) :: Nil) =>
- val fraction = numerator.toDouble / denominator.toDouble
- Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)(
- isTableSample = true)
- case a =>
- noParseRule("Sampling", a)
- }.getOrElse(relation)
-
- case Token(allJoinTokens(joinToken), relation1 :: relation2 :: other) =>
- if (!(other.size <= 1)) {
- sys.error(s"Unsupported join operation: $other")
- }
-
- val (joinType, joinCondition) = getJoinInfo(joinToken, other, node)
-
- Join(nodeToRelation(relation1),
- nodeToRelation(relation2),
- joinType,
- joinCondition)
- case _ =>
- noParseRule("Relation", node)
- }
- }
-
- protected def getJoinInfo(
- joinToken: String,
- joinConditionToken: Seq[ASTNode],
- node: ASTNode): (JoinType, Option[Expression]) = {
- val joinType = joinToken match {
- case "TOK_JOIN" => Inner
- case "TOK_CROSSJOIN" => Inner
- case "TOK_RIGHTOUTERJOIN" => RightOuter
- case "TOK_LEFTOUTERJOIN" => LeftOuter
- case "TOK_FULLOUTERJOIN" => FullOuter
- case "TOK_LEFTSEMIJOIN" => LeftSemi
- case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node)
- case "TOK_ANTIJOIN" => noParseRule("Anti Join", node)
- case "TOK_NATURALJOIN" => NaturalJoin(Inner)
- case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter)
- case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter)
- case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter)
- }
-
- joinConditionToken match {
- case Token("TOK_USING", columnList :: Nil) :: Nil =>
- val colNames = columnList.children.collect {
- case Token(name, Nil) => UnresolvedAttribute(name)
- }
- (UsingJoin(joinType, colNames), None)
- /* Join expression specified using ON clause */
- case _ => (joinType, joinConditionToken.headOption.map(nodeToExpr))
- }
- }
-
- protected def nodeToSortOrder(node: ASTNode): SortOrder = node match {
- case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) =>
- SortOrder(nodeToExpr(sortExpr), Ascending)
- case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) =>
- SortOrder(nodeToExpr(sortExpr), Descending)
- case _ =>
- noParseRule("SortOrder", node)
- }
-
- val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r
- protected def nodeToDest(
- node: ASTNode,
- query: LogicalPlan,
- overwrite: Boolean): LogicalPlan = node match {
- case Token(destinationToken(),
- Token("TOK_DIR",
- Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) =>
- query
-
- case Token(destinationToken(),
- Token("TOK_TAB",
- tableArgs) :: Nil) =>
- val Some(tableNameParts) :: partitionClause :: Nil =
- getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs)
-
- val tableIdent = extractTableIdent(tableNameParts)
-
- val partitionKeys = partitionClause.map(_.children.map {
- // Parse partitions. We also make keys case insensitive.
- case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) =>
- cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value))
- case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) =>
- cleanIdentifier(key.toLowerCase) -> None
- }.toMap).getOrElse(Map.empty)
-
- InsertIntoTable(
- UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = false)
-
- case Token(destinationToken(),
- Token("TOK_TAB",
- tableArgs) ::
- Token("TOK_IFNOTEXISTS",
- ifNotExists) :: Nil) =>
- val Some(tableNameParts) :: partitionClause :: Nil =
- getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs)
-
- val tableIdent = extractTableIdent(tableNameParts)
-
- val partitionKeys = partitionClause.map(_.children.map {
- // Parse partitions. We also make keys case insensitive.
- case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) =>
- cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value))
- case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) =>
- cleanIdentifier(key.toLowerCase) -> None
- }.toMap).getOrElse(Map.empty)
-
- InsertIntoTable(
- UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = true)
-
- case _ =>
- noParseRule("Destination", node)
- }
-
- protected def selExprNodeToExpr(node: ASTNode): Option[Expression] = node match {
- case Token("TOK_SELEXPR", e :: Nil) =>
- Some(nodeToExpr(e))
-
- case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) =>
- Some(Alias(nodeToExpr(e), cleanIdentifier(alias))())
-
- case Token("TOK_SELEXPR", e :: aliasChildren) =>
- val aliasNames = aliasChildren.collect {
- case Token(name, Nil) => cleanIdentifier(name)
- }
- Some(MultiAlias(nodeToExpr(e), aliasNames))
-
- /* Hints are ignored */
- case Token("TOK_HINTLIST", _) => None
-
- case _ =>
- noParseRule("Select", node)
- }
-
- /**
- * Flattens the left deep tree with the specified pattern into a list.
- */
- private def flattenLeftDeepTree(node: ASTNode, pattern: Regex): Seq[ASTNode] = {
- val collected = ArrayBuffer[ASTNode]()
- var rest = node
- while (rest match {
- case Token(pattern(), l :: r :: Nil) =>
- collected += r
- rest = l
- true
- case _ => false
- }) {
- // do nothing
- }
- collected += rest
- // keep them in the same order as in SQL
- collected.reverse
- }
-
- /**
- * Creates a balanced tree that has similar number of nodes on left and right.
- *
- * This help to reduce the depth of the tree to prevent StackOverflow in analyzer/optimizer.
- */
- private def balancedTree(
- expr: Seq[Expression],
- f: (Expression, Expression) => Expression): Expression = expr.length match {
- case 1 => expr.head
- case 2 => f(expr.head, expr(1))
- case l => f(balancedTree(expr.slice(0, l / 2), f), balancedTree(expr.slice(l / 2, l), f))
- }
-
- protected def nodeToExpr(node: ASTNode): Expression = node match {
- /* Attribute References */
- case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) =>
- UnresolvedAttribute.quoted(cleanIdentifier(name))
- case Token(".", qualifier :: Token(attr, Nil) :: Nil) =>
- nodeToExpr(qualifier) match {
- case UnresolvedAttribute(nameParts) =>
- UnresolvedAttribute(nameParts :+ cleanIdentifier(attr))
- case other => UnresolvedExtractValue(other, Literal(cleanIdentifier(attr)))
- }
- case Token("TOK_SUBQUERY_EXPR", Token("TOK_SUBQUERY_OP", Nil) :: subquery :: Nil) =>
- ScalarSubquery(nodeToPlan(subquery))
-
- /* Stars (*) */
- case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None)
- // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only
- // has a single child which is tableName.
- case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", target) :: Nil) if target.nonEmpty =>
- UnresolvedStar(Some(target.map(x => cleanIdentifier(x.text))))
-
- /* Aggregate Functions */
- case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) =>
- Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true)
- case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) =>
- Count(Literal(1)).toAggregateExpression()
-
- /* Casts */
- case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), StringType)
- case Token("TOK_FUNCTION", Token("TOK_VARCHAR", _) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), StringType)
- case Token("TOK_FUNCTION", Token("TOK_CHAR", _) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), StringType)
- case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), IntegerType)
- case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), LongType)
- case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), FloatType)
- case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DoubleType)
- case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), ShortType)
- case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), ByteType)
- case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), BinaryType)
- case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), BooleanType)
- case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, scale.text.toInt))
- case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, 0))
- case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT)
- case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), TimestampType)
- case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DateType)
-
- /* Arithmetic */
- case Token("+", child :: Nil) => nodeToExpr(child)
- case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child))
- case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child))
- case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right))
- case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right))
- case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right))
- case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
- case Token(DIV(), left :: right:: Nil) =>
- Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType)
- case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right))
- case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right))
- case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right))
- case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right))
-
- /* Comparisons */
- case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right))
- case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right))
- case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right))
- case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right)))
- case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right)))
- case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right))
- case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right))
- case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right))
- case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right))
- case Token(LIKE(), left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right))
- case Token(RLIKE(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right))
- case Token(REGEXP(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right))
- case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) =>
- IsNotNull(nodeToExpr(child))
- case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) =>
- IsNull(nodeToExpr(child))
- case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) =>
- In(nodeToExpr(value), list.map(nodeToExpr))
- case Token("TOK_FUNCTION",
- Token(BETWEEN(), Nil) ::
- kw ::
- target ::
- minValue ::
- maxValue :: Nil) =>
-
- val targetExpression = nodeToExpr(target)
- val betweenExpr =
- And(
- GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)),
- LessThanOrEqual(targetExpression, nodeToExpr(maxValue)))
- kw match {
- case Token("KW_FALSE", Nil) => betweenExpr
- case Token("KW_TRUE", Nil) => Not(betweenExpr)
- }
-
- /* Boolean Logic */
- case Token(AND(), left :: right:: Nil) =>
- balancedTree(flattenLeftDeepTree(node, AND).map(nodeToExpr), And)
- case Token(OR(), left :: right:: Nil) =>
- balancedTree(flattenLeftDeepTree(node, OR).map(nodeToExpr), Or)
- case Token(NOT(), child :: Nil) => Not(nodeToExpr(child))
- case Token("!", child :: Nil) => Not(nodeToExpr(child))
-
- /* Case statements */
- case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
- CaseWhen.createFromParser(branches.map(nodeToExpr))
- case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
- val keyExpr = nodeToExpr(branches.head)
- CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr))
-
- /* Complex datatype manipulation */
- case Token("[", child :: ordinal :: Nil) =>
- UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal))
-
- /* Window Functions */
- case Token(text, args :+ Token("TOK_WINDOWSPEC", spec)) =>
- val function = nodeToExpr(node.copy(children = node.children.init))
- nodesToWindowSpecification(spec) match {
- case reference: WindowSpecReference =>
- UnresolvedWindowExpression(function, reference)
- case definition: WindowSpecDefinition =>
- WindowExpression(function, definition)
- }
-
- /* UDFs - Must be last otherwise will preempt built in functions */
- case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>
- UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false)
- // Aggregate function with DISTINCT keyword.
- case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) =>
- UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true)
- case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) =>
- UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false)
-
- /* Literals */
- case Token("TOK_NULL", Nil) => Literal.create(null, NullType)
- case Token(TRUE(), Nil) => Literal.create(true, BooleanType)
- case Token(FALSE(), Nil) => Literal.create(false, BooleanType)
- case Token("TOK_STRINGLITERALSEQUENCE", strings) =>
- Literal(strings.map(s => ParseUtils.unescapeSQLString(s.text)).mkString)
-
- case ast if ast.tokenType == SparkSqlParser.TinyintLiteral =>
- Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType)
-
- case ast if ast.tokenType == SparkSqlParser.SmallintLiteral =>
- Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType)
-
- case ast if ast.tokenType == SparkSqlParser.BigintLiteral =>
- Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType)
-
- case ast if ast.tokenType == SparkSqlParser.DoubleLiteral =>
- Literal(ast.text.toDouble)
-
- case ast if ast.tokenType == SparkSqlParser.Number =>
- val text = ast.text
- text match {
- case INTEGRAL() =>
- BigDecimal(text) match {
- case v if v.isValidInt =>
- Literal(v.intValue())
- case v if v.isValidLong =>
- Literal(v.longValue())
- case v => Literal(v.underlying())
- }
- case DECIMAL(_*) =>
- Literal(BigDecimal(text).underlying())
- case _ =>
- // Convert a scientifically notated decimal into a double.
- Literal(text.toDouble)
- }
- case ast if ast.tokenType == SparkSqlParser.StringLiteral =>
- Literal(ParseUtils.unescapeSQLString(ast.text))
-
- case ast if ast.tokenType == SparkSqlParser.TOK_DATELITERAL =>
- Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1)))
-
- case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL =>
- Literal(CalendarInterval.fromYearMonthString(ast.children.head.text))
-
- case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL =>
- Literal(CalendarInterval.fromDayTimeString(ast.children.head.text))
-
- case Token("TOK_INTERVAL", elements) =>
- var interval = new CalendarInterval(0, 0)
- var updated = false
- elements.foreach {
- // The interval node will always contain children for all possible time units. A child node
- // is only useful when it contains exactly one (numeric) child.
- case e @ Token(name, Token(value, Nil) :: Nil) =>
- val unit = name match {
- case "TOK_INTERVAL_YEAR_LITERAL" => "year"
- case "TOK_INTERVAL_MONTH_LITERAL" => "month"
- case "TOK_INTERVAL_WEEK_LITERAL" => "week"
- case "TOK_INTERVAL_DAY_LITERAL" => "day"
- case "TOK_INTERVAL_HOUR_LITERAL" => "hour"
- case "TOK_INTERVAL_MINUTE_LITERAL" => "minute"
- case "TOK_INTERVAL_SECOND_LITERAL" => "second"
- case "TOK_INTERVAL_MILLISECOND_LITERAL" => "millisecond"
- case "TOK_INTERVAL_MICROSECOND_LITERAL" => "microsecond"
- case _ => noParseRule(s"Interval($name)", e)
- }
- interval = interval.add(CalendarInterval.fromSingleUnitString(unit, value))
- updated = true
- case _ =>
- }
- if (!updated) {
- throw new AnalysisException("at least one time unit should be given for interval literal")
- }
- Literal(interval)
-
- case _ =>
- noParseRule("Expression", node)
- }
-
- /* Case insensitive matches for Window Specification */
- val PRECEDING = "(?i)preceding".r
- val FOLLOWING = "(?i)following".r
- val CURRENT = "(?i)current".r
- protected def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match {
- case Token(windowName, Nil) :: Nil =>
- // Refer to a window spec defined in the window clause.
- WindowSpecReference(windowName)
- case Nil =>
- // OVER()
- WindowSpecDefinition(
- partitionSpec = Nil,
- orderSpec = Nil,
- frameSpecification = UnspecifiedFrame)
- case spec =>
- val (partitionClause :: rowFrame :: rangeFrame :: Nil) =
- getClauses(
- Seq(
- "TOK_PARTITIONINGSPEC",
- "TOK_WINDOWRANGE",
- "TOK_WINDOWVALUES"),
- spec)
-
- // Handle Partition By and Order By.
- val (partitionSpec, orderSpec) = partitionClause.map { partitionAndOrdering =>
- val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) =
- getClauses(
- Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"),
- partitionAndOrdering.children)
-
- (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match {
- case (Some(partitionByExpr), Some(orderByExpr), None) =>
- (partitionByExpr.children.map(nodeToExpr),
- orderByExpr.children.map(nodeToSortOrder))
- case (Some(partitionByExpr), None, None) =>
- (partitionByExpr.children.map(nodeToExpr), Nil)
- case (None, Some(orderByExpr), None) =>
- (Nil, orderByExpr.children.map(nodeToSortOrder))
- case (None, None, Some(clusterByExpr)) =>
- val expressions = clusterByExpr.children.map(nodeToExpr)
- (expressions, expressions.map(SortOrder(_, Ascending)))
- case _ =>
- noParseRule("Partition & Ordering", partitionAndOrdering)
- }
- }.getOrElse {
- (Nil, Nil)
- }
-
- // Handle Window Frame
- val windowFrame =
- if (rowFrame.isEmpty && rangeFrame.isEmpty) {
- UnspecifiedFrame
- } else {
- val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame)
- def nodeToBoundary(node: ASTNode): FrameBoundary = node match {
- case Token(PRECEDING(), Token(count, Nil) :: Nil) =>
- if (count.toLowerCase() == "unbounded") {
- UnboundedPreceding
- } else {
- ValuePreceding(count.toInt)
- }
- case Token(FOLLOWING(), Token(count, Nil) :: Nil) =>
- if (count.toLowerCase() == "unbounded") {
- UnboundedFollowing
- } else {
- ValueFollowing(count.toInt)
- }
- case Token(CURRENT(), Nil) => CurrentRow
- case _ =>
- noParseRule("Window Frame Boundary", node)
- }
-
- rowFrame.orElse(rangeFrame).map { frame =>
- frame.children match {
- case precedingNode :: followingNode :: Nil =>
- SpecifiedWindowFrame(
- frameType,
- nodeToBoundary(precedingNode),
- nodeToBoundary(followingNode))
- case precedingNode :: Nil =>
- SpecifiedWindowFrame(frameType, nodeToBoundary(precedingNode), CurrentRow)
- case _ =>
- noParseRule("Window Frame", frame)
- }
- }.getOrElse(sys.error(s"If you see this, please file a bug report with your query."))
- }
-
- WindowSpecDefinition(partitionSpec, orderSpec, windowFrame)
- }
-
- protected def nodeToTransformation(
- node: ASTNode,
- child: LogicalPlan): Option[ScriptTransformation] = None
-
- val explode = "(?i)explode".r
- val jsonTuple = "(?i)json_tuple".r
- protected def nodeToGenerate(node: ASTNode, outer: Boolean, child: LogicalPlan): Generate = {
- val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = node
-
- val alias = cleanIdentifier(getClause("TOK_TABALIAS", clauses).children.head.text)
-
- val generator = clauses.head match {
- case Token("TOK_FUNCTION", Token(explode(), Nil) :: childNode :: Nil) =>
- Explode(nodeToExpr(childNode))
- case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) =>
- JsonTuple(children.map(nodeToExpr))
- case other =>
- nodeToGenerator(other)
- }
-
- val attributes = clauses.collect {
- case Token(a, Nil) => UnresolvedAttribute(cleanIdentifier(a.toLowerCase))
- }
-
- Generate(
- generator,
- join = true,
- outer = outer,
- Some(cleanIdentifier(alias.toLowerCase)),
- attributes,
- child)
- }
-
- protected def nodeToGenerator(node: ASTNode): Generator = noParseRule("Generator", node)
-
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala
index 21deb82..0b570c9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.parser
import scala.language.implicitConversions
import scala.util.matching.Regex
import scala.util.parsing.combinator.syntactical.StandardTokenParsers
+import scala.util.parsing.input.CharArrayReader._
import org.apache.spark.sql.types._
@@ -117,3 +118,69 @@ private[sql] object DataTypeParser {
/** The exception thrown from the [[DataTypeParser]]. */
private[sql] class DataTypeException(message: String) extends Exception(message)
+
+class SqlLexical extends scala.util.parsing.combinator.lexical.StdLexical {
+ case class DecimalLit(chars: String) extends Token {
+ override def toString: String = chars
+ }
+
+ /* This is a work around to support the lazy setting */
+ def initialize(keywords: Seq[String]): Unit = {
+ reserved.clear()
+ reserved ++= keywords
+ }
+
+ /* Normal the keyword string */
+ def normalizeKeyword(str: String): String = str.toLowerCase
+
+ delimiters += (
+ "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
+ ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>"
+ )
+
+ protected override def processIdent(name: String) = {
+ val token = normalizeKeyword(name)
+ if (reserved contains token) Keyword(token) else Identifier(name)
+ }
+
+ override lazy val token: Parser[Token] =
+ ( rep1(digit) ~ scientificNotation ^^ { case i ~ s => DecimalLit(i.mkString + s) }
+ | '.' ~> (rep1(digit) ~ scientificNotation) ^^
+ { case i ~ s => DecimalLit("0." + i.mkString + s) }
+ | rep1(digit) ~ ('.' ~> digit.*) ~ scientificNotation ^^
+ { case i1 ~ i2 ~ s => DecimalLit(i1.mkString + "." + i2.mkString + s) }
+ | digit.* ~ identChar ~ (identChar | digit).* ^^
+ { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) }
+ | rep1(digit) ~ ('.' ~> digit.*).? ^^ {
+ case i ~ None => NumericLit(i.mkString)
+ case i ~ Some(d) => DecimalLit(i.mkString + "." + d.mkString)
+ }
+ | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^
+ { case chars => StringLit(chars mkString "") }
+ | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^
+ { case chars => StringLit(chars mkString "") }
+ | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^
+ { case chars => Identifier(chars mkString "") }
+ | EofCh ^^^ EOF
+ | '\'' ~> failure("unclosed string literal")
+ | '"' ~> failure("unclosed string literal")
+ | delim
+ | failure("illegal character")
+ )
+
+ override def identChar: Parser[Elem] = letter | elem('_')
+
+ private lazy val scientificNotation: Parser[String] =
+ (elem('e') | elem('E')) ~> (elem('+') | elem('-')).? ~ rep1(digit) ^^ {
+ case s ~ rest => "e" + s.mkString + rest.mkString
+ }
+
+ override def whitespace: Parser[Any] =
+ ( whitespaceChar
+ | '/' ~ '*' ~ comment
+ | '/' ~ '/' ~ chrExcept(EofCh, '\n').*
+ | '#' ~ chrExcept(EofCh, '\n').*
+ | '-' ~ '-' ~ chrExcept(EofCh, '\n').*
+ | '/' ~ '*' ~ failure("unclosed comment")
+ ).*
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
index 51cfc50..d013252 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
@@ -16,91 +16,106 @@
*/
package org.apache.spark.sql.catalyst.parser
-import scala.annotation.tailrec
-
-import org.antlr.runtime._
-import org.antlr.runtime.tree.CommonTree
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.ParseCancellationException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.types.DataType
/**
- * The ParseDriver takes a SQL command and turns this into an AST.
- *
- * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver
+ * Base SQL parsing infrastructure.
*/
-object ParseDriver extends Logging {
- /** Create an LogicalPlan ASTNode from a SQL command. */
- def parsePlan(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
- parser.statement().getTree
- }
+abstract class AbstractSqlParser extends ParserInterface with Logging {
- /** Create an Expression ASTNode from a SQL command. */
- def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
- parser.singleNamedExpression().getTree
+ /** Creates/Resolves DataType for a given SQL string. */
+ def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
+ // TODO add this to the parser interface.
+ astBuilder.visitSingleDataType(parser.singleDataType())
}
- /** Create an TableIdentifier ASTNode from a SQL command. */
- def parseTableName(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
- parser.singleTableName().getTree
+ /** Creates Expression for a given SQL string. */
+ override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser =>
+ astBuilder.visitSingleExpression(parser.singleExpression())
}
- private def parse(
- command: String,
- conf: ParserConf)(
- toTree: SparkSqlParser => CommonTree): ASTNode = {
- logInfo(s"Parsing command: $command")
+ /** Creates TableIdentifier for a given SQL string. */
+ override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser =>
+ astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier())
+ }
- // Setup error collection.
- val reporter = new ParseErrorReporter()
+ /** Creates LogicalPlan for a given SQL string. */
+ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
+ astBuilder.visitSingleStatement(parser.singleStatement()) match {
+ case plan: LogicalPlan => plan
+ case _ => nativeCommand(sqlText)
+ }
+ }
- // Create lexer.
- val lexer = new SparkSqlLexer(new ANTLRNoCaseStringStream(command))
- val tokens = new TokenRewriteStream(lexer)
- lexer.configure(conf, reporter)
+ /** Get the builder (visitor) which converts a ParseTree into a AST. */
+ protected def astBuilder: AstBuilder
- // Create the parser.
- val parser = new SparkSqlParser(tokens)
- parser.configure(conf, reporter)
+ /** Create a native command, or fail when this is not supported. */
+ protected def nativeCommand(sqlText: String): LogicalPlan = {
+ val position = Origin(None, None)
+ throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position)
+ }
- try {
- val result = toTree(parser)
+ protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
+ logInfo(s"Parsing command: $command")
- // Check errors.
- reporter.checkForErrors()
+ val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command))
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(ParseErrorListener)
- // Return the AST node from the result.
- logInfo(s"Parse completed.")
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new SqlBaseParser(tokenStream)
+ parser.addParseListener(PostProcessor)
+ parser.removeErrorListeners()
+ parser.addErrorListener(ParseErrorListener)
- // Find the non null token tree in the result.
- @tailrec
- def nonNullToken(tree: CommonTree): CommonTree = {
- if (tree.token != null || tree.getChildCount == 0) tree
- else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree])
+ try {
+ try {
+ // first, try parsing with potentially faster SLL mode
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ toResult(parser)
}
- val tree = nonNullToken(result)
-
- // Make sure all boundaries are set.
- tree.setUnknownTokenBoundaries()
-
- // Construct the immutable AST.
- def createASTNode(tree: CommonTree): ASTNode = {
- val children = (0 until tree.getChildCount).map { i =>
- createASTNode(tree.getChild(i).asInstanceOf[CommonTree])
- }.toList
- ASTNode(tree.token, tree.getTokenStartIndex, tree.getTokenStopIndex, children, tokens)
+ catch {
+ case e: ParseCancellationException =>
+ // if we fail, parse with LL mode
+ tokenStream.reset() // rewind input stream
+ parser.reset()
+
+ // Try Again.
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ toResult(parser)
}
- createASTNode(tree)
}
catch {
- case e: RecognitionException =>
- logInfo(s"Parse failed.")
- reporter.throwError(e)
+ case e: ParseException if e.command.isDefined =>
+ throw e
+ case e: ParseException =>
+ throw e.withCommand(command)
+ case e: AnalysisException =>
+ val position = Origin(e.line, e.startPosition)
+ throw new ParseException(Option(command), e.message, position, position)
}
}
}
/**
+ * Concrete SQL parser for Catalyst-only SQL statements.
+ */
+object CatalystSqlParser extends AbstractSqlParser {
+ val astBuilder = new AstBuilder
+}
+
+/**
* This string stream provides the lexer with upper case characters only. This greatly simplifies
* lexing the stream, while we can maintain the original command.
*
@@ -120,58 +135,104 @@ object ParseDriver extends Logging {
* have the ANTLRNoCaseStringStream implementation.
*/
-private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRStringStream(input) {
+private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) {
override def LA(i: Int): Int = {
val la = super.LA(i)
- if (la == 0 || la == CharStream.EOF) la
+ if (la == 0 || la == IntStream.EOF) la
else Character.toUpperCase(la)
}
}
/**
- * Utility used by the Parser and the Lexer for error collection and reporting.
+ * The ParseErrorListener converts parse errors into AnalysisExceptions.
*/
-private[parser] class ParseErrorReporter {
- val errors = scala.collection.mutable.Buffer.empty[ParseError]
-
- def report(br: BaseRecognizer, re: RecognitionException, tokenNames: Array[String]): Unit = {
- errors += ParseError(br, re, tokenNames)
+case object ParseErrorListener extends BaseErrorListener {
+ override def syntaxError(
+ recognizer: Recognizer[_, _],
+ offendingSymbol: scala.Any,
+ line: Int,
+ charPositionInLine: Int,
+ msg: String,
+ e: RecognitionException): Unit = {
+ val position = Origin(Some(line), Some(charPositionInLine))
+ throw new ParseException(None, msg, position, position)
}
+}
- def checkForErrors(): Unit = {
- if (errors.nonEmpty) {
- val first = errors.head
- val e = first.re
- throwError(e.line, e.charPositionInLine, first.buildMessage().toString, errors.tail)
- }
+/**
+ * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It
+ * contains fields and an extended error message that make reporting and diagnosing errors easier.
+ */
+class ParseException(
+ val command: Option[String],
+ message: String,
+ val start: Origin,
+ val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) {
+
+ def this(message: String, ctx: ParserRuleContext) = {
+ this(Option(ParserUtils.command(ctx)),
+ message,
+ ParserUtils.position(ctx.getStart),
+ ParserUtils.position(ctx.getStop))
}
- def throwError(e: RecognitionException): Nothing = {
- throwError(e.line, e.charPositionInLine, e.toString, errors)
+ override def getMessage: String = {
+ val builder = new StringBuilder
+ builder ++= "\n" ++= message
+ start match {
+ case Origin(Some(l), Some(p)) =>
+ builder ++= s"(line $l, pos $p)\n"
+ command.foreach { cmd =>
+ val (above, below) = cmd.split("\n").splitAt(l)
+ builder ++= "\n== SQL ==\n"
+ above.foreach(builder ++= _ += '\n')
+ builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n"
+ below.foreach(builder ++= _ += '\n')
+ }
+ case _ =>
+ command.foreach { cmd =>
+ builder ++= "\n== SQL ==\n" ++= cmd
+ }
+ }
+ builder.toString
}
- private def throwError(
- line: Int,
- startPosition: Int,
- msg: String,
- errors: Seq[ParseError]): Nothing = {
- val b = new StringBuilder
- b.append(msg).append("\n")
- errors.foreach(error => error.buildMessage(b).append("\n"))
- throw new AnalysisException(b.toString, Option(line), Option(startPosition))
+ def withCommand(cmd: String): ParseException = {
+ new ParseException(Option(cmd), message, start, stop)
}
}
/**
- * Error collected during the parsing process.
- *
- * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseError
+ * The post-processor validates & cleans-up the parse tree during the parse process.
*/
-private[parser] case class ParseError(
- br: BaseRecognizer,
- re: RecognitionException,
- tokenNames: Array[String]) {
- def buildMessage(s: StringBuilder = new StringBuilder): StringBuilder = {
- s.append(br.getErrorHeader(re)).append(" ").append(br.getErrorMessage(re, tokenNames))
+case object PostProcessor extends SqlBaseBaseListener {
+
+ /** Remove the back ticks from an Identifier. */
+ override def exitQuotedIdentifier(ctx: SqlBaseParser.QuotedIdentifierContext): Unit = {
+ replaceTokenByIdentifier(ctx, 1) { token =>
+ // Remove the double back ticks in the string.
+ token.setText(token.getText.replace("``", "`"))
+ token
+ }
+ }
+
+ /** Treat non-reserved keywords as Identifiers. */
+ override def exitNonReserved(ctx: SqlBaseParser.NonReservedContext): Unit = {
+ replaceTokenByIdentifier(ctx, 0)(identity)
+ }
+
+ private def replaceTokenByIdentifier(
+ ctx: ParserRuleContext,
+ stripMargins: Int)(
+ f: CommonToken => CommonToken = identity): Unit = {
+ val parent = ctx.getParent
+ parent.removeLastChild()
+ val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
+ parent.addChild(f(new CommonToken(
+ new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
+ SqlBaseParser.IDENTIFIER,
+ token.getChannel,
+ token.getStartIndex + stripMargins,
+ token.getStopIndex - stripMargins)))
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala
deleted file mode 100644
index ce449b1..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala
+++ /dev/null
@@ -1,26 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.catalyst.parser
-
-trait ParserConf {
- def supportQuotedId: Boolean
- def supportSQL11ReservedKeywords: Boolean
-}
-
-case class SimpleParserConf(
- supportQuotedId: Boolean = true,
- supportSQL11ReservedKeywords: Boolean = false) extends ParserConf
http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index 0c2e481..90b76dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -14,166 +14,105 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.sql.catalyst.parser
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.trees.CurrentOrigin
-import org.apache.spark.sql.types._
+import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token}
+import org.antlr.v4.runtime.misc.Interval
+import org.antlr.v4.runtime.tree.TerminalNode
+import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
/**
- * A collection of utility methods and patterns for parsing query texts.
+ * A collection of utility methods for use during the parsing process.
*/
-// TODO: merge with ParseUtils
object ParserUtils {
-
- object Token {
- // Match on (text, children)
- def unapply(node: ASTNode): Some[(String, List[ASTNode])] = {
- CurrentOrigin.setPosition(node.line, node.positionInLine)
- node.pattern
- }
+ /** Get the command which created the token. */
+ def command(ctx: ParserRuleContext): String = {
+ command(ctx.getStart.getInputStream)
}
- private val escapedIdentifier = "`(.+)`".r
- private val doubleQuotedString = "\"([^\"]+)\"".r
- private val singleQuotedString = "'([^']+)'".r
-
- // Token patterns
- val COUNT = "(?i)COUNT".r
- val SUM = "(?i)SUM".r
- val AND = "(?i)AND".r
- val OR = "(?i)OR".r
- val NOT = "(?i)NOT".r
- val TRUE = "(?i)TRUE".r
- val FALSE = "(?i)FALSE".r
- val LIKE = "(?i)LIKE".r
- val RLIKE = "(?i)RLIKE".r
- val REGEXP = "(?i)REGEXP".r
- val IN = "(?i)IN".r
- val DIV = "(?i)DIV".r
- val BETWEEN = "(?i)BETWEEN".r
- val WHEN = "(?i)WHEN".r
- val CASE = "(?i)CASE".r
- val INTEGRAL = "[+-]?\\d+".r
- val DECIMAL = "[+-]?((\\d+(\\.\\d*)?)|(\\.\\d+))".r
-
- /**
- * Strip quotes, if any, from the string.
- */
- def unquoteString(str: String): String = {
- str match {
- case singleQuotedString(s) => s
- case doubleQuotedString(s) => s
- case other => other
- }
+ /** Get the command which created the token. */
+ def command(stream: CharStream): String = {
+ stream.getText(Interval.of(0, stream.size()))
}
- /**
- * Strip backticks, if any, from the string.
- */
- def cleanIdentifier(ident: String): String = {
- ident match {
- case escapedIdentifier(i) => i
- case plainIdent => plainIdent
- }
+ /** Get the code that creates the given node. */
+ def source(ctx: ParserRuleContext): String = {
+ val stream = ctx.getStart.getInputStream
+ stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex))
}
- def getClauses(
- clauseNames: Seq[String],
- nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = {
- var remainingNodes = nodeList
- val clauses = clauseNames.map { clauseName =>
- val (matches, nonMatches) = remainingNodes.partition(_.text.toUpperCase == clauseName)
- remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil)
- matches.headOption
- }
+ /** Get all the text which comes after the given rule. */
+ def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop)
- if (remainingNodes.nonEmpty) {
- sys.error(
- s"""Unhandled clauses: ${remainingNodes.map(_.treeString).mkString("\n")}.
- |You are likely trying to use an unsupported Hive feature."""".stripMargin)
- }
- clauses
+ /** Get all the text which comes after the given token. */
+ def remainder(token: Token): String = {
+ val stream = token.getInputStream
+ val interval = Interval.of(token.getStopIndex + 1, stream.size())
+ stream.getText(interval)
}
- def getClause(clauseName: String, nodeList: Seq[ASTNode]): ASTNode = {
- getClauseOption(clauseName, nodeList).getOrElse(sys.error(
- s"Expected clause $clauseName missing from ${nodeList.map(_.treeString).mkString("\n")}"))
- }
+ /** Convert a string token into a string. */
+ def string(token: Token): String = unescapeSQLString(token.getText)
- def getClauseOption(clauseName: String, nodeList: Seq[ASTNode]): Option[ASTNode] = {
- nodeList.filter { case ast: ASTNode => ast.text == clauseName } match {
- case Seq(oneMatch) => Some(oneMatch)
- case Seq() => None
- case _ => sys.error(s"Found multiple instances of clause $clauseName")
- }
- }
+ /** Convert a string node into a string. */
+ def string(node: TerminalNode): String = unescapeSQLString(node.getText)
- def extractTableIdent(tableNameParts: ASTNode): TableIdentifier = {
- tableNameParts.children.map {
- case Token(part, Nil) => cleanIdentifier(part)
- } match {
- case Seq(tableOnly) => TableIdentifier(tableOnly)
- case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName))
- case other => sys.error("Hive only supports tables names like 'tableName' " +
- s"or 'databaseName.tableName', found '$other'")
- }
+ /** Get the origin (line and position) of the token. */
+ def position(token: Token): Origin = {
+ Origin(Option(token.getLine), Option(token.getCharPositionInLine))
}
- def nodeToDataType(node: ASTNode): DataType = node match {
- case Token("TOK_DECIMAL", precision :: scale :: Nil) =>
- DecimalType(precision.text.toInt, scale.text.toInt)
- case Token("TOK_DECIMAL", precision :: Nil) =>
- DecimalType(precision.text.toInt, 0)
- case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT
- case Token("TOK_BIGINT", Nil) => LongType
- case Token("TOK_INT", Nil) => IntegerType
- case Token("TOK_TINYINT", Nil) => ByteType
- case Token("TOK_SMALLINT", Nil) => ShortType
- case Token("TOK_BOOLEAN", Nil) => BooleanType
- case Token("TOK_STRING", Nil) => StringType
- case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType
- case Token("TOK_CHAR", Token(_, Nil) :: Nil) => StringType
- case Token("TOK_FLOAT", Nil) => FloatType
- case Token("TOK_DOUBLE", Nil) => DoubleType
- case Token("TOK_DATE", Nil) => DateType
- case Token("TOK_TIMESTAMP", Nil) => TimestampType
- case Token("TOK_BINARY", Nil) => BinaryType
- case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType))
- case Token("TOK_STRUCT", Token("TOK_TABCOLLIST", fields) :: Nil) =>
- StructType(fields.map(nodeToStructField))
- case Token("TOK_MAP", keyType :: valueType :: Nil) =>
- MapType(nodeToDataType(keyType), nodeToDataType(valueType))
- case _ =>
- noParseRule("DataType", node)
- }
-
- def nodeToStructField(node: ASTNode): StructField = node match {
- case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) =>
- StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true)
- case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: comment :: Nil) =>
- val meta = new MetadataBuilder().putString("comment", unquoteString(comment.text)).build()
- StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true, meta)
- case _ =>
- noParseRule("StructField", node)
+ /** Assert if a condition holds. If it doesn't throw a parse exception. */
+ def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = {
+ if (!f) {
+ throw new ParseException(message, ctx)
+ }
}
/**
- * Throw an exception because we cannot parse the given node for some unexpected reason.
+ * Register the origin of the context. Any TreeNode created in the closure will be assigned the
+ * registered origin. This method restores the previously set origin after completion of the
+ * closure.
*/
- def parseFailed(msg: String, node: ASTNode): Nothing = {
- throw new AnalysisException(s"$msg: '${node.source}")
+ def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = {
+ val current = CurrentOrigin.get
+ CurrentOrigin.set(position(ctx.getStart))
+ try {
+ f
+ } finally {
+ CurrentOrigin.set(current)
+ }
}
- /**
- * Throw an exception because there are no rules to parse the node.
- */
- def noParseRule(msg: String, node: ASTNode): Nothing = {
- throw new NotImplementedError(
- s"[$msg]: No parse rules for ASTNode type: ${node.tokenType}, tree:\n${node.treeString}")
- }
+ /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */
+ implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal {
+ /**
+ * Create a plan using the block of code when the given context exists. Otherwise return the
+ * original plan.
+ */
+ def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = {
+ if (ctx != null) {
+ f
+ } else {
+ plan
+ }
+ }
+ /**
+ * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the
+ * passed function. The original plan is returned when the context does not exist.
+ */
+ def optionalMap[C <: ParserRuleContext](
+ ctx: C)(
+ f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
+ if (ctx != null) {
+ f(ctx, plan)
+ } else {
+ plan
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org