You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by ya...@apache.org on 2017/03/09 07:53:14 UTC
incubator-hivemall git commit: Close #61: [HIVEMALL-88][SPARK]
Support a function to flatten nested schemas
Repository: incubator-hivemall
Updated Branches:
refs/heads/master 210b7765b -> 33baa2408
Close #61: [HIVEMALL-88][SPARK] Support a function to flatten nested schemas
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/33baa240
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/33baa240
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/33baa240
Branch: refs/heads/master
Commit: 33baa2408b77895a4feaaa0f60953055657275d0
Parents: 210b776
Author: Takeshi Yamamuro <ya...@apache.org>
Authored: Thu Mar 9 16:53:00 2017 +0900
Committer: Takeshi Yamamuro <ya...@apache.org>
Committed: Thu Mar 9 16:53:00 2017 +0900
----------------------------------------------------------------------
.../datasources/csv/csvExpressions.scala | 153 +++++++++++++++++++
.../org/apache/spark/sql/hive/HivemallOps.scala | 48 ++++++
.../spark/sql/hive/HivemallOpsSuite.scala | 19 +++
3 files changed, 220 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/33baa240/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala
new file mode 100644
index 0000000..363d432
--- /dev/null
+++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import java.io.CharArrayWriter
+
+import jodd.util.CsvUtil
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Converts a csv input string to a [[StructType]] with the specified schema.
+ *
+ * TODO: Move this class into org.apache.spark.sql.catalyst.expressions in Spark-v2.2+
+ */
+case class CsvToStruct(schema: StructType, options: Map[String, String], child: Expression)
+ extends UnaryExpression with CodegenFallback with ExpectsInputTypes {
+ override def nullable: Boolean = true
+
+ @transient private val csvOptions = new CSVOptions(options)
+ @transient private val csvReader = new CsvReader(csvOptions)
+ @transient private val csvParser = CSVRelation.csvParser(schema, schema.fieldNames, csvOptions)
+
+ private def parse(s: String): InternalRow = {
+ csvParser(csvReader.parseLine(s), 0).orNull
+ }
+
+ override def dataType: DataType = schema
+
+ override def nullSafeEval(csv: Any): Any = {
+ try parse(csv.toString) catch { case _: RuntimeException => null }
+ }
+
+ override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
+}
+
+/**
+ * Converts a [[StructType]] to a csv output string.
+ */
+case class StructToCsv(
+ options: Map[String, String],
+ child: Expression)
+ extends UnaryExpression with CodegenFallback with ExpectsInputTypes {
+ override def nullable: Boolean = true
+
+ @transient
+ lazy val params = new CSVOptions(options)
+
+ @transient
+ lazy val dataSchema = child.dataType.asInstanceOf[StructType]
+
+ @transient
+ lazy val writer = new LineCsvWriter(params, dataSchema.fieldNames.toSeq)
+
+ override def dataType: DataType = StringType
+
+ // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`.
+ // When the value is null, this converter should not be called.
+ private type ValueConverter = (InternalRow, Int) => String
+
+ // `ValueConverter`s for all values in the fields of the schema
+ private lazy val valueConverters: Array[ValueConverter] =
+ dataSchema.map(_.dataType).map(makeConverter).toArray
+
+ private def verifySchema(schema: StructType): Unit = {
+ def verifyType(dataType: DataType): Unit = dataType match {
+ case ByteType | ShortType | IntegerType | LongType | FloatType |
+ DoubleType | BooleanType | _: DecimalType | TimestampType |
+ DateType | StringType =>
+
+ case udt: UserDefinedType[_] => verifyType(udt.sqlType)
+
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"CSV data source does not support ${dataType.simpleString} data type.")
+ }
+
+ schema.foreach(field => verifyType(field.dataType))
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (StructType.acceptsType(child.dataType)) {
+ try {
+ verifySchema(child.dataType.asInstanceOf[StructType])
+ TypeCheckResult.TypeCheckSuccess
+ } catch {
+ case e: UnsupportedOperationException =>
+ TypeCheckResult.TypeCheckFailure(e.getMessage)
+ }
+ } else {
+ TypeCheckResult.TypeCheckFailure(
+ s"$prettyName requires that the expression is a struct expression.")
+ }
+ }
+
+ private def rowToString(row: InternalRow): Seq[String] = {
+ var i = 0
+ val values = new Array[String](row.numFields)
+ while (i < row.numFields) {
+ if (!row.isNullAt(i)) {
+ values(i) = valueConverters(i).apply(row, i)
+ } else {
+ values(i) = params.nullValue
+ }
+ i += 1
+ }
+ values
+ }
+
+ private def makeConverter(dataType: DataType): ValueConverter = dataType match {
+ case DateType =>
+ (row: InternalRow, ordinal: Int) =>
+ params.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal)))
+
+ case TimestampType =>
+ (row: InternalRow, ordinal: Int) =>
+ params.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal)))
+
+ case udt: UserDefinedType[_] => makeConverter(udt.sqlType)
+
+ case dt: DataType =>
+ (row: InternalRow, ordinal: Int) =>
+ row.get(ordinal, dt).toString
+ }
+
+ override def nullSafeEval(row: Any): Any = {
+ writer.writeRow(rowToString(row.asInstanceOf[InternalRow]), false)
+ UTF8String.fromString(writer.flush())
+ }
+
+ override def inputTypes: Seq[AbstractDataType] = StructType :: Nil
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/33baa240/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
index 6883ac1..f583cca 100644
--- a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.{Generate, JoinTopK, LogicalPlan}
import org.apache.spark.sql.execution.UserProvidedPlanner
+import org.apache.spark.sql.execution.datasources.csv.{CsvToStruct, StructToCsv}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -1423,6 +1424,53 @@ object HivemallOps {
}.as("rowid")
/**
+ * Parses a column containing a CSV string into a [[StructType]] with the specified schema.
+ * Returns `null`, in the case of an unparseable string.
+ * @group misc
+ *
+ * @param e a string column containing CSV data.
+ * @param schema the schema to use when parsing the csv string
+ * @param options options to control how the csv is parsed. accepts the same options and the
+ * csv data source.
+ */
+ def from_csv(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr {
+ CsvToStruct(schema, options, e.expr)
+ }
+
+ /**
+ * Parses a column containing a CSV string into a [[StructType]] with the specified schema.
+ * Returns `null`, in the case of an unparseable string.
+ * @group misc
+ *
+ * @param e a string column containing CSV data.
+ * @param schema the schema to use when parsing the json string
+ */
+ def from_csv(e: Column, schema: StructType): Column =
+ from_csv(e, schema, Map.empty[String, String])
+
+ /**
+ * Converts a column containing a `StructType` into a CSV string with the specified schema.
+ * Throws an exception, in the case of an unsupported type.
+ * @group misc
+ *
+ * @param e a struct column.
+ * @param options options to control how the struct column is converted into a json string.
+ * accepts the same options and the json data source.
+ */
+ def to_csv(e: Column, options: Map[String, String]): Column = withExpr {
+ StructToCsv(options, e.expr)
+ }
+
+ /**
+ * Converts a column containing a `StructType` into a CSV string with the specified schema.
+ * Throws an exception, in the case of an unsupported type.
+ * @group misc
+ *
+ * @param e a struct column.
+ */
+ def to_csv(e: Column): Column = to_csv(e, Map.empty[String, String])
+
+ /**
* A convenient function to wrap an expression and produce a Column.
*/
@inline private def withExpr(expr: Expression): Column = Column(expr)
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/33baa240/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index ed56bc3..d595df2 100644
--- a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -461,6 +461,25 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
}
}
+ test("misc - from_csv") {
+ import hiveContext.implicits._
+ val df = Seq("""1,abc""").toDF()
+ val schema = new StructType().add("a", IntegerType).add("b", StringType)
+ checkAnswer(
+ df.select(from_csv($"value", schema)),
+ Row(Row(1, "abc")) :: Nil)
+ }
+
+ test("misc - to_csv") {
+ import hiveContext.implicits._
+ val df = Seq((1, "a", (0, 3.9, "abc")), (8, "c", (2, 0.4, "def"))).toDF()
+ checkAnswer(
+ df.select(to_csv($"_3")),
+ Row("0,3.9,abc") ::
+ Row("2,0.4,def") ::
+ Nil)
+ }
+
/**
* This test fails because;
*