You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kyuubi.apache.org by ul...@apache.org on 2021/10/11 01:49:49 UTC

[incubator-kyuubi] branch master updated: [KYUUBI #1085][FOLLOWUP] Fix-Enforce maxOutputRows for aggregate with having statement

This is an automated email from the ASF dual-hosted git repository.

ulyssesyou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-kyuubi.git


The following commit(s) were added to refs/heads/master by this push:
     new fdff2b6  [KYUUBI #1085][FOLLOWUP] Fix-Enforce maxOutputRows for aggregate with having statement
fdff2b6 is described below

commit fdff2b62408c9fbede4c81935422aaea1c5752cc
Author: h <h...@zhihu.com>
AuthorDate: Mon Oct 11 09:49:33 2021 +0800

    [KYUUBI #1085][FOLLOWUP] Fix-Enforce maxOutputRows for aggregate with having statement
    
    <!--
    Thanks for sending a pull request!
    
    Here are some tips for you:
      1. If this is your first time, please read our contributor guidelines: https://kyuubi.readthedocs.io/en/latest/community/contributions.html
      2. If the PR is related to an issue in https://github.com/apache/incubator-kyuubi/issues, add '[KYUUBI #XXXX]' in your PR title, e.g., '[KYUUBI #XXXX] Your PR title ...'.
      3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][KYUUBI #XXXX] Your PR title ...'.
    -->
    
    ### _Why are the changes needed?_
    <!--
    Please clarify why the changes are needed. For instance,
      1. If you add a feature, you can talk about the use case of it.
      2. If you fix a bug, you can clarify why it is a bug.
    -->
    Support `Union` case as below
    ```
    SELECT * FROM t1
    UNION [ALL]
    SELECT * FROM t2
    ```
    
    Support `Distinct` case as below
    ```
    SELECT DISTINCT * FROM t1
    ```
    
    Fix The bug of watchdog with maxOutputRows happens in this situation as below
    
    Having and Sort
    ```
    SELECT c1, COUNT(c2) AS cnt
    FROM t1
    GROUP BY c1
    HAVING cnt > 0
    [ORDER BY c1, [c2 ...]]
    ```
    
    It throws Exception as
    ```
    org.apache.spark.sql.catalyst.plans.logical.GlobalLimit cannot be cast to org.apache.spark.sql.catalyst.plans.logical.Aggregate
    java.lang.ClassCastException: org.apache.spark.sql.catalyst.plans.logical.GlobalLimit cannot be cast to org.apache.spark.sql.catalyst.plans.logical.Aggregate
    	at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions$.resolvedAggregateFilter$1(Analyzer.scala:2451)
    	at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions$.resolveFilterCondInAggregate(Analyzer.scala:2460)
    	at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions$.resolveHaving(Analyzer.scala:2496)
    	at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions$$anonfun$apply$21.applyOrElse(Analyzer.scala:2353)
    	at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions$$anonfun$apply$21.applyOrElse(Analyzer.scala:2345)
    	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.$anonfun$resolveOperatorsUp$3(AnalysisHelper.scala:90)
    	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:74)
    	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.$anonfun$resolveOperatorsUp$1(AnalysisHelper.scala:90)
    	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.allowInvokingTransformsInAnalyzer(AnalysisHelper.scala:221)
    	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUp(AnalysisHelper.scala:86)
    	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUp$(AnalysisHelper.scala:84)
    	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveOperatorsUp(LogicalPlan.scala:29)
    	at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions$.apply(Analyzer.scala:2345)
    	at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions$.apply(Analyzer.scala:2344)
    	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$2(RuleExecutor.scala:216)
    	at scala.collection.LinearSeqOptimized.foldLeft(LinearSeqOptimized.scala:126)
    	at scala.collection.LinearSeqOptimized.foldLeft$(LinearSeqOptimized.scala:122)
    	at scala.collection.immutable.List.foldLeft(List.scala:91)
    	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1(RuleExecutor.scala:213)
    	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1$adapted(RuleExecutor.scala:205)
    	at scala.collection.immutable.List.foreach(List.scala:431)
    	at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:205)
    	at org.apache.spark.sql.catalyst.analysis.Analyzer.org$apache$spark$sql$catalyst$analysis$Analyzer$$executeSameContext(Analyzer.scala:196)
    	at org.apache.spark.sql.catalyst.analysis.Analyzer.execute(Analyzer.scala:190)
    	at org.apache.spark.sql.catalyst.analysis.Analyzer.execute(Analyzer.scala:155)
    	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$executeAndTrack$1(RuleExecutor.scala:183)
    	at org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:88)
    	at org.apache.spark.sql.catalyst.rules.RuleExecutor.executeAndTrack(RuleExecutor.scala:183)
    	at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:174)
    	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:228)
    	at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:173)
    	at org.apache.spark.sql.execution.QueryExecution.$anonfun$analyzed$1(QueryExecution.scala:73)
    	at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:111)
    	at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:143)
    	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
    	at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:143)
    	at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:73)
    	at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:71)
    	at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:63)
    	at org.apache.spark.sql.Dataset$.$anonfun$ofRows$2(Dataset.scala:98)
    	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
    	at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:96)
    	at org.apache.spark.sql.SparkSession.$anonfun$sql$1(SparkSession.scala:618)
    	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
    	at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:613)
    	at org.apache.spark.sql.test.SQLTestUtilsBase.$anonfun$sql$1(SQLTestUtils.scala:231)
    	at org.apache.spark.sql.KyuubiExtensionSuite.$anonfun$new$55(KyuubiExtensionSuite.scala:1331)
    	at org.apache.spark.sql.catalyst.plans.SQLHelper.withSQLConf(SQLHelper.scala:54)
    	at org.apache.spark.sql.catalyst.plans.SQLHelper.withSQLConf$(SQLHelper.scala:38)
    	at
    ```
    Reference related issue: https://issues.apache.org/jira/browse/SPARK-31519
    Spark SQL Ananlyzer transform aggregate with having to
    ```
    Filter
    +- Aggregate
    ```
    
    Solution:
    
    1. Skip the aggregate with havingCondition
    2. Match Filter for Adding Limit
    
    ### _How was this patch tested?_
    - [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible
    
    - [x] Add screenshots for manual tests if appropriate
    <img width="1440" alt="截屏2021-09-21 下午8 35 16" src="https://user-images.githubusercontent.com/635169/134171308-2842f0d4-acfa-4817-a03c-a7ef5e38df12.png">
    
    - [x] [Run test](https://kyuubi.readthedocs.io/en/latest/develop_tools/testing.html#running-tests) locally before make a pull request
    
    Closes #1129 from i7xh/fixAggWithHavingInMaxOutput.
    
    Closes #1085
    
    7577f4d3 [h] update
    5955c89e [h] update
    384b2333 [h] fix issue
    5b0af156 [h] update
    46327fc6 [h] Fix issue
    6119a039 [h] fix issue
    a7b87dd7 [h] Fix issue
    2570444e [h] BugFix: Aggregate with having statement
    
    Authored-by: h <h...@zhihu.com>
    Signed-off-by: ulysses-you <ul...@apache.org>
---
 .../kyuubi/sql/KyuubiSparkSQLExtension.scala       |   8 +-
 .../sql/watchdog/ForcedMaxOutputRowsRule.scala     |  75 +++++-
 .../scala/org/apache/spark/sql/WatchDogSuite.scala | 269 ++++++++++++++++-----
 3 files changed, 274 insertions(+), 78 deletions(-)

diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala
index c11d65c..7e4c780 100644
--- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala
+++ b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala
@@ -20,7 +20,7 @@ package org.apache.kyuubi.sql
 import org.apache.spark.sql.SparkSessionExtensions
 
 import org.apache.kyuubi.sql.sqlclassification.KyuubiSqlClassification
-import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MaxHivePartitionStrategy}
+import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MarkAggregateOrderRule, MaxHivePartitionStrategy}
 import org.apache.kyuubi.sql.zorder.{InsertZorderBeforeWritingDatasource, InsertZorderBeforeWritingHive}
 import org.apache.kyuubi.sql.zorder.ResolveZorder
 import org.apache.kyuubi.sql.zorder.ZorderSparkSqlExtensionsParser
@@ -43,16 +43,18 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) {
     // should be applied before
     // RepartitionBeforeWrite and RepartitionBeforeWriteHive
     // because we can only apply one of them (i.e. Global Sort or Repartition)
+    extensions.injectResolutionRule(MarkAggregateOrderRule)
+
     extensions.injectPostHocResolutionRule(InsertZorderBeforeWritingDatasource)
     extensions.injectPostHocResolutionRule(InsertZorderBeforeWritingHive)
-
     extensions.injectPostHocResolutionRule(KyuubiSqlClassification)
     extensions.injectPostHocResolutionRule(RepartitionBeforeWrite)
     extensions.injectPostHocResolutionRule(RepartitionBeforeWriteHive)
     extensions.injectPostHocResolutionRule(FinalStageConfigIsolationCleanRule)
+    extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule)
+
     extensions.injectQueryStagePrepRule(_ => InsertShuffleNodeBeforeJoin)
     extensions.injectQueryStagePrepRule(FinalStageConfigIsolation(_))
     extensions.injectPlannerStrategy(MaxHivePartitionStrategy)
-    extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule)
   }
 }
diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala
index c2e3ee4..d82eead 100644
--- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala
+++ b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala
@@ -19,11 +19,18 @@ package org.apache.kyuubi.sql.watchdog
 
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Limit, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.expressions.Alias
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, Filter, Limit, LogicalPlan, Project, Sort, Union}
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreeNodeTag
 
 import org.apache.kyuubi.sql.KyuubiSQLConf
 
+object ForcedMaxOutputRowsConstraint {
+  val CHILD_AGGREGATE: TreeNodeTag[String] = TreeNodeTag[String]("__kyuubi_child_agg__")
+  val CHILD_AGGREGATE_FLAG: String = "__kyuubi_child_agg__"
+}
+
 /*
 * Add ForcedMaxOutputRows rule for output rows limitation
 * to avoid huge output rows of non_limit query unexpectedly
@@ -45,19 +52,31 @@ import org.apache.kyuubi.sql.KyuubiSQLConf
 * */
 case class ForcedMaxOutputRowsRule(session: SparkSession) extends Rule[LogicalPlan] {
 
+  private def isChildAggregate(a: Aggregate): Boolean = a
+    .aggregateExpressions.exists(p => p.getTagValue(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE)
+    .contains(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG))
+
+  private def canInsertLimitInner(p: LogicalPlan): Boolean = p match {
+
+    case Aggregate(_, Alias(_, "havingCondition")::Nil, _) => false
+    case agg: Aggregate => !isChildAggregate(agg)
+    case _: Distinct => true
+    case _: Filter => true
+    case _: Project => true
+    case Limit(_, _) => true
+    case _: Sort => true
+    case _: Union => true
+    case _ => false
+
+  }
+
   private def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = {
 
     maxOutputRowsOpt match {
-      case Some(forcedMaxOutputRows) => val supported = p match {
-          case _: Project => true
-          case _: Aggregate => true
-          case Limit(_, _) => true
-          case _ => false
-        }
-        supported && !p.maxRows.exists(_ <= forcedMaxOutputRows)
+      case Some(forcedMaxOutputRows) => canInsertLimitInner(p) &&
+        !p.maxRows.exists(_ <= forcedMaxOutputRows)
       case None => false
     }
-
   }
 
   override def apply(plan: LogicalPlan): LogicalPlan = {
@@ -70,3 +89,41 @@ case class ForcedMaxOutputRowsRule(session: SparkSession) extends Rule[LogicalPl
   }
 
 }
+
+case class MarkAggregateOrderRule(session: SparkSession) extends Rule[LogicalPlan] {
+
+  private def markChildAggregate(a: Aggregate): Unit = {
+    // mark child aggregate
+    a.aggregateExpressions.filter(_.resolved).foreach(_.setTagValue(
+      ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE,
+      ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)
+    )
+  }
+
+  private def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match {
+    /*
+    * The case mainly process order not aggregate column but grouping column as below
+    * SELECT c1, COUNT(*) as cnt
+    * FROM t1
+    * GROUP BY c1
+    * ORDER BY c1
+    * */
+    case a: Aggregate if a.aggregateExpressions
+      .exists(x => x.resolved && x.name.equals("aggOrder")) => markChildAggregate(a)
+      plan
+
+    case _ => plan.children.foreach(_.foreach {
+        case agg: Aggregate => markChildAggregate(agg)
+        case _ => Unit
+      }
+    )
+      plan
+  }
+
+  override def apply(plan: LogicalPlan): LogicalPlan = conf.getConf(
+    KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS
+  ) match {
+    case Some(_) => findAndMarkChildAggregate(plan)
+    case _ => plan
+  }
+}
diff --git a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala
index 483a2bf..656e5bc 100644
--- a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala
+++ b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala
@@ -23,6 +23,10 @@ import org.apache.kyuubi.sql.KyuubiSQLConf
 import org.apache.kyuubi.sql.watchdog.MaxHivePartitionExceedException
 
 class WatchDogSuite extends KyuubiSparkSQLExtensionTest {
+
+  case class LimitAndExpected(limit: Int, expected: Int)
+  val limitAndExpecteds = List(LimitAndExpected(1, 1), LimitAndExpected(11, 10))
+
   test("test watchdog with scan maxHivePartitions") {
     withTable("test", "temp") {
       sql(
@@ -59,95 +63,228 @@ class WatchDogSuite extends KyuubiSparkSQLExtensionTest {
     }
   }
 
-  test("test watchdog with query forceMaxOutputRows") {
+  test("test watchdog: simple SELECT STATEMENT") {
 
     withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
 
-      assert(sql("SELECT * FROM t1")
-        .queryExecution.analyzed.isInstanceOf[GlobalLimit])
+      List("", "ORDER BY c1", "ORDER BY c2").foreach { sort =>
+        List("", " DISTINCT").foreach{ distinct =>
+        assert(sql(
+          s"""
+             |SELECT $distinct *
+             |FROM t1
+             |$sort
+             |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
+        }
+      }
 
-      assert(sql("SELECT * FROM t1 LIMIT 1")
-        .queryExecution.analyzed.asInstanceOf[GlobalLimit].maxRows.contains(1))
+      limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
+        List("", "ORDER BY c1", "ORDER BY c2").foreach { sort =>
+          List("", "DISTINCT").foreach{ distinct =>
+            assert(sql(
+              s"""
+                 |SELECT $distinct *
+                 |FROM t1
+                 |$sort
+                 |LIMIT $limit
+                 |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)
+            )
+          }
+        }
+      }
+    }
+  }
 
-      assert(sql("SELECT * FROM t1 LIMIT 11")
-        .queryExecution.analyzed.asInstanceOf[GlobalLimit].maxRows.contains(10))
+  test("test watchdog: SELECT ... WITH AGGREGATE STATEMENT ") {
+
+    withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
 
       assert(!sql("SELECT count(*) FROM t1")
         .queryExecution.analyzed.isInstanceOf[GlobalLimit])
 
-      assert(sql(
-        """
-          |SELECT c1, COUNT(*)
-          |FROM t1
-          |GROUP BY c1
-          |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
+      val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
+      val havingConditions = List("", "HAVING cnt > 1")
 
-      assert(sql(
-        """
-          |WITH custom_cte AS (
-          |SELECT * FROM t1
-          |)
-          |
-          |SELECT * FROM custom_cte
-          |""".stripMargin).queryExecution
-        .analyzed.isInstanceOf[GlobalLimit])
+      havingConditions.foreach { having =>
+        sorts.foreach { sort =>
+          assert(sql(
+            s"""
+               |SELECT c1, COUNT(*) as cnt
+               |FROM t1
+               |GROUP BY c1
+               |$having
+               |$sort
+               |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
+        }
+      }
 
-      assert(sql(
-        """
-          |WITH custom_cte AS (
-          |SELECT * FROM t1
-          |)
-          |
-          |SELECT * FROM custom_cte
-          |LIMIT 1
-          |""".stripMargin).queryExecution
-        .analyzed.asInstanceOf[GlobalLimit].maxRows.contains(1))
+      limitAndExpecteds.foreach{ case LimitAndExpected(limit, expected) =>
+        havingConditions.foreach { having =>
+          sorts.foreach { sort =>
+            assert(sql(
+              s"""
+                 |SELECT c1, COUNT(*) as cnt
+                 |FROM t1
+                 |GROUP BY c1
+                 |$having
+                 |$sort
+                 |LIMIT $limit
+                 |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
+          }
+        }
+      }
+    }
+  }
 
-      assert(sql(
-        """
-          |WITH custom_cte AS (
-          |SELECT * FROM t1
-          |)
-          |
-          |SELECT * FROM custom_cte
-          |LIMIT 11
-          |""".stripMargin).queryExecution
-        .analyzed.asInstanceOf[GlobalLimit].maxRows.contains(10))
+  test("test watchdog: SELECT with CTE forceMaxOutputRows") {
 
-      assert(!sql(
-        """
-          |WITH custom_cte AS (
-          |SELECT * FROM t1
-          |)
-          |
-          |SELECT COUNT(*) FROM custom_cte
-          |""".stripMargin).queryExecution
-        .analyzed.isInstanceOf[GlobalLimit])
+    withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
+
+      val sorts = List("", "ORDER BY c1", "ORDER BY c2")
+
+      sorts.foreach { sort =>
+        assert(sql(
+          s"""
+             |WITH custom_cte AS (
+             |SELECT * FROM t1
+             |)
+             |SELECT *
+             |FROM custom_cte
+             |$sort
+             |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
+      }
+
+      limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
+        sorts.foreach { sort =>
+          assert(sql(
+            s"""
+               |WITH custom_cte AS (
+               |SELECT * FROM t1
+               |)
+               |SELECT *
+               |FROM custom_cte
+               |$sort
+               |LIMIT $limit
+               |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
+        }
+      }
+    }
+  }
+
+  test("test watchdog: SELECT AGGREGATE WITH CTE forceMaxOutputRows") {
 
-      assert(sql(
+    withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
+
+      assert(!sql(
         """
           |WITH custom_cte AS (
           |SELECT * FROM t1
           |)
           |
-          |SELECT c1, COUNT(*)
+          |SELECT COUNT(*)
           |FROM custom_cte
-          |GROUP BY c1
           |""".stripMargin).queryExecution
         .analyzed.isInstanceOf[GlobalLimit])
 
-      assert(sql(
-        """
-          |WITH custom_cte AS (
-          |SELECT * FROM t1
-          |)
-          |
-          |SELECT c1, COUNT(*)
-          |FROM custom_cte
-          |GROUP BY c1
-          |LIMIT 11
-          |""".stripMargin).queryExecution
-        .analyzed.asInstanceOf[GlobalLimit].maxRows.contains(10))
+      val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
+      val havingConditions = List("", "HAVING cnt > 1")
+
+      havingConditions.foreach { having =>
+        sorts.foreach { sort =>
+          assert(sql(
+            s"""
+               |WITH custom_cte AS (
+               |SELECT * FROM t1
+               |)
+               |
+               |SELECT c1, COUNT(*) as cnt
+               |FROM custom_cte
+               |GROUP BY c1
+               |$having
+               |$sort
+               |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
+        }
+      }
+
+      limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
+        havingConditions.foreach { having =>
+          sorts.foreach { sort =>
+            assert(sql(
+              s"""
+                 |WITH custom_cte AS (
+                 |SELECT * FROM t1
+                 |)
+                 |
+                 |SELECT c1, COUNT(*) as cnt
+                 |FROM custom_cte
+                 |GROUP BY c1
+                 |$having
+                 |$sort
+                 |LIMIT $limit
+                 |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
+          }
+        }
+      }
+    }
+  }
+
+  test("test watchdog: UNION Statement for forceMaxOutputRows") {
+
+    withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
+
+      List("", "ALL").foreach { x =>
+        assert(sql(
+          s"""
+             |SELECT c1, c2 FROM t1
+             |UNION $x
+             |SELECT c1, c2 FROM t2
+             |UNION $x
+             |SELECT c1, c2 FROM t3
+             |""".stripMargin)
+          .queryExecution.analyzed.isInstanceOf[GlobalLimit])
+      }
+
+      val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
+      val havingConditions = List("", "HAVING cnt > 1")
+
+      List("", "ALL").foreach { x =>
+        havingConditions.foreach{ having =>
+          sorts.foreach { sort =>
+            assert(sql(
+              s"""
+                 |SELECT c1, count(c2) as cnt
+                 |FROM t1
+                 |GROUP BY c1
+                 |$having
+                 |UNION $x
+                 |SELECT c1, COUNT(c2) as cnt
+                 |FROM t2
+                 |GROUP BY c1
+                 |$having
+                 |UNION $x
+                 |SELECT c1, COUNT(c2) as cnt
+                 |FROM t3
+                 |GROUP BY c1
+                 |$having
+                 |$sort
+                 |""".stripMargin)
+              .queryExecution.analyzed.isInstanceOf[GlobalLimit])
+          }
+        }
+      }
+
+      limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
+        assert(sql(
+          s"""
+             |SELECT c1, c2 FROM t1
+             |UNION
+             |SELECT c1, c2 FROM t2
+             |UNION
+             |SELECT c1, c2 FROM t3
+             |LIMIT $limit
+             |""".stripMargin)
+          .queryExecution.analyzed.maxRows.contains(expected))
+      }
     }
   }
 }