You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kyuubi.apache.org by ul...@apache.org on 2023/02/24 07:37:28 UTC

[kyuubi] branch branch-1.7 updated: [KYUUBI #4393] [Kyuubi #4332] Fix some bugs with `Groupby` and `CacheTable`

This is an automated email from the ASF dual-hosted git repository.

ulyssesyou pushed a commit to branch branch-1.7
in repository https://gitbox.apache.org/repos/asf/kyuubi.git


The following commit(s) were added to refs/heads/branch-1.7 by this push:
     new d023620ad [KYUUBI #4393] [Kyuubi #4332] Fix some bugs with `Groupby` and `CacheTable`
d023620ad is described below

commit d023620ad6ba87213b78918b0f9841afad83603f
Author: odone <od...@gmail.com>
AuthorDate: Fri Feb 24 15:37:08 2023 +0800

    [KYUUBI #4393] [Kyuubi #4332] Fix some bugs with `Groupby` and `CacheTable`
    
    close #4332
    ### _Why are the changes needed?_
    
    For the case where the table name has been resolved and an `Expand` logical plan exists
    ```
    InsertIntoHiveTable `default`.`t1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, false, false, [a, b]
    +- Aggregate [a#0], [a#0, ansi_cast((count(if ((gid#9 = 1)) spark_catalog.default.t2.`b`#10 else null) * count(if ((gid#9 = 2)) spark_catalog.default.t2.`c`#11 else null)) as string) AS b#8]
       +- Aggregate [a#0, spark_catalog.default.t2.`b`#10, spark_catalog.default.t2.`c`#11, gid#9], [a#0, spark_catalog.default.t2.`b`#10, spark_catalog.default.t2.`c`#11, gid#9]
          +- Expand [ArrayBuffer(a#0, b#1, null, 1), ArrayBuffer(a#0, null, c#2, 2)], [a#0, spark_catalog.default.t2.`b`#10, spark_catalog.default.t2.`c`#11, gid#9]
             +- HiveTableRelation [`default`.`t2`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, Data Cols: [a#0, b#1, c#2], Partition Cols: []]
    ```
    For the case `CacheTable` with `window` function
    ```
    InsertIntoHiveTable `default`.`t1`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, true, false, [a, b]
    +- Project [a#98, b#99]
       +- InMemoryRelation [a#98, b#99, rank#100], StorageLevel(disk, memory, deserialized, 1 replicas)
             +- *(2) Filter (isnotnull(rank#4) AND (rank#4 = 1))
                +- Window [row_number() windowspecdefinition(a#9, b#10 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rank#4], [a#9], [b#10 ASC NULLS FIRST]
                   +- *(1) Sort [a#9 ASC NULLS FIRST, b#10 ASC NULLS FIRST], false, 0
                      +- Exchange hashpartitioning(a#9, 200), ENSURE_REQUIREMENTS, [id=#38]
                         +- Scan hive default.t2 [a#9, b#10], HiveTableRelation [`default`.`t2`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, Data Cols: [a#9, b#10], Partition Cols: []]
    
    ```
    
    ### _How was this patch tested?_
    - [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible
    
    - [ ] Add screenshots for manual tests if appropriate
    
    - [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request
    
    Closes #4393 from iodone/kyuubi-4332.
    
    Closes #4393
    
    d2afdabd [odone] fix cache table bug
    443af798 [odone] fix some bugs with groupby
    
    Authored-by: odone <od...@gmail.com>
    Signed-off-by: ulyssesyou <ul...@apache.org>
    (cherry picked from commit 427771017462a2e32f87296d3215d90b774736d0)
    Signed-off-by: ulyssesyou <ul...@apache.org>
---
 .../helper/SparkSQLLineageParseHelper.scala        |  39 ++++++-
 .../helper/SparkSQLLineageParserHelperSuite.scala  | 119 +++++++++++++++++++++
 2 files changed, 155 insertions(+), 3 deletions(-)

diff --git a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala
index f70e09126..a58653113 100644
--- a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala
+++ b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala
@@ -25,8 +25,7 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.analysis.{NamedRelation, PersistedView, ViewType}
 import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, HiveTableRelation}
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression}
-import org.apache.spark.sql.catalyst.expressions.ScalarSubquery
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression, ScalarSubquery}
 import org.apache.spark.sql.catalyst.expressions.aggregate.Count
 import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -128,7 +127,7 @@ trait LineageParser {
           exp.toAttribute,
           if (!containsCountAll(exp.child)) references
           else references + exp.toAttribute.withName(AGGREGATE_COUNT_COLUMN_IDENTIFIER))
-      case a: Attribute => a -> a.references
+      case a: Attribute => a -> AttributeSet(a)
     }
     ListMap(exps: _*)
   }
@@ -149,6 +148,9 @@ trait LineageParser {
             attr.withQualifier(attr.qualifier.init)
           case attr if attr.name.equalsIgnoreCase(AGGREGATE_COUNT_COLUMN_IDENTIFIER) =>
             attr.withQualifier(qualifier)
+          case attr if isNameWithQualifier(attr, qualifier) =>
+            val newName = attr.name.split('.').last.stripPrefix("`").stripSuffix("`")
+            attr.withName(newName).withQualifier(qualifier)
         })
       }
     } else {
@@ -160,6 +162,12 @@ trait LineageParser {
     }
   }
 
+  private def isNameWithQualifier(attr: Attribute, qualifier: Seq[String]): Boolean = {
+    val nameTokens = attr.name.split('.')
+    val namespace = nameTokens.init.mkString(".")
+    nameTokens.length > 1 && namespace.endsWith(qualifier.mkString("."))
+  }
+
   private def mergeRelationColumnLineage(
       parentColumnsLineage: AttributeMap[AttributeSet],
       relationOutput: Seq[Attribute],
@@ -327,6 +335,31 @@ trait LineageParser {
           joinColumnsLineage(parentColumnsLineage, getSelectColumnLineage(p.aggregateExpressions))
         p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage)
 
+      case p: Expand =>
+        val references =
+          p.projections.transpose.map(_.flatMap(x => x.references)).map(AttributeSet(_))
+
+        val childColumnsLineage = ListMap(p.output.zip(references): _*)
+        val nextColumnsLineage =
+          joinColumnsLineage(parentColumnsLineage, childColumnsLineage)
+        p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage)
+
+      case p: Window =>
+        val windowColumnsLineage =
+          ListMap(p.windowExpressions.map(exp => (exp.toAttribute, exp.references)): _*)
+
+        val nextColumnsLineage = if (parentColumnsLineage.isEmpty) {
+          ListMap(p.child.output.map(attr => (attr, attr.references)): _*) ++ windowColumnsLineage
+        } else {
+          parentColumnsLineage.map {
+            case (k, _) if windowColumnsLineage.contains(k) =>
+              k -> windowColumnsLineage(k)
+            case (k, attrs) =>
+              k -> AttributeSet(attrs.flatten(attr =>
+                windowColumnsLineage.getOrElse(attr, AttributeSet(attr))))
+          }
+        }
+        p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage)
       case p: Join =>
         p.joinType match {
           case LeftSemi | LeftAnti =>
diff --git a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala
index 6652be9ea..050f3ddc9 100644
--- a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala
+++ b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala
@@ -1094,6 +1094,125 @@ class SparkSQLLineageParserHelperSuite extends KyuubiFunSuite
     }
   }
 
+  test("test group by") {
+    withTable("t1", "t2", "v2_catalog.db.t1", "v2_catalog.db.t2") { _ =>
+      spark.sql("CREATE TABLE t1 (a string, b string, c string) USING hive")
+      spark.sql("CREATE TABLE t2 (a string, b string, c string) USING hive")
+      spark.sql("CREATE TABLE v2_catalog.db.t1 (a string, b string, c string)")
+      spark.sql("CREATE TABLE v2_catalog.db.t2 (a string, b string, c string)")
+      val ret0 =
+        exectractLineage(
+          s"insert into table t1 select a," +
+            s"concat_ws('/', collect_set(b))," +
+            s"count(distinct(b)) * count(distinct(c))" +
+            s"from t2 group by a")
+      assert(ret0 == Lineage(
+        List("default.t2"),
+        List("default.t1"),
+        List(
+          ("default.t1.a", Set("default.t2.a")),
+          ("default.t1.b", Set("default.t2.b")),
+          ("default.t1.c", Set("default.t2.b", "default.t2.c")))))
+
+      val ret1 =
+        exectractLineage(
+          s"insert into table v2_catalog.db.t1 select a," +
+            s"concat_ws('/', collect_set(b))," +
+            s"count(distinct(b)) * count(distinct(c))" +
+            s"from v2_catalog.db.t2 group by a")
+      assert(ret1 == Lineage(
+        List("v2_catalog.db.t2"),
+        List("v2_catalog.db.t1"),
+        List(
+          ("v2_catalog.db.t1.a", Set("v2_catalog.db.t2.a")),
+          ("v2_catalog.db.t1.b", Set("v2_catalog.db.t2.b")),
+          ("v2_catalog.db.t1.c", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c")))))
+
+      val ret2 =
+        exectractLineage(
+          s"insert into table v2_catalog.db.t1 select a," +
+            s"count(distinct(b+c))," +
+            s"count(distinct(b)) * count(distinct(c))" +
+            s"from v2_catalog.db.t2 group by a")
+      assert(ret2 == Lineage(
+        List("v2_catalog.db.t2"),
+        List("v2_catalog.db.t1"),
+        List(
+          ("v2_catalog.db.t1.a", Set("v2_catalog.db.t2.a")),
+          ("v2_catalog.db.t1.b", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c")),
+          ("v2_catalog.db.t1.c", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c")))))
+    }
+  }
+
+  test("test grouping sets") {
+    withTable("t1", "t2") { _ =>
+      spark.sql("CREATE TABLE t1 (a string, b string, c string) USING hive")
+      spark.sql("CREATE TABLE t2 (a string, b string, c string, d string) USING hive")
+      val ret0 =
+        exectractLineage(
+          s"insert into table t1 select a,b,GROUPING__ID " +
+            s"from t2 group by a,b,c,d grouping sets ((a,b,c), (a,b,d))")
+      assert(ret0 == Lineage(
+        List("default.t2"),
+        List("default.t1"),
+        List(
+          ("default.t1.a", Set("default.t2.a")),
+          ("default.t1.b", Set("default.t2.b")),
+          ("default.t1.c", Set()))))
+    }
+  }
+
+  test("test catch table with window function") {
+    withTable("t1", "t2") { _ =>
+      spark.sql("CREATE TABLE t1 (a string, b string) USING hive")
+      spark.sql("CREATE TABLE t2 (a string, b string) USING hive")
+
+      spark.sql(
+        s"cache table c1 select * from (" +
+          s"select a, b, row_number() over (partition by a order by b asc ) rank from t2)" +
+          s" where rank=1")
+      val ret0 = exectractLineage("insert overwrite table t1 select a, b from c1")
+      assert(ret0 == Lineage(
+        List("default.t2"),
+        List("default.t1"),
+        List(
+          ("default.t1.a", Set("default.t2.a")),
+          ("default.t1.b", Set("default.t2.b")))))
+
+      val ret1 = exectractLineage("insert overwrite table t1 select a, rank from c1")
+      assert(ret1 == Lineage(
+        List("default.t2"),
+        List("default.t1"),
+        List(
+          ("default.t1.a", Set("default.t2.a")),
+          ("default.t1.b", Set("default.t2.a", "default.t2.b")))))
+
+      spark.sql(
+        s"cache table c2 select * from (" +
+          s"select b, a, row_number() over (partition by a order by b asc ) rank from t2)" +
+          s" where rank=1")
+      val ret2 = exectractLineage("insert overwrite table t1 select a, b from c2")
+      assert(ret2 == Lineage(
+        List("default.t2"),
+        List("default.t1"),
+        List(
+          ("default.t1.a", Set("default.t2.a")),
+          ("default.t1.b", Set("default.t2.b")))))
+
+      spark.sql(
+        s"cache table c3 select * from (" +
+          s"select a as aa, b as bb, row_number() over (partition by a order by b asc ) rank" +
+          s" from t2) where rank=1")
+      val ret3 = exectractLineage("insert overwrite table t1 select aa, bb from c3")
+      assert(ret3 == Lineage(
+        List("default.t2"),
+        List("default.t1"),
+        List(
+          ("default.t1.a", Set("default.t2.a")),
+          ("default.t1.b", Set("default.t2.b")))))
+    }
+  }
+
   private def exectractLineageWithoutExecuting(sql: String): Lineage = {
     val parsed = spark.sessionState.sqlParser.parsePlan(sql)
     val analyzed = spark.sessionState.analyzer.execute(parsed)