You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@trafficcontrol.apache.org by GitBox <gi...@apache.org> on 2018/09/27 17:42:13 UTC

[GitHub] dg4prez closed pull request #2882: Optimize tenancy check in Origins read query

dg4prez closed pull request #2882: Optimize tenancy check in Origins read query
URL: https://github.com/apache/trafficcontrol/pull/2882
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/traffic_ops/traffic_ops_golang/origin/origins.go b/traffic_ops/traffic_ops_golang/origin/origins.go
index 0df598f48..4c5f8fc00 100644
--- a/traffic_ops/traffic_ops_golang/origin/origins.go
+++ b/traffic_ops/traffic_ops_golang/origin/origins.go
@@ -141,57 +141,15 @@ func (origin *TOOrigin) IsTenantAuthorized(user *auth.CurrentUser) (bool, error)
 	return true, nil
 }
 
-// filterAuthorized will filter a slice of Origins based upon tenant. It assumes that tenancy is enabled
-func filterAuthorized(origins []tc.Origin, user *auth.CurrentUser, tx *sqlx.Tx) ([]tc.Origin, error) {
-	newOrigins := []tc.Origin{}
-	for _, origin := range origins {
-		if origin.TenantID == nil {
-			if origin.ID == nil {
-				return nil, errors.New("isResourceAuthorized for origin with nil ID: no tenant ID")
-			} else {
-				return nil, fmt.Errorf("isResourceAuthorized for origin %d: no tenant ID", *origin.ID)
-			}
-		}
-		// TODO add/use a helper func to make a single SQL call, for performance
-		ok, err := tenant.IsResourceAuthorizedToUserTx(*origin.TenantID, user, tx.Tx)
-		if err != nil {
-			if origin.ID == nil {
-				return nil, errors.New("isResourceAuthorized for origin with nil ID: " + err.Error())
-			} else {
-				return nil, fmt.Errorf("isResourceAuthorized for origin %d: "+err.Error(), *origin.ID)
-			}
-		}
-		if !ok {
-			continue
-		}
-		newOrigins = append(newOrigins, origin)
-	}
-	return newOrigins, nil
-}
-
 func (origin *TOOrigin) Read() ([]interface{}, error, error, int) {
 	returnable := []interface{}{}
 
-	privLevel := origin.ReqInfo.User.PrivLevel
-
-	origins, errs, errType := getOrigins(origin.ReqInfo.Params, origin.ReqInfo.Tx, privLevel)
+	origins, errs, errType := getOrigins(origin.ReqInfo.Params, origin.ReqInfo.Tx, origin.ReqInfo.User)
 	if len(errs) > 0 {
 		userErr, sysErr, errCode := api.TypeErrsToAPIErr(errs, errType)
 		return nil, userErr, sysErr, errCode
 	}
 
-	var err error
-	tenancyEnabled, err := tenant.IsTenancyEnabledTx(origin.ReqInfo.Tx.Tx)
-	if err != nil {
-		return nil, nil, errors.New("origin read: checking tenancy: " + err.Error()), http.StatusInternalServerError
-	}
-	if tenancyEnabled {
-		origins, err = filterAuthorized(origins, origin.ReqInfo.User, origin.ReqInfo.Tx)
-		if err != nil {
-			return nil, nil, errors.New("origin read: filtering authorized: " + err.Error()), http.StatusInternalServerError
-		}
-	}
-
 	for _, origin := range origins {
 		returnable = append(returnable, origin)
 	}
@@ -199,7 +157,7 @@ func (origin *TOOrigin) Read() ([]interface{}, error, error, int) {
 	return returnable, nil, nil, http.StatusOK
 }
 
-func getOrigins(params map[string]string, tx *sqlx.Tx, privLevel int) ([]tc.Origin, []error, tc.ApiErrorType) {
+func getOrigins(params map[string]string, tx *sqlx.Tx, user *auth.CurrentUser) ([]tc.Origin, []error, tc.ApiErrorType) {
 	var rows *sqlx.Rows
 	var err error
 
@@ -221,6 +179,13 @@ func getOrigins(params map[string]string, tx *sqlx.Tx, privLevel int) ([]tc.Orig
 		return nil, errs, tc.DataConflictError
 	}
 
+	tenantIDs, err := tenant.GetUserTenantIDListTx(tx.Tx, user.TenantID)
+	if err != nil {
+		log.Errorln("received error querying for user's tenants: " + err.Error())
+		return nil, []error{tc.DBError}, tc.SystemError
+	}
+	where, queryValues = dbhelpers.AddTenancyCheck(where, queryValues, "o.tenant", tenantIDs)
+
 	query := selectQuery() + where + orderBy
 	log.Debugln("Query is ", query)
 
diff --git a/traffic_ops/traffic_ops_golang/origin/origins_test.go b/traffic_ops/traffic_ops_golang/origin/origins_test.go
index 35368c5c4..75c47824a 100644
--- a/traffic_ops/traffic_ops_golang/origin/origins_test.go
+++ b/traffic_ops/traffic_ops_golang/origin/origins_test.go
@@ -87,10 +87,10 @@ func TestReadOrigins(t *testing.T) {
 
 	testOrigins := getTestOrigins()
 	cols := test.ColsFromStructByTag("db", tc.Origin{})
-	rows := sqlmock.NewRows(cols)
+	originRows := sqlmock.NewRows(cols)
 
 	for _, to := range testOrigins {
-		rows = rows.AddRow(
+		originRows = originRows.AddRow(
 			to.Cachegroup,
 			to.CachegroupID,
 			to.Coordinate,
@@ -112,11 +112,17 @@ func TestReadOrigins(t *testing.T) {
 			to.TenantID,
 		)
 	}
+
+	tenantRows := sqlmock.NewRows([]string{"id"})
+	tenantRows.AddRow(1)
+
 	mock.ExpectBegin()
-	mock.ExpectQuery("SELECT").WillReturnRows(rows)
+	mock.ExpectQuery("WITH").WillReturnRows(tenantRows)
+	mock.ExpectQuery("SELECT").WillReturnRows(originRows)
 	v := map[string]string{}
 
-	origins, errs, errType := getOrigins(v, db.MustBegin(), auth.PrivLevelAdmin)
+	testUser := auth.CurrentUser{TenantID: 1}
+	origins, errs, errType := getOrigins(v, db.MustBegin(), &testUser)
 	log.Debugln("%v-->", origins)
 	if len(errs) > 0 {
 		t.Errorf("getOrigins expected: no errors, actual: %v with error type: %s", errs, errType.String())


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services