You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@trafficcontrol.apache.org by ro...@apache.org on 2018/06/05 22:48:12 UTC

[incubator-trafficcontrol] branch master updated: Fix go delivery service API validation

This is an automated email from the ASF dual-hosted git repository.

rob pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-trafficcontrol.git


The following commit(s) were added to refs/heads/master by this push:
     new 82d60b8  Fix go delivery service API validation
82d60b8 is described below

commit 82d60b8096938c3a4306f8d395e948650f0abc9f
Author: Rawlin Peters <ra...@comcast.com>
AuthorDate: Tue Jun 5 16:01:56 2018 -0600

    Fix go delivery service API validation
    
    Validate v1.3 DS PUT requests, and fix the "required if type X"
    validation.
---
 .../deliveryservice/deliveryservicesv12.go         | 124 +++++++++++----------
 .../deliveryservice/deliveryservicesv13.go         |   5 +
 2 files changed, 70 insertions(+), 59 deletions(-)

diff --git a/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv12.go b/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv12.go
index 67eb3ce..7184c71 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv12.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv12.go
@@ -218,19 +218,17 @@ func validateV12(db *sqlx.DB, ds *tc.DeliveryServiceNullableV12) []error {
 		"typeId":              validation.Validate(ds.TypeID, validation.Required, validation.Min(1)),
 		"xmlId":               validation.Validate(ds.XMLID, noSpaces, noPeriods, validation.Length(1, 48)),
 	}
-	if errs != nil {
-		return tovalidate.ToErrors(errs)
+	toErrs := tovalidate.ToErrors(errs)
+	if fieldErrs := validateTypeFields(db, ds); len(fieldErrs) > 0 {
+		toErrs = append(toErrs, fieldErrs...)
 	}
-	errsResponse := validateTypeFields(db, ds)
-	if errsResponse != nil {
-		return errsResponse
+	if len(toErrs) > 0 {
+		return toErrs
 	}
-
 	return nil
 }
 
 func validateTypeFields(db *sqlx.DB, ds *tc.DeliveryServiceNullableV12) []error {
-	fmt.Printf("validateTypeFields\n")
 	// Validate the TypeName related fields below
 	var typeName string
 	var err error
@@ -238,43 +236,67 @@ func validateTypeFields(db *sqlx.DB, ds *tc.DeliveryServiceNullableV12) []error
 	HTTPRegexType := "^HTTP.*$"
 	SteeringRegexType := "^STEERING.*$"
 
-	if db != nil && ds == nil || ds.TypeID != nil {
-		typeID := *ds.TypeID
-		typeName, err = getTypeName(db, typeID)
-		if err != nil {
-			return []error{err}
-		}
+	if ds.TypeID == nil {
+		return []error{errors.New("missing typeID")}
 	}
 
-	if typeName != "" {
-		errs := validation.Errors{
-			"initialDispersion": validation.Validate(ds.InitialDispersion,
-				validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, typeName))),
-			"ipv6RoutingEnabled": validation.Validate(ds.IPV6RoutingEnabled,
-				validation.By(requiredIfMatchesTypeName([]string{SteeringRegexType, DNSRegexType, HTTPRegexType}, typeName))),
-			"missLat": validation.Validate(ds.MissLat,
-				validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, typeName))),
-			"missLong": validation.Validate(ds.MissLong,
-				validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, typeName))),
-			"multiSiteOrigin": validation.Validate(ds.MultiSiteOrigin,
-				validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, typeName))),
-			"orgServerFqdn": validation.Validate(ds.OrgServerFQDN,
-				validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, typeName))),
-			"protocol": validation.Validate(ds.Protocol,
-				validation.By(requiredIfMatchesTypeName([]string{SteeringRegexType, DNSRegexType, HTTPRegexType}, typeName))),
-			"qstringIgnore": validation.Validate(ds.QStringIgnore,
-				validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, typeName))),
-			"rangeRequestHandling": validation.Validate(ds.RangeRequestHandling,
-				validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, typeName))),
-		}
-		return tovalidate.ToErrors(errs)
+	typeName, ok, err := getTypeName(db, *ds.TypeID)
+	if err != nil {
+		return []error{err}
+	}
+	if !ok {
+		return []error{errors.New("type not found")}
+	}
+
+	errs := validation.Errors{
+		"initialDispersion": validation.Validate(ds.InitialDispersion,
+			validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, typeName))),
+		"ipv6RoutingEnabled": validation.Validate(ds.IPV6RoutingEnabled,
+			validation.By(requiredIfMatchesTypeName([]string{SteeringRegexType, DNSRegexType, HTTPRegexType}, typeName))),
+		"missLat": validation.Validate(ds.MissLat,
+			validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, typeName))),
+		"missLong": validation.Validate(ds.MissLong,
+			validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, typeName))),
+		"multiSiteOrigin": validation.Validate(ds.MultiSiteOrigin,
+			validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, typeName))),
+		"orgServerFqdn": validation.Validate(ds.OrgServerFQDN,
+			validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, typeName))),
+		"protocol": validation.Validate(ds.Protocol,
+			validation.By(requiredIfMatchesTypeName([]string{SteeringRegexType, DNSRegexType, HTTPRegexType}, typeName))),
+		"qstringIgnore": validation.Validate(ds.QStringIgnore,
+			validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, typeName))),
+		"rangeRequestHandling": validation.Validate(ds.RangeRequestHandling,
+			validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, typeName))),
+	}
+	toErrs := tovalidate.ToErrors(errs)
+	if len(toErrs) > 0 {
+		return toErrs
 	}
 	return nil
 }
 
 func requiredIfMatchesTypeName(patterns []string, typeName string) func(interface{}) error {
 	return func(value interface{}) error {
-
+		switch v := value.(type) {
+		case *int:
+			if v != nil {
+				return nil
+			}
+		case *bool:
+			if v != nil {
+				return nil
+			}
+		case *string:
+			if v != nil {
+				return nil
+			}
+		case *float64:
+			if v != nil {
+				return nil
+			}
+		default:
+			return fmt.Errorf("validation failure: unknown type %T", value)
+		}
 		pattern := strings.Join(patterns, "|")
 		var err error
 		var match bool
@@ -288,31 +310,15 @@ func requiredIfMatchesTypeName(patterns []string, typeName string) func(interfac
 	}
 }
 
-// TODO: drichardson - refactor to the types.go once implemented.
-func getTypeName(db *sqlx.DB, typeID int) (string, error) {
-
-	query := `SELECT name from type where id=$1`
-
-	var rows *sqlx.Rows
-	var err error
-
-	rows, err = db.Queryx(query, typeID)
-	if err != nil {
-		return "", err
-	}
-	defer rows.Close()
-
-	typeResults := []tc.Type{}
-	for rows.Next() {
-		var s tc.Type
-		if err = rows.StructScan(&s); err != nil {
-			return "", fmt.Errorf("getting Type: %v", err)
+func getTypeName(db *sqlx.DB, typeID int) (string, bool, error) {
+	name := ""
+	if err := db.QueryRow(`SELECT name from type where id=$1`, typeID).Scan(&name); err != nil {
+		if err == sql.ErrNoRows {
+			return "", false, nil
 		}
-		typeResults = append(typeResults, s)
+		return "", false, errors.New("querying type name: " + err.Error())
 	}
-
-	typeName := typeResults[0].Name
-	return typeName, err
+	return name, true, nil
 }
 
 func CreateV12(db *sqlx.DB, cfg config.Config) http.HandlerFunc {
diff --git a/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv13.go b/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv13.go
index 55727a7..6244acf 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv13.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv13.go
@@ -375,6 +375,11 @@ func UpdateV13(db *sqlx.DB, cfg config.Config) http.HandlerFunc {
 		}
 		ds.ID = &id
 
+		if errs := validateV13(db, &ds); len(errs) > 0 {
+			api.HandleErr(w, r, http.StatusBadRequest, errors.New("invalid request: "+util.JoinErrs(errs).Error()), nil)
+			return
+		}
+
 		if authorized, err := isTenantAuthorized(*user, db, &ds.DeliveryServiceNullableV12); err != nil {
 			api.HandleErr(w, r, http.StatusInternalServerError, nil, errors.New("checking tenant: "+err.Error()))
 			return

-- 
To stop receiving notification emails like this one, please contact
rob@apache.org.