You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/03/21 02:04:46 UTC
[3/9] SPARK-1251 Support for optimizing and executing structured
queries
http://git-wip-us.apache.org/repos/asf/spark/blob/9aadcffa/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
new file mode 100644
index 0000000..6402925
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.test._
+
+/* Implicits */
+import TestSQLContext._
+
+object TestData {
+ case class TestData(key: Int, value: String)
+ val testData: SchemaRDD = TestSQLContext.sparkContext.parallelize(
+ (1 to 100).map(i => TestData(i, i.toString)))
+ testData.registerAsTable("testData")
+
+ case class TestData2(a: Int, b: Int)
+ val testData2: SchemaRDD =
+ TestSQLContext.sparkContext.parallelize(
+ TestData2(1, 1) ::
+ TestData2(1, 2) ::
+ TestData2(2, 1) ::
+ TestData2(2, 2) ::
+ TestData2(3, 1) ::
+ TestData2(3, 2) :: Nil
+ )
+ testData2.registerAsTable("testData2")
+
+ // TODO: There is no way to express null primitives as case classes currently...
+ val testData3 =
+ logical.LocalRelation('a.int, 'b.int).loadData(
+ (1, null) ::
+ (2, 2) :: Nil
+ )
+
+ case class UpperCaseData(N: Int, L: String)
+ val upperCaseData =
+ TestSQLContext.sparkContext.parallelize(
+ UpperCaseData(1, "A") ::
+ UpperCaseData(2, "B") ::
+ UpperCaseData(3, "C") ::
+ UpperCaseData(4, "D") ::
+ UpperCaseData(5, "E") ::
+ UpperCaseData(6, "F") :: Nil
+ )
+ upperCaseData.registerAsTable("upperCaseData")
+
+ case class LowerCaseData(n: Int, l: String)
+ val lowerCaseData =
+ TestSQLContext.sparkContext.parallelize(
+ LowerCaseData(1, "a") ::
+ LowerCaseData(2, "b") ::
+ LowerCaseData(3, "c") ::
+ LowerCaseData(4, "d") :: Nil
+ )
+ lowerCaseData.registerAsTable("lowerCaseData")
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/spark/blob/9aadcffa/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala
new file mode 100644
index 0000000..08265b7
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package execution
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.test._
+
+
+import TestSQLContext._
+
+/**
+ * This is an example TGF that uses UnresolvedAttributes 'name and 'age to access specific columns
+ * from the input data. These will be replaced during analysis with specific AttributeReferences
+ * and then bound to specific ordinals during query planning. While TGFs could also access specific
+ * columns using hand-coded ordinals, doing so violates data independence.
+ *
+ * Note: this is only a rough example of how TGFs can be expressed, the final version will likely
+ * involve a lot more sugar for cleaner use in Scala/Java/etc.
+ */
+case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generator {
+ def children = input
+ protected def makeOutput() = 'nameAndAge.string :: Nil
+
+ val Seq(nameAttr, ageAttr) = input
+
+ override def apply(input: Row): TraversableOnce[Row] = {
+ val name = nameAttr.apply(input)
+ val age = ageAttr.apply(input).asInstanceOf[Int]
+
+ Iterator(
+ new GenericRow(Array[Any](s"$name is $age years old")),
+ new GenericRow(Array[Any](s"Next year, $name will be ${age + 1} years old")))
+ }
+}
+
+class TgfSuite extends QueryTest {
+ val inputData =
+ logical.LocalRelation('name.string, 'age.int).loadData(
+ ("michael", 29) :: Nil
+ )
+
+ test("simple tgf example") {
+ checkAnswer(
+ inputData.generate(ExampleTGF()),
+ Seq(
+ "michael is 29 years old" :: Nil,
+ "Next year, michael will be 30 years old" :: Nil))
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9aadcffa/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
new file mode 100644
index 0000000..8b2ccb5
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.parquet
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.catalyst.util.getTempFilePath
+import org.apache.spark.sql.test.TestSQLContext
+
+import org.apache.hadoop.mapreduce.Job
+import org.apache.hadoop.fs.{Path, FileSystem}
+
+import parquet.schema.MessageTypeParser
+import parquet.hadoop.ParquetFileWriter
+import parquet.hadoop.util.ContextUtil
+
+class ParquetQuerySuite extends FunSuite with BeforeAndAfterAll {
+ override def beforeAll() {
+ ParquetTestData.writeFile
+ }
+
+ override def afterAll() {
+ ParquetTestData.testFile.delete()
+ }
+
+ test("Import of simple Parquet file") {
+ val result = getRDD(ParquetTestData.testData).collect()
+ assert(result.size === 15)
+ result.zipWithIndex.foreach {
+ case (row, index) => {
+ val checkBoolean =
+ if (index % 3 == 0)
+ row(0) == true
+ else
+ row(0) == false
+ assert(checkBoolean === true, s"boolean field value in line $index did not match")
+ if (index % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match")
+ assert(row(2) === "abc", s"string field value in line $index did not match")
+ assert(row(3) === (index.toLong << 33), s"long value in line $index did not match")
+ assert(row(4) === 2.5F, s"float field value in line $index did not match")
+ assert(row(5) === 4.5D, s"double field value in line $index did not match")
+ }
+ }
+ }
+
+ test("Projection of simple Parquet file") {
+ val scanner = new ParquetTableScan(
+ ParquetTestData.testData.output,
+ ParquetTestData.testData,
+ None)(TestSQLContext.sparkContext)
+ val projected = scanner.pruneColumns(ParquetTypesConverter
+ .convertToAttributes(MessageTypeParser
+ .parseMessageType(ParquetTestData.subTestSchema)))
+ assert(projected.output.size === 2)
+ val result = projected
+ .execute()
+ .map(_.copy())
+ .collect()
+ result.zipWithIndex.foreach {
+ case (row, index) => {
+ if (index % 3 == 0)
+ assert(row(0) === true, s"boolean field value in line $index did not match (every third row)")
+ else
+ assert(row(0) === false, s"boolean field value in line $index did not match")
+ assert(row(1) === (index.toLong << 33), s"long field value in line $index did not match")
+ assert(row.size === 2, s"number of columns in projection in line $index is incorrect")
+ }
+ }
+ }
+
+ test("Writing metadata from scratch for table CREATE") {
+ val job = new Job()
+ val path = new Path(getTempFilePath("testtable").getCanonicalFile.toURI.toString)
+ val fs: FileSystem = FileSystem.getLocal(ContextUtil.getConfiguration(job))
+ ParquetTypesConverter.writeMetaData(
+ ParquetTestData.testData.output,
+ path,
+ TestSQLContext.sparkContext.hadoopConfiguration)
+ assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE)))
+ val metaData = ParquetTypesConverter.readMetaData(path)
+ assert(metaData != null)
+ ParquetTestData
+ .testData
+ .parquetSchema
+ .checkContains(metaData.getFileMetaData.getSchema) // throws exception if incompatible
+ metaData
+ .getFileMetaData
+ .getSchema
+ .checkContains(ParquetTestData.testData.parquetSchema) // throws exception if incompatible
+ fs.delete(path, true)
+ }
+
+ /**
+ * Computes the given [[ParquetRelation]] and returns its RDD.
+ *
+ * @param parquetRelation The Parquet relation.
+ * @return An RDD of Rows.
+ */
+ private def getRDD(parquetRelation: ParquetRelation): RDD[Row] = {
+ val scanner = new ParquetTableScan(
+ parquetRelation.output,
+ parquetRelation,
+ None)(TestSQLContext.sparkContext)
+ scanner
+ .execute
+ .map(_.copy())
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/spark/blob/9aadcffa/sql/hive/pom.xml
----------------------------------------------------------------------
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
new file mode 100644
index 0000000..7b5ea98
--- /dev/null
+++ b/sql/hive/pom.xml
@@ -0,0 +1,81 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent</artifactId>
+ <version>1.0.0-SNAPSHOT</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-hive_2.10</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project Hive</name>
+ <url>http://spark.apache.org/</url>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hive</groupId>
+ <artifactId>hive-metastore</artifactId>
+ <version>${hive.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hive</groupId>
+ <artifactId>hive-exec</artifactId>
+ <version>${hive.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hive</groupId>
+ <artifactId>hive-serde</artifactId>
+ <version>${hive.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.scalacheck</groupId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+ <plugins>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+</project>
http://git-wip-us.apache.org/repos/asf/spark/blob/9aadcffa/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala b/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala
new file mode 100644
index 0000000..08d390e
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapred
+
+import java.io.IOException
+import java.text.NumberFormat
+import java.util.Date
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.Writable
+
+import org.apache.spark.Logging
+import org.apache.spark.SerializableWritable
+
+import org.apache.hadoop.hive.ql.exec.{Utilities, FileSinkOperator}
+import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat}
+import org.apache.hadoop.hive.ql.plan.FileSinkDesc
+
+/**
+ * Internal helper class that saves an RDD using a Hive OutputFormat.
+ * It is based on [[SparkHadoopWriter]].
+ */
+protected[apache]
+class SparkHiveHadoopWriter(
+ @transient jobConf: JobConf,
+ fileSinkConf: FileSinkDesc)
+ extends Logging
+ with SparkHadoopMapRedUtil
+ with Serializable {
+
+ private val now = new Date()
+ private val conf = new SerializableWritable(jobConf)
+
+ private var jobID = 0
+ private var splitID = 0
+ private var attemptID = 0
+ private var jID: SerializableWritable[JobID] = null
+ private var taID: SerializableWritable[TaskAttemptID] = null
+
+ @transient private var writer: FileSinkOperator.RecordWriter = null
+ @transient private var format: HiveOutputFormat[AnyRef, Writable] = null
+ @transient private var committer: OutputCommitter = null
+ @transient private var jobContext: JobContext = null
+ @transient private var taskContext: TaskAttemptContext = null
+
+ def preSetup() {
+ setIDs(0, 0, 0)
+ setConfParams()
+
+ val jCtxt = getJobContext()
+ getOutputCommitter().setupJob(jCtxt)
+ }
+
+
+ def setup(jobid: Int, splitid: Int, attemptid: Int) {
+ setIDs(jobid, splitid, attemptid)
+ setConfParams()
+ }
+
+ def open() {
+ val numfmt = NumberFormat.getInstance()
+ numfmt.setMinimumIntegerDigits(5)
+ numfmt.setGroupingUsed(false)
+
+ val extension = Utilities.getFileExtension(
+ conf.value,
+ fileSinkConf.getCompressed,
+ getOutputFormat())
+
+ val outputName = "part-" + numfmt.format(splitID) + extension
+ val path = FileOutputFormat.getTaskOutputPath(conf.value, outputName)
+
+ getOutputCommitter().setupTask(getTaskContext())
+ writer = HiveFileFormatUtils.getHiveRecordWriter(
+ conf.value,
+ fileSinkConf.getTableInfo,
+ conf.value.getOutputValueClass.asInstanceOf[Class[Writable]],
+ fileSinkConf,
+ path,
+ null)
+ }
+
+ def write(value: Writable) {
+ if (writer != null) {
+ writer.write(value)
+ } else {
+ throw new IOException("Writer is null, open() has not been called")
+ }
+ }
+
+ def close() {
+ // Seems the boolean value passed into close does not matter.
+ writer.close(false)
+ }
+
+ def commit() {
+ val taCtxt = getTaskContext()
+ val cmtr = getOutputCommitter()
+ if (cmtr.needsTaskCommit(taCtxt)) {
+ try {
+ cmtr.commitTask(taCtxt)
+ logInfo (taID + ": Committed")
+ } catch {
+ case e: IOException => {
+ logError("Error committing the output of task: " + taID.value, e)
+ cmtr.abortTask(taCtxt)
+ throw e
+ }
+ }
+ } else {
+ logWarning ("No need to commit output of task: " + taID.value)
+ }
+ }
+
+ def commitJob() {
+ // always ? Or if cmtr.needsTaskCommit ?
+ val cmtr = getOutputCommitter()
+ cmtr.commitJob(getJobContext())
+ }
+
+ // ********* Private Functions *********
+
+ private def getOutputFormat(): HiveOutputFormat[AnyRef,Writable] = {
+ if (format == null) {
+ format = conf.value.getOutputFormat()
+ .asInstanceOf[HiveOutputFormat[AnyRef,Writable]]
+ }
+ format
+ }
+
+ private def getOutputCommitter(): OutputCommitter = {
+ if (committer == null) {
+ committer = conf.value.getOutputCommitter
+ }
+ committer
+ }
+
+ private def getJobContext(): JobContext = {
+ if (jobContext == null) {
+ jobContext = newJobContext(conf.value, jID.value)
+ }
+ jobContext
+ }
+
+ private def getTaskContext(): TaskAttemptContext = {
+ if (taskContext == null) {
+ taskContext = newTaskAttemptContext(conf.value, taID.value)
+ }
+ taskContext
+ }
+
+ private def setIDs(jobid: Int, splitid: Int, attemptid: Int) {
+ jobID = jobid
+ splitID = splitid
+ attemptID = attemptid
+
+ jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid))
+ taID = new SerializableWritable[TaskAttemptID](
+ new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID))
+ }
+
+ private def setConfParams() {
+ conf.value.set("mapred.job.id", jID.value.toString)
+ conf.value.set("mapred.tip.id", taID.value.getTaskID.toString)
+ conf.value.set("mapred.task.id", taID.value.toString)
+ conf.value.setBoolean("mapred.task.is.map", true)
+ conf.value.setInt("mapred.task.partition", splitID)
+ }
+}
+
+object SparkHiveHadoopWriter {
+ def createPathFromString(path: String, conf: JobConf): Path = {
+ if (path == null) {
+ throw new IllegalArgumentException("Output path is null")
+ }
+ val outputPath = new Path(path)
+ val fs = outputPath.getFileSystem(conf)
+ if (outputPath == null || fs == null) {
+ throw new IllegalArgumentException("Incorrectly formatted output path")
+ }
+ outputPath.makeQualified(fs)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9aadcffa/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
new file mode 100644
index 0000000..4aad876
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import java.io.{PrintStream, InputStreamReader, BufferedReader, File}
+import java.util.{ArrayList => JArrayList}
+import scala.language.implicitConversions
+
+import org.apache.spark.SparkContext
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.ql.processors.{CommandProcessorResponse, CommandProcessorFactory}
+import org.apache.hadoop.hive.ql.processors.CommandProcessor
+import org.apache.hadoop.hive.ql.Driver
+import org.apache.spark.rdd.RDD
+
+import catalyst.analysis.{Analyzer, OverrideCatalog}
+import catalyst.expressions.GenericRow
+import catalyst.plans.logical.{BaseRelation, LogicalPlan, NativeCommand, ExplainCommand}
+import catalyst.types._
+
+import org.apache.spark.sql.execution._
+
+import scala.collection.JavaConversions._
+
+/**
+ * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is
+ * created with data stored in ./metadata. Warehouse data is stored in in ./warehouse.
+ */
+class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) {
+
+ lazy val metastorePath = new File("metastore").getCanonicalPath
+ lazy val warehousePath: String = new File("warehouse").getCanonicalPath
+
+ /** Sets up the system initially or after a RESET command */
+ protected def configure() {
+ // TODO: refactor this so we can work with other databases.
+ runSqlHive(
+ s"set javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$metastorePath;create=true")
+ runSqlHive("set hive.metastore.warehouse.dir=" + warehousePath)
+ }
+
+ configure() // Must be called before initializing the catalog below.
+}
+
+/**
+ * An instance of the Spark SQL execution engine that integrates with data stored in Hive.
+ * Configuration for Hive is read from hive-site.xml on the classpath.
+ */
+class HiveContext(sc: SparkContext) extends SQLContext(sc) {
+ self =>
+
+ override def parseSql(sql: String): LogicalPlan = HiveQl.parseSql(sql)
+ override def executePlan(plan: LogicalPlan): this.QueryExecution =
+ new this.QueryExecution { val logical = plan }
+
+ // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur.
+ @transient
+ protected val outputBuffer = new java.io.OutputStream {
+ var pos: Int = 0
+ var buffer = new Array[Int](10240)
+ def write(i: Int): Unit = {
+ buffer(pos) = i
+ pos = (pos + 1) % buffer.size
+ }
+
+ override def toString = {
+ val (end, start) = buffer.splitAt(pos)
+ val input = new java.io.InputStream {
+ val iterator = (start ++ end).iterator
+
+ def read(): Int = if (iterator.hasNext) iterator.next else -1
+ }
+ val reader = new BufferedReader(new InputStreamReader(input))
+ val stringBuilder = new StringBuilder
+ var line = reader.readLine()
+ while(line != null) {
+ stringBuilder.append(line)
+ stringBuilder.append("\n")
+ line = reader.readLine()
+ }
+ stringBuilder.toString()
+ }
+ }
+
+ @transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState])
+ @transient protected[hive] lazy val sessionState = new SessionState(hiveconf)
+
+ sessionState.err = new PrintStream(outputBuffer, true, "UTF-8")
+ sessionState.out = new PrintStream(outputBuffer, true, "UTF-8")
+
+ /* A catalyst metadata catalog that points to the Hive Metastore. */
+ @transient
+ override lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog
+
+ /* An analyzer that uses the Hive metastore. */
+ @transient
+ override lazy val analyzer = new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false)
+
+ def tables: Seq[BaseRelation] = {
+ // TODO: Move this functionallity to Catalog. Make client protected.
+ val allTables = catalog.client.getAllTables("default")
+ allTables.map(catalog.lookupRelation(None, _, None)).collect { case b: BaseRelation => b }
+ }
+
+ /**
+ * Runs the specified SQL query using Hive.
+ */
+ protected def runSqlHive(sql: String): Seq[String] = {
+ val maxResults = 100000
+ val results = runHive(sql, 100000)
+ // It is very confusing when you only get back some of the results...
+ if (results.size == maxResults) sys.error("RESULTS POSSIBLY TRUNCATED")
+ results
+ }
+
+ // TODO: Move this.
+
+ SessionState.start(sessionState)
+
+ /**
+ * Execute the command using Hive and return the results as a sequence. Each element
+ * in the sequence is one row.
+ */
+ protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = {
+ try {
+ val cmd_trimmed: String = cmd.trim()
+ val tokens: Array[String] = cmd_trimmed.split("\\s+")
+ val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim()
+ val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hiveconf)
+
+ SessionState.start(sessionState)
+
+ if (proc.isInstanceOf[Driver]) {
+ val driver: Driver = proc.asInstanceOf[Driver]
+ driver.init()
+
+ val results = new JArrayList[String]
+ val response: CommandProcessorResponse = driver.run(cmd)
+ // Throw an exception if there is an error in query processing.
+ if (response.getResponseCode != 0) {
+ driver.destroy()
+ throw new QueryExecutionException(response.getErrorMessage)
+ }
+ driver.setMaxRows(maxRows)
+ driver.getResults(results)
+ driver.destroy()
+ results
+ } else {
+ sessionState.out.println(tokens(0) + " " + cmd_1)
+ Seq(proc.run(cmd_1).getResponseCode.toString)
+ }
+ } catch {
+ case e: Exception =>
+ logger.error(
+ s"""
+ |======================
+ |HIVE FAILURE OUTPUT
+ |======================
+ |${outputBuffer.toString}
+ |======================
+ |END HIVE FAILURE OUTPUT
+ |======================
+ """.stripMargin)
+ throw e
+ }
+ }
+
+ @transient
+ val hivePlanner = new SparkPlanner with HiveStrategies {
+ val hiveContext = self
+
+ override val strategies: Seq[Strategy] = Seq(
+ TopK,
+ ColumnPrunings,
+ PartitionPrunings,
+ HiveTableScans,
+ DataSinks,
+ Scripts,
+ PartialAggregation,
+ SparkEquiInnerJoin,
+ BasicOperators,
+ CartesianProduct,
+ BroadcastNestedLoopJoin
+ )
+ }
+
+ @transient
+ override val planner = hivePlanner
+
+ @transient
+ protected lazy val emptyResult =
+ sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)
+
+ /** Extends QueryExecution with hive specific features. */
+ abstract class QueryExecution extends super.QueryExecution {
+ // TODO: Create mixin for the analyzer instead of overriding things here.
+ override lazy val optimizedPlan =
+ optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))
+
+ // TODO: We are loosing schema here.
+ override lazy val toRdd: RDD[Row] =
+ analyzed match {
+ case NativeCommand(cmd) =>
+ val output = runSqlHive(cmd)
+
+ if (output.size == 0) {
+ emptyResult
+ } else {
+ val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]]))
+ sparkContext.parallelize(asRows, 1)
+ }
+ case _ =>
+ executedPlan.execute.map(_.copy())
+ }
+
+ protected val primitiveTypes =
+ Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
+ ShortType, DecimalType)
+
+ protected def toHiveString(a: (Any, DataType)): String = a match {
+ case (struct: Row, StructType(fields)) =>
+ struct.zip(fields).map {
+ case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
+ }.mkString("{", ",", "}")
+ case (seq: Seq[_], ArrayType(typ))=>
+ seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
+ case (map: Map[_,_], MapType(kType, vType)) =>
+ map.map {
+ case (key, value) =>
+ toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
+ }.toSeq.sorted.mkString("{", ",", "}")
+ case (null, _) => "NULL"
+ case (other, tpe) if primitiveTypes contains tpe => other.toString
+ }
+
+ /** Hive outputs fields of structs slightly differently than top level attributes. */
+ protected def toHiveStructString(a: (Any, DataType)): String = a match {
+ case (struct: Row, StructType(fields)) =>
+ struct.zip(fields).map {
+ case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}"""
+ }.mkString("{", ",", "}")
+ case (seq: Seq[_], ArrayType(typ))=>
+ seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]")
+ case (map: Map[_,_], MapType(kType, vType)) =>
+ map.map {
+ case (key, value) =>
+ toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType))
+ }.toSeq.sorted.mkString("{", ",", "}")
+ case (null, _) => "null"
+ case (s: String, StringType) => "\"" + s + "\""
+ case (other, tpe) if primitiveTypes contains tpe => other.toString
+ }
+
+ /**
+ * Returns the result as a hive compatible sequence of strings. For native commands, the
+ * execution is simply passed back to Hive.
+ */
+ def stringResult(): Seq[String] = analyzed match {
+ case NativeCommand(cmd) => runSqlHive(cmd)
+ case ExplainCommand(plan) => new QueryExecution { val logical = plan }.toString.split("\n")
+ case query =>
+ val result: Seq[Seq[Any]] = toRdd.collect().toSeq
+ // We need the types so we can output struct field names
+ val types = analyzed.output.map(_.dataType)
+ // Reformat to match hive tab delimited output.
+ val asString = result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")).toSeq
+ asString
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9aadcffa/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
new file mode 100644
index 0000000..e4d5072
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import scala.util.parsing.combinator.RegexParsers
+
+import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo}
+import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition}
+import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
+import org.apache.hadoop.hive.ql.plan.TableDesc
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.serde2.Deserializer
+
+import catalyst.analysis.Catalog
+import catalyst.expressions._
+import catalyst.plans.logical
+import catalyst.plans.logical._
+import catalyst.rules._
+import catalyst.types._
+
+import scala.collection.JavaConversions._
+
+class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging {
+ import HiveMetastoreTypes._
+
+ val client = Hive.get(hive.hiveconf)
+
+ def lookupRelation(
+ db: Option[String],
+ tableName: String,
+ alias: Option[String]): LogicalPlan = {
+ val databaseName = db.getOrElse(hive.sessionState.getCurrentDatabase())
+ val table = client.getTable(databaseName, tableName)
+ val partitions: Seq[Partition] =
+ if (table.isPartitioned) {
+ client.getPartitions(table)
+ } else {
+ Nil
+ }
+
+ // Since HiveQL is case insensitive for table names we make them all lowercase.
+ MetastoreRelation(
+ databaseName.toLowerCase,
+ tableName.toLowerCase,
+ alias)(table.getTTable, partitions.map(part => part.getTPartition))
+ }
+
+ def createTable(databaseName: String, tableName: String, schema: Seq[Attribute]) {
+ val table = new Table(databaseName, tableName)
+ val hiveSchema =
+ schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), ""))
+ table.setFields(hiveSchema)
+
+ val sd = new StorageDescriptor()
+ table.getTTable.setSd(sd)
+ sd.setCols(hiveSchema)
+
+ // TODO: THESE ARE ALL DEFAULTS, WE NEED TO PARSE / UNDERSTAND the output specs.
+ sd.setCompressed(false)
+ sd.setParameters(Map[String, String]())
+ sd.setInputFormat("org.apache.hadoop.mapred.TextInputFormat")
+ sd.setOutputFormat("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")
+ val serDeInfo = new SerDeInfo()
+ serDeInfo.setName(tableName)
+ serDeInfo.setSerializationLib("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")
+ serDeInfo.setParameters(Map[String, String]())
+ sd.setSerdeInfo(serDeInfo)
+ client.createTable(table)
+ }
+
+ /**
+ * Creates any tables required for query execution.
+ * For example, because of a CREATE TABLE X AS statement.
+ */
+ object CreateTables extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case InsertIntoCreatedTable(db, tableName, child) =>
+ val databaseName = db.getOrElse(SessionState.get.getCurrentDatabase())
+
+ createTable(databaseName, tableName, child.output)
+
+ InsertIntoTable(
+ lookupRelation(Some(databaseName), tableName, None).asInstanceOf[BaseRelation],
+ Map.empty,
+ child,
+ overwrite = false)
+ }
+ }
+
+ /**
+ * Casts input data to correct data types according to table definition before inserting into
+ * that table.
+ */
+ object PreInsertionCasts extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
+ // Wait until children are resolved
+ case p: LogicalPlan if !p.childrenResolved => p
+
+ case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) =>
+ val childOutputDataTypes = child.output.map(_.dataType)
+ // Only check attributes, not partitionKeys since they are always strings.
+ // TODO: Fully support inserting into partitioned tables.
+ val tableOutputDataTypes = table.attributes.map(_.dataType)
+
+ if (childOutputDataTypes == tableOutputDataTypes) {
+ p
+ } else {
+ // Only do the casting when child output data types differ from table output data types.
+ val castedChildOutput = child.output.zip(table.output).map {
+ case (input, table) if input.dataType != table.dataType =>
+ Alias(Cast(input, table.dataType), input.name)()
+ case (input, _) => input
+ }
+
+ p.copy(child = logical.Project(castedChildOutput, child))
+ }
+ }
+ }
+
+ /**
+ * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore.
+ * For now, if this functionallity is desired mix in the in-memory [[OverrideCatalog]].
+ */
+ override def registerTable(
+ databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = ???
+}
+
+object HiveMetastoreTypes extends RegexParsers {
+ protected lazy val primitiveType: Parser[DataType] =
+ "string" ^^^ StringType |
+ "float" ^^^ FloatType |
+ "int" ^^^ IntegerType |
+ "tinyint" ^^^ ShortType |
+ "double" ^^^ DoubleType |
+ "bigint" ^^^ LongType |
+ "binary" ^^^ BinaryType |
+ "boolean" ^^^ BooleanType |
+ "decimal" ^^^ DecimalType |
+ "varchar\\((\\d+)\\)".r ^^^ StringType
+
+ protected lazy val arrayType: Parser[DataType] =
+ "array" ~> "<" ~> dataType <~ ">" ^^ ArrayType
+
+ protected lazy val mapType: Parser[DataType] =
+ "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ {
+ case t1 ~ _ ~ t2 => MapType(t1, t2)
+ }
+
+ protected lazy val structField: Parser[StructField] =
+ "[a-zA-Z0-9]*".r ~ ":" ~ dataType ^^ {
+ case name ~ _ ~ tpe => StructField(name, tpe, nullable = true)
+ }
+
+ protected lazy val structType: Parser[DataType] =
+ "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ StructType
+
+ protected lazy val dataType: Parser[DataType] =
+ arrayType |
+ mapType |
+ structType |
+ primitiveType
+
+ def toDataType(metastoreType: String): DataType = parseAll(dataType, metastoreType) match {
+ case Success(result, _) => result
+ case failure: NoSuccess => sys.error(s"Unsupported dataType: $metastoreType")
+ }
+
+ def toMetastoreType(dt: DataType): String = dt match {
+ case ArrayType(elementType) => s"array<${toMetastoreType(elementType)}>"
+ case StructType(fields) =>
+ s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>"
+ case MapType(keyType, valueType) =>
+ s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>"
+ case StringType => "string"
+ case FloatType => "float"
+ case IntegerType => "int"
+ case ShortType =>"tinyint"
+ case DoubleType => "double"
+ case LongType => "bigint"
+ case BinaryType => "binary"
+ case BooleanType => "boolean"
+ case DecimalType => "decimal"
+ }
+}
+
+case class MetastoreRelation(databaseName: String, tableName: String, alias: Option[String])
+ (val table: TTable, val partitions: Seq[TPartition])
+ extends BaseRelation {
+ // TODO: Can we use org.apache.hadoop.hive.ql.metadata.Table as the type of table and
+ // use org.apache.hadoop.hive.ql.metadata.Partition as the type of elements of partitions.
+ // Right now, using org.apache.hadoop.hive.ql.metadata.Table and
+ // org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException
+ // which indicates the SerDe we used is not Serializable.
+
+ def hiveQlTable = new Table(table)
+
+ def hiveQlPartitions = partitions.map { p =>
+ new Partition(hiveQlTable, p)
+ }
+
+ override def isPartitioned = hiveQlTable.isPartitioned
+
+ val tableDesc = new TableDesc(
+ Class.forName(hiveQlTable.getSerializationLib).asInstanceOf[Class[Deserializer]],
+ hiveQlTable.getInputFormatClass,
+ // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because
+ // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to
+ // substitute some output formats, e.g. substituting SequenceFileOutputFormat to
+ // HiveSequenceFileOutputFormat.
+ hiveQlTable.getOutputFormatClass,
+ hiveQlTable.getMetadata
+ )
+
+ implicit class SchemaAttribute(f: FieldSchema) {
+ def toAttribute = AttributeReference(
+ f.getName,
+ HiveMetastoreTypes.toDataType(f.getType),
+ // Since data can be dumped in randomly with no validation, everything is nullable.
+ nullable = true
+ )(qualifiers = tableName +: alias.toSeq)
+ }
+
+ // Must be a stable value since new attributes are born here.
+ val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute)
+
+ /** Non-partitionKey attributes */
+ val attributes = table.getSd.getCols.map(_.toAttribute)
+
+ val output = attributes ++ partitionKeys
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/9aadcffa/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
new file mode 100644
index 0000000..4f33a29
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -0,0 +1,966 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+package hive
+
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.hive.ql.lib.Node
+import org.apache.hadoop.hive.ql.parse._
+import org.apache.hadoop.hive.ql.plan.PlanUtils
+
+import catalyst.analysis._
+import catalyst.expressions._
+import catalyst.plans._
+import catalyst.plans.logical
+import catalyst.plans.logical._
+import catalyst.types._
+
+/**
+ * Used when we need to start parsing the AST before deciding that we are going to pass the command
+ * back for Hive to execute natively. Will be replaced with a native command that contains the
+ * cmd string.
+ */
+case object NativePlaceholder extends Command
+
+case class DfsCommand(cmd: String) extends Command
+
+case class ShellCommand(cmd: String) extends Command
+
+case class SourceCommand(filePath: String) extends Command
+
+case class AddJar(jarPath: String) extends Command
+
+case class AddFile(filePath: String) extends Command
+
+/** Provides a mapping from HiveQL statments to catalyst logical plans and expression trees. */
+object HiveQl {
+ protected val nativeCommands = Seq(
+ "TOK_DESCFUNCTION",
+ "TOK_DESCTABLE",
+ "TOK_DESCDATABASE",
+ "TOK_SHOW_TABLESTATUS",
+ "TOK_SHOWDATABASES",
+ "TOK_SHOWFUNCTIONS",
+ "TOK_SHOWINDEXES",
+ "TOK_SHOWINDEXES",
+ "TOK_SHOWPARTITIONS",
+ "TOK_SHOWTABLES",
+
+ "TOK_LOCKTABLE",
+ "TOK_SHOWLOCKS",
+ "TOK_UNLOCKTABLE",
+
+ "TOK_CREATEROLE",
+ "TOK_DROPROLE",
+ "TOK_GRANT",
+ "TOK_GRANT_ROLE",
+ "TOK_REVOKE",
+ "TOK_SHOW_GRANT",
+ "TOK_SHOW_ROLE_GRANT",
+
+ "TOK_CREATEFUNCTION",
+ "TOK_DROPFUNCTION",
+
+ "TOK_ANALYZE",
+ "TOK_ALTERDATABASE_PROPERTIES",
+ "TOK_ALTERINDEX_PROPERTIES",
+ "TOK_ALTERINDEX_REBUILD",
+ "TOK_ALTERTABLE_ADDCOLS",
+ "TOK_ALTERTABLE_ADDPARTS",
+ "TOK_ALTERTABLE_ALTERPARTS",
+ "TOK_ALTERTABLE_ARCHIVE",
+ "TOK_ALTERTABLE_CLUSTER_SORT",
+ "TOK_ALTERTABLE_DROPPARTS",
+ "TOK_ALTERTABLE_PARTITION",
+ "TOK_ALTERTABLE_PROPERTIES",
+ "TOK_ALTERTABLE_RENAME",
+ "TOK_ALTERTABLE_RENAMECOL",
+ "TOK_ALTERTABLE_REPLACECOLS",
+ "TOK_ALTERTABLE_SKEWED",
+ "TOK_ALTERTABLE_TOUCH",
+ "TOK_ALTERTABLE_UNARCHIVE",
+ "TOK_ANALYZE",
+ "TOK_CREATEDATABASE",
+ "TOK_CREATEFUNCTION",
+ "TOK_CREATEINDEX",
+ "TOK_DROPDATABASE",
+ "TOK_DROPINDEX",
+ "TOK_DROPTABLE",
+ "TOK_MSCK",
+
+ // TODO(marmbrus): Figure out how view are expanded by hive, as we might need to handle this.
+ "TOK_ALTERVIEW_ADDPARTS",
+ "TOK_ALTERVIEW_AS",
+ "TOK_ALTERVIEW_DROPPARTS",
+ "TOK_ALTERVIEW_PROPERTIES",
+ "TOK_ALTERVIEW_RENAME",
+ "TOK_CREATEVIEW",
+ "TOK_DROPVIEW",
+
+ "TOK_EXPORT",
+ "TOK_IMPORT",
+ "TOK_LOAD",
+
+ "TOK_SWITCHDATABASE"
+ )
+
+ /**
+ * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations
+ * similar to [[catalyst.trees.TreeNode]].
+ *
+ * Note that this should be considered very experimental and is not indented as a replacement
+ * for TreeNode. Primarily it should be noted ASTNodes are not immutable and do not appear to
+ * have clean copy semantics. Therefore, users of this class should take care when
+ * copying/modifying trees that might be used elsewhere.
+ */
+ implicit class TransformableNode(n: ASTNode) {
+ /**
+ * Returns a copy of this node where `rule` has been recursively applied to it and all of its
+ * children. When `rule` does not apply to a given node it is left unchanged.
+ * @param rule the function use to transform this nodes children
+ */
+ def transform(rule: PartialFunction[ASTNode, ASTNode]): ASTNode = {
+ try {
+ val afterRule = rule.applyOrElse(n, identity[ASTNode])
+ afterRule.withChildren(
+ nilIfEmpty(afterRule.getChildren)
+ .asInstanceOf[Seq[ASTNode]]
+ .map(ast => Option(ast).map(_.transform(rule)).orNull))
+ } catch {
+ case e: Exception =>
+ println(dumpTree(n))
+ throw e
+ }
+ }
+
+ /**
+ * Returns a scala.Seq equivilent to [s] or Nil if [s] is null.
+ */
+ private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] =
+ Option(s).map(_.toSeq).getOrElse(Nil)
+
+ /**
+ * Returns this ASTNode with the text changed to `newText``.
+ */
+ def withText(newText: String): ASTNode = {
+ n.token.asInstanceOf[org.antlr.runtime.CommonToken].setText(newText)
+ n
+ }
+
+ /**
+ * Returns this ASTNode with the children changed to `newChildren`.
+ */
+ def withChildren(newChildren: Seq[ASTNode]): ASTNode = {
+ (1 to n.getChildCount).foreach(_ => n.deleteChild(0))
+ n.addChildren(newChildren)
+ n
+ }
+
+ /**
+ * Throws an error if this is not equal to other.
+ *
+ * Right now this function only checks the name, type, text and children of the node
+ * for equality.
+ */
+ def checkEquals(other: ASTNode) {
+ def check(field: String, f: ASTNode => Any) = if (f(n) != f(other)) {
+ sys.error(s"$field does not match for trees. " +
+ s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}")
+ }
+ check("name", _.getName)
+ check("type", _.getType)
+ check("text", _.getText)
+ check("numChildren", n => nilIfEmpty(n.getChildren).size)
+
+ val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]]
+ val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]]
+ leftChildren zip rightChildren foreach {
+ case (l,r) => l checkEquals r
+ }
+ }
+ }
+
+ class ParseException(sql: String, cause: Throwable)
+ extends Exception(s"Failed to parse: $sql", cause)
+
+ /**
+ * Returns the AST for the given SQL string.
+ */
+ def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql))
+
+ /** Returns a LogicalPlan for a given HiveQL string. */
+ def parseSql(sql: String): LogicalPlan = {
+ try {
+ if (sql.toLowerCase.startsWith("set")) {
+ NativeCommand(sql)
+ } else if (sql.toLowerCase.startsWith("add jar")) {
+ AddJar(sql.drop(8))
+ } else if (sql.toLowerCase.startsWith("add file")) {
+ AddFile(sql.drop(9))
+ } else if (sql.startsWith("dfs")) {
+ DfsCommand(sql)
+ } else if (sql.startsWith("source")) {
+ SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath })
+ } else if (sql.startsWith("!")) {
+ ShellCommand(sql.drop(1))
+ } else {
+ val tree = getAst(sql)
+
+ if (nativeCommands contains tree.getText) {
+ NativeCommand(sql)
+ } else {
+ nodeToPlan(tree) match {
+ case NativePlaceholder => NativeCommand(sql)
+ case other => other
+ }
+ }
+ }
+ } catch {
+ case e: Exception => throw new ParseException(sql, e)
+ }
+ }
+
+ def parseDdl(ddl: String): Seq[Attribute] = {
+ val tree =
+ try {
+ ParseUtils.findRootNonNullToken(
+ (new ParseDriver).parse(ddl, null /* no context required for parsing alone */))
+ } catch {
+ case pe: org.apache.hadoop.hive.ql.parse.ParseException =>
+ throw new RuntimeException(s"Failed to parse ddl: '$ddl'", pe)
+ }
+ assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.")
+ val tableOps = tree.getChildren
+ val colList =
+ tableOps
+ .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST")
+ .getOrElse(sys.error("No columnList!")).getChildren
+
+ colList.map(nodeToAttribute)
+ }
+
+ /** Extractor for matching Hive's AST Tokens. */
+ object Token {
+ /** @return matches of the form (tokenName, children). */
+ def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match {
+ case t: ASTNode =>
+ Some((t.getText,
+ Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]]))
+ case _ => None
+ }
+ }
+
+ protected def getClauses(clauseNames: Seq[String], nodeList: Seq[ASTNode]): Seq[Option[Node]] = {
+ var remainingNodes = nodeList
+ val clauses = clauseNames.map { clauseName =>
+ val (matches, nonMatches) = remainingNodes.partition(_.getText.toUpperCase == clauseName)
+ remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil)
+ matches.headOption
+ }
+
+ assert(remainingNodes.isEmpty,
+ s"Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}")
+ clauses
+ }
+
+ def getClause(clauseName: String, nodeList: Seq[Node]) =
+ getClauseOption(clauseName, nodeList).getOrElse(sys.error(
+ s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}"))
+
+ def getClauseOption(clauseName: String, nodeList: Seq[Node]): Option[Node] = {
+ nodeList.filter { case ast: ASTNode => ast.getText == clauseName } match {
+ case Seq(oneMatch) => Some(oneMatch)
+ case Seq() => None
+ case _ => sys.error(s"Found multiple instances of clause $clauseName")
+ }
+ }
+
+ protected def nodeToAttribute(node: Node): Attribute = node match {
+ case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) =>
+ AttributeReference(colName, nodeToDataType(dataType), true)()
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ protected def nodeToDataType(node: Node): DataType = node match {
+ case Token("TOK_BIGINT", Nil) => IntegerType
+ case Token("TOK_INT", Nil) => IntegerType
+ case Token("TOK_TINYINT", Nil) => IntegerType
+ case Token("TOK_SMALLINT", Nil) => IntegerType
+ case Token("TOK_BOOLEAN", Nil) => BooleanType
+ case Token("TOK_STRING", Nil) => StringType
+ case Token("TOK_FLOAT", Nil) => FloatType
+ case Token("TOK_DOUBLE", Nil) => FloatType
+ case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType))
+ case Token("TOK_STRUCT",
+ Token("TOK_TABCOLLIST", fields) :: Nil) =>
+ StructType(fields.map(nodeToStructField))
+ case Token("TOK_MAP",
+ keyType ::
+ valueType :: Nil) =>
+ MapType(nodeToDataType(keyType), nodeToDataType(valueType))
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for DataType:\n ${dumpTree(a).toString} ")
+ }
+
+ protected def nodeToStructField(node: Node): StructField = node match {
+ case Token("TOK_TABCOL",
+ Token(fieldName, Nil) ::
+ dataType :: Nil) =>
+ StructField(fieldName, nodeToDataType(dataType), nullable = true)
+ case Token("TOK_TABCOL",
+ Token(fieldName, Nil) ::
+ dataType ::
+ _ /* comment */:: Nil) =>
+ StructField(fieldName, nodeToDataType(dataType), nullable = true)
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ")
+ }
+
+ protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = {
+ exprs.zipWithIndex.map {
+ case (ne: NamedExpression, _) => ne
+ case (e, i) => Alias(e, s"c_$i")()
+ }
+ }
+
+ protected def nodeToPlan(node: Node): LogicalPlan = node match {
+ // Just fake explain for any of the native commands.
+ case Token("TOK_EXPLAIN", explainArgs) if nativeCommands contains explainArgs.head.getText =>
+ NoRelation
+ case Token("TOK_EXPLAIN", explainArgs) =>
+ // Ignore FORMATTED if present.
+ val Some(query) :: _ :: _ :: Nil =
+ getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs)
+ // TODO: support EXTENDED?
+ ExplainCommand(nodeToPlan(query))
+
+ case Token("TOK_CREATETABLE", children)
+ if children.collect { case t@Token("TOK_QUERY", _) => t }.nonEmpty =>
+ // TODO: Parse other clauses.
+ // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
+ val (
+ Some(tableNameParts) ::
+ _ /* likeTable */ ::
+ Some(query) +:
+ notImplemented) =
+ getClauses(
+ Seq(
+ "TOK_TABNAME",
+ "TOK_LIKETABLE",
+ "TOK_QUERY",
+ "TOK_IFNOTEXISTS",
+ "TOK_TABLECOMMENT",
+ "TOK_TABCOLLIST",
+ "TOK_TABLEPARTCOLS", // Partitioned by
+ "TOK_TABLEBUCKETS", // Clustered by
+ "TOK_TABLESKEWED", // Skewed by
+ "TOK_TABLEROWFORMAT",
+ "TOK_TABLESERIALIZER",
+ "TOK_FILEFORMAT_GENERIC", // For file formats not natively supported by Hive.
+ "TOK_TBLSEQUENCEFILE", // Stored as SequenceFile
+ "TOK_TBLTEXTFILE", // Stored as TextFile
+ "TOK_TBLRCFILE", // Stored as RCFile
+ "TOK_TBLORCFILE", // Stored as ORC File
+ "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat
+ "TOK_STORAGEHANDLER", // Storage handler
+ "TOK_TABLELOCATION",
+ "TOK_TABLEPROPERTIES"),
+ children)
+ if (notImplemented.exists(token => !token.isEmpty)) {
+ throw new NotImplementedError(
+ s"Unhandled clauses: ${notImplemented.flatten.map(dumpTree(_)).mkString("\n")}")
+ }
+
+ val (db, tableName) =
+ tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match {
+ case Seq(tableOnly) => (None, tableOnly)
+ case Seq(databaseName, table) => (Some(databaseName), table)
+ }
+ InsertIntoCreatedTable(db, tableName, nodeToPlan(query))
+
+ // If its not a "CREATE TABLE AS" like above then just pass it back to hive as a native command.
+ case Token("TOK_CREATETABLE", _) => NativePlaceholder
+
+ case Token("TOK_QUERY",
+ Token("TOK_FROM", fromClause :: Nil) ::
+ insertClauses) =>
+
+ // Return one query for each insert clause.
+ val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) =>
+ val (
+ intoClause ::
+ destClause ::
+ selectClause ::
+ selectDistinctClause ::
+ whereClause ::
+ groupByClause ::
+ orderByClause ::
+ sortByClause ::
+ clusterByClause ::
+ distributeByClause ::
+ limitClause ::
+ lateralViewClause :: Nil) = {
+ getClauses(
+ Seq(
+ "TOK_INSERT_INTO",
+ "TOK_DESTINATION",
+ "TOK_SELECT",
+ "TOK_SELECTDI",
+ "TOK_WHERE",
+ "TOK_GROUPBY",
+ "TOK_ORDERBY",
+ "TOK_SORTBY",
+ "TOK_CLUSTERBY",
+ "TOK_DISTRIBUTEBY",
+ "TOK_LIMIT",
+ "TOK_LATERAL_VIEW"),
+ singleInsert)
+ }
+
+ val relations = nodeToRelation(fromClause)
+ val withWhere = whereClause.map { whereNode =>
+ val Seq(whereExpr) = whereNode.getChildren.toSeq
+ Filter(nodeToExpr(whereExpr), relations)
+ }.getOrElse(relations)
+
+ val select =
+ (selectClause orElse selectDistinctClause).getOrElse(sys.error("No select clause."))
+
+ // Script transformations are expressed as a select clause with a single expression of type
+ // TOK_TRANSFORM
+ val transformation = select.getChildren.head match {
+ case Token("TOK_SELEXPR",
+ Token("TOK_TRANSFORM",
+ Token("TOK_EXPLIST", inputExprs) ::
+ Token("TOK_SERDE", Nil) ::
+ Token("TOK_RECORDWRITER", writerClause) ::
+ // TODO: Need to support other types of (in/out)put
+ Token(script, Nil) ::
+ Token("TOK_SERDE", serdeClause) ::
+ Token("TOK_RECORDREADER", readerClause) ::
+ outputClause :: Nil) :: Nil) =>
+
+ val output = outputClause match {
+ case Token("TOK_ALIASLIST", aliases) =>
+ aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() }
+ case Token("TOK_TABCOLLIST", attributes) =>
+ attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) =>
+ AttributeReference(name, nodeToDataType(dataType))() }
+ }
+ val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script)
+
+ Some(
+ logical.ScriptTransformation(
+ inputExprs.map(nodeToExpr),
+ unescapedScript,
+ output,
+ withWhere))
+ case _ => None
+ }
+
+ val withLateralView = lateralViewClause.map { lv =>
+ val Token("TOK_SELECT",
+ Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.head
+
+ val alias =
+ getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText
+
+ Generate(
+ nodesToGenerator(clauses),
+ join = true,
+ outer = false,
+ Some(alias.toLowerCase),
+ withWhere)
+ }.getOrElse(withWhere)
+
+
+ // The projection of the query can either be a normal projection, an aggregation
+ // (if there is a group by) or a script transformation.
+ val withProject = transformation.getOrElse {
+ // Not a transformation so must be either project or aggregation.
+ val selectExpressions = nameExpressions(select.getChildren.flatMap(selExprNodeToExpr))
+
+ groupByClause match {
+ case Some(groupBy) =>
+ Aggregate(groupBy.getChildren.map(nodeToExpr), selectExpressions, withLateralView)
+ case None =>
+ Project(selectExpressions, withLateralView)
+ }
+ }
+
+ val withDistinct =
+ if (selectDistinctClause.isDefined) Distinct(withProject) else withProject
+
+ val withSort =
+ (orderByClause, sortByClause, distributeByClause, clusterByClause) match {
+ case (Some(totalOrdering), None, None, None) =>
+ Sort(totalOrdering.getChildren.map(nodeToSortOrder), withDistinct)
+ case (None, Some(perPartitionOrdering), None, None) =>
+ SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withDistinct)
+ case (None, None, Some(partitionExprs), None) =>
+ Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct)
+ case (None, Some(perPartitionOrdering), Some(partitionExprs), None) =>
+ SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder),
+ Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct))
+ case (None, None, None, Some(clusterExprs)) =>
+ SortPartitions(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)),
+ Repartition(clusterExprs.getChildren.map(nodeToExpr), withDistinct))
+ case (None, None, None, None) => withDistinct
+ case _ => sys.error("Unsupported set of ordering / distribution clauses.")
+ }
+
+ val withLimit =
+ limitClause.map(l => nodeToExpr(l.getChildren.head))
+ .map(StopAfter(_, withSort))
+ .getOrElse(withSort)
+
+ // TOK_INSERT_INTO means to add files to the table.
+ // TOK_DESTINATION means to overwrite the table.
+ val resultDestination =
+ (intoClause orElse destClause).getOrElse(sys.error("No destination found."))
+ val overwrite = if (intoClause.isEmpty) true else false
+ nodeToDest(
+ resultDestination,
+ withLimit,
+ overwrite)
+ }
+
+ // If there are multiple INSERTS just UNION them together into on query.
+ queries.reduceLeft(Union)
+
+ case Token("TOK_UNION", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right))
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ val allJoinTokens = "(TOK_.*JOIN)".r
+ val laterViewToken = "TOK_LATERAL_VIEW(.*)".r
+ def nodeToRelation(node: Node): LogicalPlan = node match {
+ case Token("TOK_SUBQUERY",
+ query :: Token(alias, Nil) :: Nil) =>
+ Subquery(alias, nodeToPlan(query))
+
+ case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) =>
+ val Token("TOK_SELECT",
+ Token("TOK_SELEXPR", clauses) :: Nil) = selectClause
+
+ val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText
+
+ Generate(
+ nodesToGenerator(clauses),
+ join = true,
+ outer = isOuter.nonEmpty,
+ Some(alias.toLowerCase),
+ nodeToRelation(relationClause))
+
+ /* All relations, possibly with aliases or sampling clauses. */
+ case Token("TOK_TABREF", clauses) =>
+ // If the last clause is not a token then it's the alias of the table.
+ val (nonAliasClauses, aliasClause) =
+ if (clauses.last.getText.startsWith("TOK")) {
+ (clauses, None)
+ } else {
+ (clauses.dropRight(1), Some(clauses.last))
+ }
+
+ val (Some(tableNameParts) ::
+ splitSampleClause ::
+ bucketSampleClause :: Nil) = {
+ getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"),
+ nonAliasClauses)
+ }
+
+ val (db, tableName) =
+ tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match {
+ case Seq(tableOnly) => (None, tableOnly)
+ case Seq(databaseName, table) => (Some(databaseName), table)
+ }
+ val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) }
+ val relation = UnresolvedRelation(db, tableName, alias)
+
+ // Apply sampling if requested.
+ (bucketSampleClause orElse splitSampleClause).map {
+ case Token("TOK_TABLESPLITSAMPLE",
+ Token("TOK_ROWCOUNT", Nil) ::
+ Token(count, Nil) :: Nil) =>
+ StopAfter(Literal(count.toInt), relation)
+ case Token("TOK_TABLESPLITSAMPLE",
+ Token("TOK_PERCENT", Nil) ::
+ Token(fraction, Nil) :: Nil) =>
+ Sample(fraction.toDouble, withReplacement = false, (math.random * 1000).toInt, relation)
+ case Token("TOK_TABLEBUCKETSAMPLE",
+ Token(numerator, Nil) ::
+ Token(denominator, Nil) :: Nil) =>
+ val fraction = numerator.toDouble / denominator.toDouble
+ Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation)
+ case a: ASTNode =>
+ throw new NotImplementedError(
+ s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} :
+ |${dumpTree(a).toString}" +
+ """.stripMargin)
+ }.getOrElse(relation)
+
+ case Token("TOK_UNIQUEJOIN", joinArgs) =>
+ val tableOrdinals =
+ joinArgs.zipWithIndex.filter {
+ case (arg, i) => arg.getText == "TOK_TABREF"
+ }.map(_._2)
+
+ val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE")
+ val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i)))
+ val joinExpressions = tableOrdinals.map(i => joinArgs(i + 1).getChildren.map(nodeToExpr))
+
+ val joinConditions = joinExpressions.sliding(2).map {
+ case Seq(c1, c2) =>
+ val predicates = (c1, c2).zipped.map { case (e1, e2) => Equals(e1, e2): Expression }
+ predicates.reduceLeft(And)
+ }.toBuffer
+
+ val joinType = isPreserved.sliding(2).map {
+ case Seq(true, true) => FullOuter
+ case Seq(true, false) => LeftOuter
+ case Seq(false, true) => RightOuter
+ case Seq(false, false) => Inner
+ }.toBuffer
+
+ val joinedTables = tables.reduceLeft(Join(_,_, Inner, None))
+
+ // Must be transform down.
+ val joinedResult = joinedTables transform {
+ case j: Join =>
+ j.copy(
+ condition = Some(joinConditions.remove(joinConditions.length - 1)),
+ joinType = joinType.remove(joinType.length - 1))
+ }
+
+ val groups = (0 until joinExpressions.head.size).map(i => Coalesce(joinExpressions.map(_(i))))
+
+ // Unique join is not really the same as an outer join so we must group together results where
+ // the joinExpressions are the same, taking the First of each value is only okay because the
+ // user of a unique join is implicitly promising that there is only one result.
+ // TODO: This doesn't actually work since [[Star]] is not a valid aggregate expression.
+ // instead we should figure out how important supporting this feature is and whether it is
+ // worth the number of hacks that will be required to implement it. Namely, we need to add
+ // some sort of mapped star expansion that would expand all child output row to be similarly
+ // named output expressions where some aggregate expression has been applied (i.e. First).
+ ??? /// Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult)
+
+ case Token(allJoinTokens(joinToken),
+ relation1 ::
+ relation2 :: other) =>
+ assert(other.size <= 1, s"Unhandled join child ${other}")
+ val joinType = joinToken match {
+ case "TOK_JOIN" => Inner
+ case "TOK_RIGHTOUTERJOIN" => RightOuter
+ case "TOK_LEFTOUTERJOIN" => LeftOuter
+ case "TOK_FULLOUTERJOIN" => FullOuter
+ }
+ assert(other.size <= 1, "Unhandled join clauses.")
+ Join(nodeToRelation(relation1),
+ nodeToRelation(relation2),
+ joinType,
+ other.headOption.map(nodeToExpr))
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ def nodeToSortOrder(node: Node): SortOrder = node match {
+ case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) =>
+ SortOrder(nodeToExpr(sortExpr), Ascending)
+ case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) =>
+ SortOrder(nodeToExpr(sortExpr), Descending)
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r
+ protected def nodeToDest(
+ node: Node,
+ query: LogicalPlan,
+ overwrite: Boolean): LogicalPlan = node match {
+ case Token(destinationToken(),
+ Token("TOK_DIR",
+ Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) =>
+ query
+
+ case Token(destinationToken(),
+ Token("TOK_TAB",
+ tableArgs) :: Nil) =>
+ val Some(tableNameParts) :: partitionClause :: Nil =
+ getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs)
+
+ val (db, tableName) =
+ tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match {
+ case Seq(tableOnly) => (None, tableOnly)
+ case Seq(databaseName, table) => (Some(databaseName), table)
+ }
+
+ val partitionKeys = partitionClause.map(_.getChildren.map {
+ // Parse partitions. We also make keys case insensitive.
+ case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) =>
+ cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value))
+ case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) =>
+ cleanIdentifier(key.toLowerCase) -> None
+ }.toMap).getOrElse(Map.empty)
+
+ if (partitionKeys.values.exists(p => p.isEmpty)) {
+ throw new NotImplementedError(s"Do not support INSERT INTO/OVERWRITE with" +
+ s"dynamic partitioning.")
+ }
+
+ InsertIntoTable(UnresolvedRelation(db, tableName, None), partitionKeys, query, overwrite)
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+ protected def selExprNodeToExpr(node: Node): Option[Expression] = node match {
+ case Token("TOK_SELEXPR",
+ e :: Nil) =>
+ Some(nodeToExpr(e))
+
+ case Token("TOK_SELEXPR",
+ e :: Token(alias, Nil) :: Nil) =>
+ Some(Alias(nodeToExpr(e), alias)())
+
+ /* Hints are ignored */
+ case Token("TOK_HINTLIST", _) => None
+
+ case a: ASTNode =>
+ throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ")
+ }
+
+
+ protected val escapedIdentifier = "`([^`]+)`".r
+ /** Strips backticks from ident if present */
+ protected def cleanIdentifier(ident: String): String = ident match {
+ case escapedIdentifier(i) => i
+ case plainIdent => plainIdent
+ }
+
+ val numericAstTypes = Seq(
+ HiveParser.Number,
+ HiveParser.TinyintLiteral,
+ HiveParser.SmallintLiteral,
+ HiveParser.BigintLiteral)
+
+ /* Case insensitive matches */
+ val COUNT = "(?i)COUNT".r
+ val AVG = "(?i)AVG".r
+ val SUM = "(?i)SUM".r
+ val RAND = "(?i)RAND".r
+ val AND = "(?i)AND".r
+ val OR = "(?i)OR".r
+ val NOT = "(?i)NOT".r
+ val TRUE = "(?i)TRUE".r
+ val FALSE = "(?i)FALSE".r
+
+ protected def nodeToExpr(node: Node): Expression = node match {
+ /* Attribute References */
+ case Token("TOK_TABLE_OR_COL",
+ Token(name, Nil) :: Nil) =>
+ UnresolvedAttribute(cleanIdentifier(name))
+ case Token(".", qualifier :: Token(attr, Nil) :: Nil) =>
+ nodeToExpr(qualifier) match {
+ case UnresolvedAttribute(qualifierName) =>
+ UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr))
+ // The precidence for . seems to be wrong, so [] binds tighter an we need to go inside to
+ // find the underlying attribute references.
+ case GetItem(UnresolvedAttribute(qualifierName), ordinal) =>
+ GetItem(UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)), ordinal)
+ }
+
+ /* Stars (*) */
+ case Token("TOK_ALLCOLREF", Nil) => Star(None)
+ // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only
+ // has a single child which is tableName.
+ case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) =>
+ Star(Some(name))
+
+ /* Aggregate Functions */
+ case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg))
+ case Token("TOK_FUNCTION", Token(COUNT(), Nil) :: arg :: Nil) => Count(nodeToExpr(arg))
+ case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1))
+ case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr))
+ case Token("TOK_FUNCTION", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg))
+ case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg))
+
+ /* Casts */
+ case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), StringType)
+ case Token("TOK_FUNCTION", Token("TOK_VARCHAR", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), StringType)
+ case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), IntegerType)
+ case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), LongType)
+ case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), FloatType)
+ case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), DoubleType)
+ case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), ShortType)
+ case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), ByteType)
+ case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), BinaryType)
+ case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), BooleanType)
+ case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), DecimalType)
+
+ /* Arithmetic */
+ case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child))
+ case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right))
+ case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right))
+ case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right))
+ case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
+ case Token("DIV", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
+ case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right))
+
+ /* Comparisons */
+ case Token("=", left :: right:: Nil) => Equals(nodeToExpr(left), nodeToExpr(right))
+ case Token("!=", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right)))
+ case Token("<>", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right)))
+ case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right))
+ case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right))
+ case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right))
+ case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right))
+ case Token("LIKE", left :: right:: Nil) =>
+ UnresolvedFunction("LIKE", Seq(nodeToExpr(left), nodeToExpr(right)))
+ case Token("RLIKE", left :: right:: Nil) =>
+ UnresolvedFunction("RLIKE", Seq(nodeToExpr(left), nodeToExpr(right)))
+ case Token("REGEXP", left :: right:: Nil) =>
+ UnresolvedFunction("REGEXP", Seq(nodeToExpr(left), nodeToExpr(right)))
+ case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) =>
+ IsNotNull(nodeToExpr(child))
+ case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) =>
+ IsNull(nodeToExpr(child))
+ case Token("TOK_FUNCTION", Token("IN", Nil) :: value :: list) =>
+ In(nodeToExpr(value), list.map(nodeToExpr))
+
+ /* Boolean Logic */
+ case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right))
+ case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right))
+ case Token(NOT(), child :: Nil) => Not(nodeToExpr(child))
+
+ /* Complex datatype manipulation */
+ case Token("[", child :: ordinal :: Nil) =>
+ GetItem(nodeToExpr(child), nodeToExpr(ordinal))
+
+ /* Other functions */
+ case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand
+
+ /* UDFs - Must be last otherwise will preempt built in functions */
+ case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>
+ UnresolvedFunction(name, args.map(nodeToExpr))
+ case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) =>
+ UnresolvedFunction(name, Star(None) :: Nil)
+
+ /* Literals */
+ case Token("TOK_NULL", Nil) => Literal(null, NullType)
+ case Token(TRUE(), Nil) => Literal(true, BooleanType)
+ case Token(FALSE(), Nil) => Literal(false, BooleanType)
+ case Token("TOK_STRINGLITERALSEQUENCE", strings) =>
+ Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString)
+
+ // This code is adapted from
+ // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223
+ case ast: ASTNode if numericAstTypes contains ast.getType =>
+ var v: Literal = null
+ try {
+ if (ast.getText.endsWith("L")) {
+ // Literal bigint.
+ v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType)
+ } else if (ast.getText.endsWith("S")) {
+ // Literal smallint.
+ v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType)
+ } else if (ast.getText.endsWith("Y")) {
+ // Literal tinyint.
+ v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType)
+ } else if (ast.getText.endsWith("BD")) {
+ // Literal decimal
+ val strVal = ast.getText.substring(0, ast.getText.length() - 2)
+ BigDecimal(strVal)
+ } else {
+ v = Literal(ast.getText.toDouble, DoubleType)
+ v = Literal(ast.getText.toLong, LongType)
+ v = Literal(ast.getText.toInt, IntegerType)
+ }
+ } catch {
+ case nfe: NumberFormatException => // Do nothing
+ }
+
+ if (v == null) {
+ sys.error(s"Failed to parse number ${ast.getText}")
+ } else {
+ v
+ }
+
+ case ast: ASTNode if ast.getType == HiveParser.StringLiteral =>
+ Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText))
+
+ case a: ASTNode =>
+ throw new NotImplementedError(
+ s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} :
+ |${dumpTree(a).toString}" +
+ """.stripMargin)
+ }
+
+
+ val explode = "(?i)explode".r
+ def nodesToGenerator(nodes: Seq[Node]): Generator = {
+ val function = nodes.head
+
+ val attributes = nodes.flatMap {
+ case Token(a, Nil) => a.toLowerCase :: Nil
+ case _ => Nil
+ }
+
+ function match {
+ case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) =>
+ Explode(attributes, nodeToExpr(child))
+
+ case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) =>
+ HiveGenericUdtf(functionName, attributes, children.map(nodeToExpr))
+
+ case a: ASTNode =>
+ throw new NotImplementedError(
+ s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText}, tree:
+ |${dumpTree(a).toString}
+ """.stripMargin)
+ }
+ }
+
+ def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0)
+ : StringBuilder = {
+ node match {
+ case a: ASTNode => builder.append((" " * indent) + a.getText + "\n")
+ case other => sys.error(s"Non ASTNode encountered: $other")
+ }
+
+ Option(node.getChildren).map(_.toList).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1))
+ builder
+ }
+}