You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@trafficcontrol.apache.org by ro...@apache.org on 2018/06/01 16:07:04 UTC

[incubator-trafficcontrol] 04/05: refactor AddTenancyCheck into dbhelpers where it belongs

This is an automated email from the ASF dual-hosted git repository.

rob pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-trafficcontrol.git

commit 3967c650e27b86e426bfd95463a7c471624f4da2
Author: Dylan Volz <Dy...@comcast.com>
AuthorDate: Thu May 31 21:13:51 2018 -0600

    refactor AddTenancyCheck into dbhelpers where it belongs
---
 .../traffic_ops_golang/dbhelpers/db_helpers.go     | 12 +++++++++
 .../deliveryservice/deliveryservicesv13.go         | 25 ++---------------
 traffic_ops/traffic_ops_golang/tenant/tenancy.go   | 31 ++++++++++++++++++++++
 3 files changed, 45 insertions(+), 23 deletions(-)

diff --git a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
index c396d74..8fed9e9 100644
--- a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
+++ b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
@@ -144,3 +144,15 @@ func FinishTxX(tx *sqlx.Tx, commit *bool) {
 	}
 	tx.Commit()
 }
+
+func AddTenancyCheck(where string, queryValues map[string]interface{},tenantColumnName string, tenantIDs []int) (string, map[string]interface{}) {
+	if where == "" {
+		where = BaseWhere + " " + tenantColumnName + " = ANY((:accessibleTenants)::::bigint[])"
+	} else {
+		where += " AND "+ tenantColumnName + " = ANY((:accessibleTenants)::::bigint[])"
+	}
+
+	queryValues["accessibleTenants"] = pq.Array(tenantIDs)
+
+	return where, queryValues
+}
\ No newline at end of file
diff --git a/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv13.go b/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv13.go
index 8b1b8b7..3d74e69 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv13.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv13.go
@@ -615,27 +615,6 @@ func filterAuthorized(dses []tc.DeliveryServiceNullableV13, user auth.CurrentUse
 	return newDSes, nil
 }
 
-func addTenancyCheck(where string, queryValues map[string]interface{}, user auth.CurrentUser, db *sqlx.DB) (string, map[string]interface{}, error) {
-	if where == "" {
-		where = dbhelpers.BaseWhere + " ds.tenant_id = ANY((:accessibleTenants)::::bigint[])"
-	} else {
-		where += " AND ds.tenant_id = ANY((:accessibleTenants)::::bigint[])"
-	}
-
-	tenants, err := tenant.GetUserTenantList(user, db)
-	if err != nil {
-		return "", queryValues, err
-	}
-
-	tenantIDs := make([]int, len(tenants))
-	for i, tenant := range tenants {
-		tenantIDs[i] = tenant.ID
-	}
-	queryValues["accessibleTenants"] = pq.Array(tenantIDs)
-
-	return where, queryValues, nil
-}
-
 func readGetDeliveryServices(params map[string]string, db *sqlx.DB, user auth.CurrentUser) ([]tc.DeliveryServiceNullableV13, []error, tc.ApiErrorType) {
 	if strings.HasSuffix(params["id"], ".json") {
 		params["id"] = params["id"][:len(params["id"])-len(".json")]
@@ -654,12 +633,12 @@ func readGetDeliveryServices(params map[string]string, db *sqlx.DB, user auth.Cu
 
 	if tenant.IsTenancyEnabled(db) {
 		log.Debugln("Tenancy is enabled")
-		var err error
-		where, queryValues, err = addTenancyCheck(where, queryValues, user, db)
+		tenantIDs, err := tenant.GetUserTenantIDList(user, db)
 		if err != nil {
 			log.Errorln("received error querying for user's tenants: " + err.Error())
 			return nil, []error{tc.DBError}, tc.SystemError
 		}
+		where, queryValues = dbhelpers.AddTenancyCheck(where, queryValues, "ds.tenant_id", tenantIDs)
 	}
 	query := selectQuery() + where + orderBy
 
diff --git a/traffic_ops/traffic_ops_golang/tenant/tenancy.go b/traffic_ops/traffic_ops_golang/tenant/tenancy.go
index b802483..b3298cd 100644
--- a/traffic_ops/traffic_ops_golang/tenant/tenancy.go
+++ b/traffic_ops/traffic_ops_golang/tenant/tenancy.go
@@ -122,6 +122,37 @@ func GetUserTenantList(user auth.CurrentUser, db *sqlx.DB) ([]Tenant, error) {
 	return tenants, nil
 }
 
+// returns a TenantID list that the specified user has access too.
+// NOTE: This method does not use the use_tenancy parameter and if this method is being used
+// to control tenancy the parameter must be checked. The method IsResourceAuthorizedToUser checks the use_tenancy parameter
+// and should be used for this purpose in most cases.
+func GetUserTenantIDList(user auth.CurrentUser, db *sqlx.DB) ([]int, error) {
+	query := `WITH RECURSIVE q AS (SELECT id, name, active, parent_id FROM tenant WHERE id = $1
+	UNION SELECT t.id, t.name, t.active, t.parent_id  FROM tenant t JOIN q ON q.id = t.parent_id)
+	SELECT id FROM q;`
+
+	log.Debugln("\nQuery: ", query)
+
+	var tenantID int
+
+	rows, err := db.Query(query, user.TenantID)
+	if err != nil {
+		return nil, err
+	}
+	defer rows.Close()
+
+	tenants := []int{}
+
+	for rows.Next() {
+		if err := rows.Scan(&tenantID); err != nil {
+			return nil, err
+		}
+			tenants = append(tenants, tenantID)
+	}
+
+	return tenants, nil
+}
+
 // IsTenancyEnabled returns true if tenancy is enabled or false otherwise
 func IsTenancyEnabled(db *sqlx.DB) bool {
 	query := `SELECT COALESCE(value::boolean,FALSE) AS value FROM parameter WHERE name = 'use_tenancy' AND config_file = 'global' UNION ALL SELECT FALSE FETCH FIRST 1 ROW ONLY`

-- 
To stop receiving notification emails like this one, please contact
rob@apache.org.