You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by db...@apache.org on 2020/02/14 21:50:24 UTC

[spark] branch branch-3.0 updated: [SPARK-30289][SQL] Partitioned by Nested Column for `InMemoryTable`

This is an automated email from the ASF dual-hosted git repository.

dbtsai pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 2824fec9 [SPARK-30289][SQL] Partitioned by Nested Column for `InMemoryTable`
2824fec9 is described below

commit 2824fec9fa57444b7c64edb8226cf75bb87a2e5d
Author: DB Tsai <d_...@apple.com>
AuthorDate: Fri Feb 14 21:46:01 2020 +0000

    [SPARK-30289][SQL] Partitioned by Nested Column for `InMemoryTable`
    
    ### What changes were proposed in this pull request?
    1. `InMemoryTable` was flatting the nested columns, and then the flatten columns was used to look up the indices which is not correct.
    
    This PR implements partitioned by nested column for `InMemoryTable`.
    
    ### Why are the changes needed?
    
    This PR implements partitioned by nested column for `InMemoryTable`, so we can test this features in DSv2
    
    ### Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing unit tests and new tests.
    
    Closes #26929 from dbtsai/addTests.
    
    Authored-by: DB Tsai <d_...@apple.com>
    Signed-off-by: DB Tsai <d_...@apple.com>
    (cherry picked from commit d0f961476031b62bda0d4d41f7248295d651ea92)
    Signed-off-by: DB Tsai <d_...@apple.com>
---
 .../apache/spark/sql/connector/InMemoryTable.scala | 35 +++++++--
 .../apache/spark/sql/DataFrameWriterV2Suite.scala  | 86 +++++++++++++++++++++-
 2 files changed, 114 insertions(+), 7 deletions(-)

diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
index c9e4e0a..0187ae3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
@@ -26,7 +26,7 @@ import org.scalatest.Assertions._
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.connector.catalog._
-import org.apache.spark.sql.connector.expressions.{IdentityTransform, Transform}
+import org.apache.spark.sql.connector.expressions.{IdentityTransform, NamedReference, Transform}
 import org.apache.spark.sql.connector.read._
 import org.apache.spark.sql.connector.write._
 import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull}
@@ -59,10 +59,30 @@ class InMemoryTable(
 
   def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq
 
-  private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames)
-  private val partIndexes = partFieldNames.map(schema.fieldIndex)
+  private val partCols: Array[Array[String]] = partitioning.flatMap(_.references).map { ref =>
+    schema.findNestedField(ref.fieldNames(), includeCollections = false) match {
+      case Some(_) => ref.fieldNames()
+      case None => throw new IllegalArgumentException(s"${ref.describe()} does not exist.")
+    }
+  }
 
-  private def getKey(row: InternalRow): Seq[Any] = partIndexes.map(row.toSeq(schema)(_))
+  private def getKey(row: InternalRow): Seq[Any] = {
+    def extractor(fieldNames: Array[String], schema: StructType, row: InternalRow): Any = {
+      val index = schema.fieldIndex(fieldNames(0))
+      val value = row.toSeq(schema).apply(index)
+      if (fieldNames.length > 1) {
+        (value, schema(index).dataType) match {
+          case (row: InternalRow, nestedSchema: StructType) =>
+            extractor(fieldNames.drop(1), nestedSchema, row)
+          case (_, dataType) =>
+            throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}")
+        }
+      } else {
+        value
+      }
+    }
+    partCols.map(fieldNames => extractor(fieldNames, schema, row))
+  }
 
   def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized {
     data.foreach(_.rows.foreach { row =>
@@ -146,8 +166,10 @@ class InMemoryTable(
   }
 
   private class Overwrite(filters: Array[Filter]) extends TestBatchWrite {
+    import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
     override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
-      val deleteKeys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters)
+      val deleteKeys = InMemoryTable.filtersToKeys(
+        dataMap.keys, partCols.map(_.toSeq.quoted), filters)
       dataMap --= deleteKeys
       withData(messages.map(_.asInstanceOf[BufferedRows]))
     }
@@ -161,7 +183,8 @@ class InMemoryTable(
   }
 
   override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized {
-    dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters)
+    import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
+    dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters)
   }
 }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
index d49dc58..cd15708 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala
@@ -17,20 +17,24 @@
 
 package org.apache.spark.sql
 
+import java.sql.Timestamp
+
 import scala.collection.JavaConverters._
 
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException}
 import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic}
-import org.apache.spark.sql.connector.InMemoryTableCatalog
+import org.apache.spark.sql.connector.{InMemoryTable, InMemoryTableCatalog}
 import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
 import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform}
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
+import org.apache.spark.sql.types.TimestampType
 import org.apache.spark.sql.util.QueryExecutionListener
+import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
 
 class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with BeforeAndAfter {
@@ -550,4 +554,84 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
     assert(replaced.partitioning.isEmpty)
     assert(replaced.properties === defaultOwnership.asJava)
   }
+
+  test("SPARK-30289 Create: partitioned by nested column") {
+    val schema = new StructType().add("ts", new StructType()
+      .add("created", TimestampType)
+      .add("modified", TimestampType)
+      .add("timezone", StringType))
+
+    val data = Seq(
+      Row(Row(Timestamp.valueOf("2019-06-01 10:00:00"), Timestamp.valueOf("2019-09-02 07:00:00"),
+        "America/Los_Angeles")),
+      Row(Row(Timestamp.valueOf("2019-08-26 18:00:00"), Timestamp.valueOf("2019-09-26 18:00:00"),
+        "America/Los_Angeles")),
+      Row(Row(Timestamp.valueOf("2018-11-23 18:00:00"), Timestamp.valueOf("2018-12-22 18:00:00"),
+        "America/New_York")))
+    val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema)
+
+    df.writeTo("testcat.table_name")
+      .partitionedBy($"ts.timezone")
+      .create()
+
+    val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+      .asInstanceOf[InMemoryTable]
+
+    assert(table.name === "testcat.table_name")
+    assert(table.partitioning === Seq(IdentityTransform(FieldReference(Array("ts", "timezone")))))
+    checkAnswer(spark.table(table.name), data)
+    assert(table.dataMap.toArray.length == 2)
+    assert(table.dataMap(Seq(UTF8String.fromString("America/Los_Angeles"))).rows.size == 2)
+    assert(table.dataMap(Seq(UTF8String.fromString("America/New_York"))).rows.size == 1)
+
+    // TODO: `DataSourceV2Strategy` can not translate nested fields into source filter yet
+    // so the following sql will fail.
+    // sql("DELETE FROM testcat.table_name WHERE ts.timezone = \"America/Los_Angeles\"")
+  }
+
+  test("SPARK-30289 Create: partitioned by multiple transforms on nested columns") {
+    spark.table("source")
+      .withColumn("ts", struct(
+        lit("2019-06-01 10:00:00.000000").cast("timestamp") as "created",
+        lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified",
+        lit("America/Los_Angeles") as "timezone"))
+      .writeTo("testcat.table_name")
+      .tableProperty("allow-unsupported-transforms", "true")
+      .partitionedBy(
+        years($"ts.created"), months($"ts.created"), days($"ts.created"), hours($"ts.created"),
+        years($"ts.modified"), months($"ts.modified"), days($"ts.modified"), hours($"ts.modified")
+      )
+      .create()
+
+    val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+    assert(table.name === "testcat.table_name")
+    assert(table.partitioning === Seq(
+      YearsTransform(FieldReference(Array("ts", "created"))),
+      MonthsTransform(FieldReference(Array("ts", "created"))),
+      DaysTransform(FieldReference(Array("ts", "created"))),
+      HoursTransform(FieldReference(Array("ts", "created"))),
+      YearsTransform(FieldReference(Array("ts", "modified"))),
+      MonthsTransform(FieldReference(Array("ts", "modified"))),
+      DaysTransform(FieldReference(Array("ts", "modified"))),
+      HoursTransform(FieldReference(Array("ts", "modified")))))
+  }
+
+  test("SPARK-30289 Create: partitioned by bucket(4, ts.timezone)") {
+    spark.table("source")
+      .withColumn("ts", struct(
+        lit("2019-06-01 10:00:00.000000").cast("timestamp") as "created",
+        lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified",
+        lit("America/Los_Angeles") as "timezone"))
+      .writeTo("testcat.table_name")
+      .tableProperty("allow-unsupported-transforms", "true")
+      .partitionedBy(bucket(4, $"ts.timezone"))
+      .create()
+
+    val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+    assert(table.name === "testcat.table_name")
+    assert(table.partitioning === Seq(BucketTransform(LiteralValue(4, IntegerType),
+      Seq(FieldReference(Seq("ts", "timezone"))))))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org