You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by li...@apache.org on 2020/03/24 05:36:43 UTC

[submarine] branch master updated: SUBMARINE-449. Fix escaped correlated subquery cases for spark security

This is an automated email from the ASF dual-hosted git repository.

liuxun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git


The following commit(s) were added to refs/heads/master by this push:
     new 9180e58  SUBMARINE-449. Fix escaped correlated subquery cases for spark security
9180e58 is described below

commit 9180e5839e663d272f9ecc13518dbc2c98245681
Author: Kent Yao <ya...@hotmail.com>
AuthorDate: Mon Mar 23 18:02:08 2020 +0800

    SUBMARINE-449. Fix escaped correlated subquery cases for spark security
    
    ### What is this PR for?
    
    handle row filtering for correlated subqueries
    
    ### What type of PR is it?
    
    Bug fix
    
    ### Todos
    * [ ] - Task
    
    ### What is the Jira issue?
    https://issues.apache.org/jira/browse/SUBMARINE-449
    ### How should this be tested?
    add CTE, uncorrelated and correlated subquery test cases
    ### Screenshots (if appropriate)
    
    ### Questions:
    * Does the licenses files need update? No
    * Is there breaking changes for older versions? No
    * Does this needs documentation? No
    
    Author: Kent Yao <ya...@hotmail.com>
    
    Closes #245 from yaooqinn/SUBMARINE-449 and squashes the following commits:
    
    cbac516 [Kent Yao] fix tests
    5527cbd [Kent Yao] SUBMARINE-449. Fix escaped correlated subquery cases for spark security
---
 .../optimizer/SubmarineRowFilterExtension.scala    | 33 ++++++++-----
 .../spark/security/DataMaskingSQLTest.scala        | 52 ++++++++++++++++++++
 .../spark/security/RowFilterSQLTest.scala          | 57 +++++++++++++++++++++-
 3 files changed, 130 insertions(+), 12 deletions(-)

diff --git a/submarine-security/spark-security/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SubmarineRowFilterExtension.scala b/submarine-security/spark-security/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SubmarineRowFilterExtension.scala
index 7641368..be3c031 100644
--- a/submarine-security/spark-security/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SubmarineRowFilterExtension.scala
+++ b/submarine-security/spark-security/src/main/scala/org/apache/spark/sql/catalyst/optimizer/SubmarineRowFilterExtension.scala
@@ -47,7 +47,7 @@ case class SubmarineRowFilterExtension(spark: SparkSession) extends Rule[Logical
    * @return A new Spark [[LogicalPlan]] with specified row filter expressions
    */
   private def applyingRowFilterExpr(plan: LogicalPlan, table: CatalogTable): LogicalPlan = {
-    val auditHandler = new RangerSparkAuditHandler()
+    val auditHandler = RangerSparkAuditHandler()
     try {
       val identifier = table.identifier
       val resource =
@@ -79,19 +79,30 @@ case class SubmarineRowFilterExtension(spark: SparkSession) extends Rule[Logical
     result != null && result.isRowFilterEnabled && StringUtils.isNotEmpty(result.getFilterExpr)
   }
 
+  private def getPlanWithTables(plan: LogicalPlan): Map[LogicalPlan, CatalogTable] = {
+    plan.collectLeaves().map {
+      case h if h.nodeName == "HiveTableRelation" =>
+        h -> getFieldVal(h, "tableMeta").asInstanceOf[CatalogTable]
+      case m if m.nodeName == "MetastoreRelation" =>
+        m -> getFieldVal(m, "catalogTable").asInstanceOf[CatalogTable]
+      case l: LogicalRelation if l.catalogTable.isDefined =>
+        l -> l.catalogTable.get
+      case _ => null
+    }.filter(_ != null).toMap
+  }
+
+  private def isFixed(plan: LogicalPlan): Boolean = {
+    val markNum = plan.collect { case _: SubmarineRowFilter => true }.size
+    markNum >= getPlanWithTables(plan).size
+  }
   private def doFiltering(plan: LogicalPlan): LogicalPlan = plan match {
     case rf: SubmarineRowFilter => rf
-    case fixed if fixed.find(_.isInstanceOf[SubmarineRowFilter]).nonEmpty => fixed
+    case plan if isFixed(plan) => plan
     case _ =>
-      val plansWithTables = plan.collectLeaves().map {
-        case h if h.nodeName == "HiveTableRelation" =>
-          (h, getFieldVal(h, "tableMeta").asInstanceOf[CatalogTable])
-        case m if m.nodeName == "MetastoreRelation" =>
-          (m, getFieldVal(m, "catalogTable").asInstanceOf[CatalogTable])
-        case l: LogicalRelation if l.catalogTable.isDefined =>
-          (l, l.catalogTable.get)
-        case _ => null
-      }.filter(_ != null).map(lt => (lt._1, applyingRowFilterExpr(lt._1, lt._2))).toMap
+      val plansWithTables = getPlanWithTables(plan)
+        .map { case (plan, table) =>
+          (plan, applyingRowFilterExpr(plan, table))
+        }
 
       plan transformUp {
         case p => plansWithTables.getOrElse(p, p)
diff --git a/submarine-security/spark-security/src/test/scala/org/apache/submarine/spark/security/DataMaskingSQLTest.scala b/submarine-security/spark-security/src/test/scala/org/apache/submarine/spark/security/DataMaskingSQLTest.scala
index f8b6dff..f228845 100644
--- a/submarine-security/spark-security/src/test/scala/org/apache/submarine/spark/security/DataMaskingSQLTest.scala
+++ b/submarine-security/spark-security/src/test/scala/org/apache/submarine/spark/security/DataMaskingSQLTest.scala
@@ -208,4 +208,56 @@ case class DataMaskingSQLTest() extends FunSuite with BeforeAndAfterAll {
       assert(row.getString(1) === "xxx_277", "value shows last 4 characters")
     }
   }
+
+  test("MASK_SHOW_LAST_4 with uncorrelated subquery") {
+    val statement =
+      s"""
+         |select
+         | *
+         |from default.rangertbl5 outer
+         |where value in (select value from default.rangertbl4 where value = 'val_277')
+         |""".stripMargin
+    withUser("bob") {
+      val df = sql(statement)
+      println(df.queryExecution.optimizedPlan)
+      val row = df.take(1)(0)
+      assert(row.getString(1) === "xxx_277", "value shows last 4 characters")
+    }
+  }
+
+  test("MASK_SHOW_LAST_4 with correlated subquery") {
+    val statement =
+      s"""
+         |select
+         | *
+         |from default.rangertbl5 outer
+         |where key =
+         | (select max(key) from default.rangertbl4 where value = 'val_277' and value = outer.value)
+         |""".stripMargin
+    withUser("bob") {
+      val df = sql(statement)
+      println(df.queryExecution.optimizedPlan)
+      val row = df.take(1)(0)
+      assert(row.getString(1) === "xxx_277", "value shows last 4 characters")
+    }
+  }
+
+  test("CTE") {
+    val statement =
+      s"""
+         |with myCTE as
+         |(select
+         | *
+         |from default.rangertbl5 where value = 'val_277')
+         |select t1.value, t2.value from myCTE t1 join myCTE t2 on t1.key = t2.key
+         |
+         |""".stripMargin
+    withUser("bob") {
+      val df = sql(statement)
+      println(df.queryExecution.optimizedPlan)
+      val row = df.take(1)(0)
+      assert(row.getString(0) === "xxx_277", "value shows last 4 characters")
+      assert(row.getString(1) === "xxx_277", "value shows last 4 characters")
+    }
+  }
 }
diff --git a/submarine-security/spark-security/src/test/scala/org/apache/submarine/spark/security/RowFilterSQLTest.scala b/submarine-security/spark-security/src/test/scala/org/apache/submarine/spark/security/RowFilterSQLTest.scala
index bd3a2ad..c6eb47e 100644
--- a/submarine-security/spark-security/src/test/scala/org/apache/submarine/spark/security/RowFilterSQLTest.scala
+++ b/submarine-security/spark-security/src/test/scala/org/apache/submarine/spark/security/RowFilterSQLTest.scala
@@ -20,7 +20,7 @@
 package org.apache.submarine.spark.security
 
 import org.apache.spark.sql.SubmarineSparkUtils._
-import org.apache.spark.sql.catalyst.plans.logical.SubmarineRowFilter
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, SubmarineRowFilter}
 import org.apache.spark.sql.hive.test.TestHive
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
 
@@ -216,4 +216,59 @@ class RowFilterSQLTest extends FunSuite with BeforeAndAfterAll {
       assert(rows.forall { r => val x = r.getInt(0); x > 1 && x < 10 || x == 500 })
     }
   }
+
+  test("applying filters with uncorrelated subquery") {
+    val statement =
+      s"""
+         |select
+         | *
+         |from default.rangertbl1 outer
+         |where value in (select value from default.rangertbl2)
+         |""".stripMargin
+    withUser("bob") {
+      val df = sql(statement)
+      val plan = df.queryExecution.optimizedPlan
+      println(plan)
+      assert(plan.collect { case _: Filter => true }.size === 2, "tbl 1 and 2 have 2 filters")
+      val row = df.take(1)(0)
+      assert(row.getInt(0) === 0, "tbl 1 and 2 have 2 filters")
+    }
+  }
+
+  test("applying filters with correlated subquery") {
+    val statement =
+      s"""
+         |select
+         | *
+         |from default.rangertbl1 outer
+         |where key =
+         | (select max(key) from default.rangertbl2 where value = outer.value)
+         |""".stripMargin
+    withUser("bob") {
+      val df = sql(statement)
+      val plan = df.queryExecution.optimizedPlan
+      println(plan)
+      assert(plan.collectLeaves().size <= plan.collect { case _: SubmarineRowFilter => true}.size)
+      val row = df.take(1)(0)
+      assert(row.getString(1) === "val_0", "tbl 1 and 2 have 2 filters")
+    }
+  }
+
+  test("CTE") {
+    val statement =
+      s"""
+         |with myCTE as
+         |(select
+         | *
+         |from default.rangertbl1)
+         |select t1.value, t2.value from myCTE t1 join myCTE t2 on t1.key = t2.key
+         |
+         |""".stripMargin
+    withUser("bob") {
+      val df = sql(statement)
+      println(df.queryExecution.optimizedPlan)
+      val row = df.take(1)(0)
+      assert(row.getString(0) === "val_0", "rangertbl1 has an internal expression key=0")
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org