You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@kyuubi.apache.org by GitBox <gi...@apache.org> on 2022/03/30 06:39:44 UTC

[GitHub] [incubator-kyuubi] yanghua commented on a change in pull request #2160: [KYUUBI #1451] Introduce Kyuubi Spark AuthZ Module with column-level fine-grained authorization

yanghua commented on a change in pull request #2160:
URL: https://github.com/apache/incubator-kyuubi/pull/2160#discussion_r838170314



##########
File path: extensions/spark/kyuubi-spark-authz/src/main/scala/org/apache/kyuubi/plugin/spark/authz/PrivilegesBuilder.scala
##########
@@ -0,0 +1,496 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.plugin.spark.authz
+
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success, Try}
+
+import org.apache.spark.SPARK_VERSION
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.catalog.CatalogTable
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
+import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
+import org.apache.spark.sql.catalyst.plans.logical.{Command, Filter, Join, LogicalPlan, Project}
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.types.StructField
+
+import org.apache.kyuubi.plugin.spark.authz.PrivilegeObjectActionType._
+import org.apache.kyuubi.plugin.spark.authz.PrivilegeObjectType._
+
+object PrivilegesBuilder {
+
+  private val versionParts = SPARK_VERSION.split('.')
+  private val majorVersion: Int = versionParts.head.toInt
+  private val minorVersion: Int = versionParts(1).toInt
+
+  private def quoteIfNeeded(part: String): String = {
+    if (part.matches("[a-zA-Z0-9_]+") && !part.matches("\\d+")) {
+      part
+    } else {
+      s"`${part.replace("`", "``")}`"
+    }
+  }
+
+  private def quote(parts: Seq[String]): String = {
+    parts.map(quoteIfNeeded).mkString(".")
+  }
+
+  /**
+   * fixme error handling need improve here
+   */
+  private def getFieldVal[T](o: Any, name: String): T = {
+    Try {
+      val field = o.getClass.getDeclaredField(name)
+      field.setAccessible(true)
+      field.get(o)
+    } match {
+      case Success(value) => value.asInstanceOf[T]
+      case Failure(e) =>
+        val candidates = o.getClass.getDeclaredFields.map(_.getName).mkString("[", ",", "]")
+        throw new RuntimeException(s"$name not in $candidates", e)
+    }
+  }
+
+  private def databasePrivileges(db: String): PrivilegeObject = {
+    PrivilegeObject(DATABASE, PrivilegeObjectActionType.OTHER, db, db)
+  }
+
+  private def tablePrivileges(
+      table: TableIdentifier,
+      columns: Seq[String] = Nil,
+      actionType: PrivilegeObjectActionType = PrivilegeObjectActionType.OTHER): PrivilegeObject = {
+    PrivilegeObject(TABLE_OR_VIEW, actionType, table.database.orNull, table.table, columns)
+  }
+
+  private def functionPrivileges(
+      db: String,
+      functionName: String): PrivilegeObject = {
+    PrivilegeObject(FUNCTION, PrivilegeObjectActionType.OTHER, db, functionName)
+  }
+
+  private def collectLeaves(expr: Expression): Seq[NamedExpression] = {
+    expr.collect { case p: NamedExpression if p.children.isEmpty => p }
+  }
+
+  /**
+   * Build PrivilegeObjects from Spark LogicalPlan
+   *
+   * @param plan a Spark LogicalPlan used to generate SparkPrivilegeObjects
+   * @param privilegeObjects input or output spark privilege object list
+   * @param projectionList Projection list after pruning
+   */
+  private def buildQuery(
+      plan: LogicalPlan,
+      privilegeObjects: ArrayBuffer[PrivilegeObject],
+      projectionList: Seq[NamedExpression] = Nil): Unit = {
+
+    def mergeProjection(table: CatalogTable, plan: LogicalPlan): Unit = {
+      if (projectionList.isEmpty) {
+        privilegeObjects += tablePrivileges(
+          table.identifier,
+          table.schema.fieldNames)
+      } else {
+        val cols = projectionList.flatMap(collectLeaves)
+          .filter(plan.outputSet.contains).map(_.name).distinct
+        privilegeObjects += tablePrivileges(table.identifier, cols)
+      }
+    }
+
+    plan match {
+      case p: Project => buildQuery(p.child, privilegeObjects, p.projectList)
+
+      case j: Join =>
+        val cols =
+          projectionList ++ j.condition.map(expr => collectLeaves(expr)).getOrElse(Nil)
+        buildQuery(j.left, privilegeObjects, cols)
+        buildQuery(j.right, privilegeObjects, cols)
+
+      case f: Filter =>
+        val cols = projectionList ++ collectLeaves(f.condition)
+        buildQuery(f.child, privilegeObjects, cols)
+
+      case h if h.nodeName == "HiveTableRelation" =>
+        mergeProjection(getFieldVal[CatalogTable](h, "tableMeta"), h)
+
+      case l if l.nodeName == "LogicalRelation" =>
+        getFieldVal[Option[CatalogTable]](l, "catalogTable").foreach { t =>
+          mergeProjection(t, plan)
+        }
+
+      case u if u.nodeName == "UnresolvedRelation" =>
+        val tableNameM = u.getClass.getMethod("tableName")
+        val parts = tableNameM.invoke(u).asInstanceOf[String].split("\\.")
+        val db = quote(parts.init)
+        privilegeObjects += tablePrivileges(TableIdentifier(parts.last, Some(db)))
+
+      case p =>
+        for (child <- p.children) {
+          buildQuery(child, privilegeObjects, projectionList)
+        }
+    }
+  }
+
+  /**
+   * Build PrivilegeObjects from Spark LogicalPlan
+   * @param plan a Spark LogicalPlan used to generate Spark PrivilegeObjects
+   * @param inputObjs input privilege object list
+   * @param outputObjs output privilege object list
+   */
+  private def buildCommand(
+      plan: LogicalPlan,
+      inputObjs: ArrayBuffer[PrivilegeObject],
+      outputObjs: ArrayBuffer[PrivilegeObject]): Unit = {
+
+    def getPlanField[T](field: String): T = {
+      getFieldVal[T](plan, field)
+    }
+
+    def getTableName: TableIdentifier = {
+      getPlanField[TableIdentifier]("tableName")
+    }
+
+    def getTableIdent: TableIdentifier = {
+      getPlanField[TableIdentifier]("tableIdent")
+    }
+
+    def getMultipartIdentifier: TableIdentifier = {
+      val multipartIdentifier = getPlanField[Seq[String]]("multipartIdentifier")
+      assert(multipartIdentifier.nonEmpty)
+      val table = multipartIdentifier.last
+      val db = Some(quote(multipartIdentifier.init))
+      TableIdentifier(table, db)
+    }
+
+    def getQuery: LogicalPlan = {
+      getPlanField[LogicalPlan]("query")
+    }
+
+    plan.nodeName match {
+      case "AlterDatabasePropertiesCommand" |
+          "AlterDatabaseSetLocationCommand" |
+          "CreateDatabaseCommand" |
+          "DropDatabaseCommand" =>
+        val database = getPlanField[String]("databaseName")
+        outputObjs += databasePrivileges(database)
+
+      case "AlterTableAddColumnsCommand" =>
+        val table = getPlanField[TableIdentifier]("table")
+        val cols = getPlanField[Seq[StructField]]("colsToAdd").map(_.name)
+        outputObjs += tablePrivileges(table, cols)
+
+      case "AlterTableAddPartitionCommand" =>
+        val table = getTableName
+        val cols = getPlanField[Seq[(TablePartitionSpec, Option[String])]]("partitionSpecsAndLocs")
+          .flatMap(_._1.keySet).distinct
+        outputObjs += tablePrivileges(table, cols)
+
+      case "AlterTableChangeColumnCommand" =>
+        val table = getTableName
+        val cols = getPlanField[String]("columnName") :: Nil
+        outputObjs += tablePrivileges(table, cols)
+
+      case "AlterTableDropPartitionCommand" =>
+        val table = getTableName
+        val cols = getPlanField[Seq[TablePartitionSpec]]("specs").flatMap(_.keySet).distinct
+        outputObjs += tablePrivileges(table, cols)
+
+      case "AlterTableRenameCommand" =>
+        val oldTable = getPlanField[TableIdentifier]("oldName")
+        val newTable = getPlanField[TableIdentifier]("newName")
+        outputObjs += tablePrivileges(oldTable, actionType = PrivilegeObjectActionType.DELETE)
+        outputObjs += tablePrivileges(newTable)
+
+      case "AlterTableRenamePartitionCommand" =>
+        val table = getTableName
+        val cols = getPlanField[TablePartitionSpec]("oldPartition").keySet.toSeq
+        outputObjs += tablePrivileges(table, cols)
+
+      case "AlterTableSerDePropertiesCommand" =>
+        val table = getTableName
+        val cols = getPlanField[Option[TablePartitionSpec]]("partSpec")
+          .toSeq.flatMap(_.keySet)
+        outputObjs += tablePrivileges(table, cols)
+
+      case "AlterTableSetLocationCommand" =>
+        val table = getTableName
+        val cols = getPlanField[Option[TablePartitionSpec]]("partitionSpec")
+          .toSeq.flatMap(_.keySet)
+        outputObjs += tablePrivileges(table, cols)
+
+      case "AlterTableSetPropertiesCommand" |
+          "AlterTableUnsetPropertiesCommand" =>
+        val table = getTableName
+        outputObjs += tablePrivileges(table)
+
+      case "AlterViewAsCommand" =>
+        val view = getPlanField[TableIdentifier]("name")
+        outputObjs += tablePrivileges(view)
+        buildQuery(getQuery, inputObjs)
+
+      case "AlterViewAs" =>
+
+      case "AnalyzeColumnCommand" =>
+        val table = getTableIdent
+        val cols = getPlanField[Option[Seq[String]]]("columnNames").getOrElse(Nil)
+        inputObjs += tablePrivileges(table, cols)
+
+      case "AnalyzePartitionCommand" =>
+        val table = getTableIdent
+        val cols = getPlanField[Map[String, Option[String]]]("partitionSpec")
+          .keySet.toSeq
+        inputObjs += tablePrivileges(table, cols)
+
+      case "AnalyzeTableCommand" |
+          "RefreshTableCommand" |
+          "RefreshTable" =>
+        inputObjs += tablePrivileges(getTableIdent)
+
+      case "AnalyzeTablesCommand" |
+          "ShowTablesCommand" =>
+        val db = getPlanField[Option[String]]("databaseName")
+        if (db.nonEmpty) {
+          inputObjs += databasePrivileges(db.get)
+        }
+
+      case "CacheTable" =>
+        // >= 3.2
+        outputObjs += tablePrivileges(getMultipartIdentifier)
+        val query = getPlanField[LogicalPlan]("table") // table to cache
+        buildQuery(query, inputObjs)
+
+      case "CacheTableCommand" =>
+        if (majorVersion == 3 && minorVersion == 1) {

Review comment:
       Implement different logic for different versions, looks better?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: notifications-unsubscribe@kyuubi.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: notifications-unsubscribe@kyuubi.apache.org
For additional commands, e-mail: notifications-help@kyuubi.apache.org