You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@trafficcontrol.apache.org by de...@apache.org on 2018/07/24 22:09:54 UTC

[trafficcontrol] branch master updated: Fix TO Go API PUT/POST type checking

This is an automated email from the ASF dual-hosted git repository.

dewrich pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/trafficcontrol.git


The following commit(s) were added to refs/heads/master by this push:
     new 69cba53  Fix TO Go API PUT/POST type checking
69cba53 is described below

commit 69cba5347afb869c38b0159c664870d3a56f39fe
Author: Rawlin Peters <ra...@comcast.com>
AuthorDate: Thu Jul 19 18:04:39 2018 -0600

    Fix TO Go API PUT/POST type checking
    
    For entities that have a required type from the type table, validate
    the type.
    
    Fix the DS regex PUT/POST handlers to use transactions and some of the
    shared api.NewInfo functionality.
---
 lib/go-tc/deliveryservices.go                      |  35 +-----
 lib/go-tc/types.go                                 |  42 +++++++
 .../traffic_ops_golang/cachegroup/cachegroups.go   |   5 +
 .../cachegroup/cachegroups_test.go                 |  27 +++-
 .../deliveryservicesregexes.go                     | 140 ++++++++-------------
 traffic_ops/traffic_ops_golang/routes.go           |   4 +-
 traffic_ops/traffic_ops_golang/server/servers.go   |  21 +---
 .../staticdnsentry/staticdnsentry.go               |   4 +
 .../staticdnsentry/staticdnsentry_test.go          |  25 +++-
 9 files changed, 159 insertions(+), 144 deletions(-)

diff --git a/lib/go-tc/deliveryservices.go b/lib/go-tc/deliveryservices.go
index 384cffb..5a2e752 100644
--- a/lib/go-tc/deliveryservices.go
+++ b/lib/go-tc/deliveryservices.go
@@ -244,23 +244,6 @@ func (ds *DeliveryServiceNullableV12) Sanitize() {
 	}
 }
 
-// getTypeData returns the type's name and use_in_table, true/false if the query returned data, and any error
-func getTypeData(tx *sql.Tx, id int) (string, string, bool, error) {
-	name := ""
-	var useInTablePtr *string
-	if err := tx.QueryRow(`SELECT name, use_in_table from type where id=$1`, id).Scan(&name, &useInTablePtr); err != nil {
-		if err == sql.ErrNoRows {
-			return "", "", false, nil
-		}
-		return "", "", false, errors.New("querying type data: " + err.Error())
-	}
-	useInTable := ""
-	if useInTablePtr != nil {
-		useInTable = *useInTablePtr
-	}
-	return name, useInTable, true, nil
-}
-
 func requiredIfMatchesTypeName(patterns []string, typeName string) func(interface{}) error {
 	return func(value interface{}) error {
 		switch v := value.(type) {
@@ -296,30 +279,20 @@ func requiredIfMatchesTypeName(patterns []string, typeName string) func(interfac
 	}
 }
 
-// util.JoinErrs(errs).Error()
-
 func (ds *DeliveryServiceNullableV12) validateTypeFields(tx *sql.Tx) error {
 	// Validate the TypeName related fields below
-	typeName := ""
 	err := error(nil)
 	DNSRegexType := "^DNS.*$"
 	HTTPRegexType := "^HTTP.*$"
 	SteeringRegexType := "^STEERING.*$"
 	latitudeErr := "Must be a floating point number within the range +-90"
 	longitudeErr := "Must be a floating point number within the range +-180"
-	if ds.TypeID == nil {
-		return errors.New("missing type")
-	}
-	typeName, useInTable, ok, err := getTypeData(tx, *ds.TypeID)
+
+	typeName, err := ValidateTypeID(tx, ds.TypeID, "deliveryservice")
 	if err != nil {
-		return errors.New("getting type name: " + err.Error())
-	}
-	if !ok {
-		return errors.New("type not found")
-	}
-	if useInTable != "deliveryservice" {
-		return errors.New("type is not a valid deliveryservice type")
+		return err
 	}
+
 	errs := validation.Errors{
 		"initialDispersion": validation.Validate(ds.InitialDispersion,
 			validation.By(requiredIfMatchesTypeName([]string{HTTPRegexType}, typeName)),
diff --git a/lib/go-tc/types.go b/lib/go-tc/types.go
index d8887fc..206bc8f 100644
--- a/lib/go-tc/types.go
+++ b/lib/go-tc/types.go
@@ -1,5 +1,10 @@
 package tc
 
+import (
+	"database/sql"
+	"errors"
+)
+
 /*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
@@ -41,3 +46,40 @@ type TypeNullable struct {
 	Description *string    `json:"description" db:"description"`
 	UseInTable  *string    `json:"useInTable" db:"use_in_table"`
 }
+
+// GetTypeData returns the type's name and use_in_table, true/false if the query returned data, and any error
+func GetTypeData(tx *sql.Tx, id int) (string, string, bool, error) {
+	name := ""
+	var useInTablePtr *string
+	if err := tx.QueryRow(`SELECT name, use_in_table from type where id=$1`, id).Scan(&name, &useInTablePtr); err != nil {
+		if err == sql.ErrNoRows {
+			return "", "", false, nil
+		}
+		return "", "", false, errors.New("querying type data: " + err.Error())
+	}
+	useInTable := ""
+	if useInTablePtr != nil {
+		useInTable = *useInTablePtr
+	}
+	return name, useInTable, true, nil
+}
+
+// ValidateTypeID validates that the typeID references a type with the expected use_in_table string and
+// returns "" and an error if the typeID is invalid. If valid, the type's name is returned.
+func ValidateTypeID(tx *sql.Tx, typeID *int, expectedUseInTable string) (string, error) {
+	if typeID == nil {
+		return "", errors.New("missing type")
+	}
+
+	typeName, useInTable, ok, err := GetTypeData(tx, *typeID)
+	if err != nil {
+		return "", errors.New("validating type: " + err.Error())
+	}
+	if !ok {
+		return "", errors.New("type not found")
+	}
+	if useInTable != expectedUseInTable {
+		return "", errors.New("type is not a valid " + expectedUseInTable + " type")
+	}
+	return typeName, nil
+}
diff --git a/traffic_ops/traffic_ops_golang/cachegroup/cachegroups.go b/traffic_ops/traffic_ops_golang/cachegroup/cachegroups.go
index 1151bcc..22a7c4c 100644
--- a/traffic_ops/traffic_ops_golang/cachegroup/cachegroups.go
+++ b/traffic_ops/traffic_ops_golang/cachegroup/cachegroups.go
@@ -149,6 +149,11 @@ func IsValidParentCachegroupID(id *int) bool {
 
 // Validate fulfills the api.Validator interface
 func (cg TOCacheGroup) Validate() error {
+
+	if _, err := tc.ValidateTypeID(cg.ReqInfo.Tx.Tx, cg.TypeID, "cachegroup"); err != nil {
+		return err
+	}
+
 	validName := validation.NewStringRule(IsValidCacheGroupName, "invalid characters found - Use alphanumeric . or - or _ .")
 	validShortName := validation.NewStringRule(IsValidCacheGroupName, "invalid characters found - Use alphanumeric . or - or _ .")
 	latitudeErr := "Must be a floating point number within the range +-90"
diff --git a/traffic_ops/traffic_ops_golang/cachegroup/cachegroups_test.go b/traffic_ops/traffic_ops_golang/cachegroup/cachegroups_test.go
index f330c46..645f990 100644
--- a/traffic_ops/traffic_ops_golang/cachegroup/cachegroups_test.go
+++ b/traffic_ops/traffic_ops_golang/cachegroup/cachegroups_test.go
@@ -153,6 +153,24 @@ func TestInterfaces(t *testing.T) {
 }
 
 func TestValidate(t *testing.T) {
+	mockDB, mock, err := sqlmock.New()
+	if err != nil {
+		t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+	}
+	defer mockDB.Close()
+
+	db := sqlx.NewDb(mockDB, "sqlmock")
+	defer db.Close()
+
+	rows := sqlmock.NewRows([]string{"name", "use_in_table"})
+	rows.AddRow("EDGE_LOC", "cachegroup")
+
+	mock.ExpectBegin()
+	mock.ExpectQuery("SELECT").WillReturnRows(rows)
+	tx := db.MustBegin()
+
+	reqInfo := api.APIInfo{Tx: tx, CommitTx: util.BoolPtr(false)}
+
 	// invalid name, shortname, loattude, and longitude
 	id := 1
 	nm := "not!a!valid!cachegroup"
@@ -162,7 +180,7 @@ func TestValidate(t *testing.T) {
 	ty := "EDGE_LOC"
 	ti := 6
 	lu := tc.TimeNoMod{Time: time.Now()}
-	c := TOCacheGroup{CacheGroupNullable: v13.CacheGroupNullable{ID: &id,
+	c := TOCacheGroup{ReqInfo: &reqInfo, CacheGroupNullable: v13.CacheGroupNullable{ID: &id,
 		Name:        &nm,
 		ShortName:   &sn,
 		Latitude:    &la,
@@ -184,12 +202,15 @@ func TestValidate(t *testing.T) {
 		t.Errorf("expected %s, got %s", expectedErrs, errs)
 	}
 
+	rows.AddRow("EDGE_LOC", "cachegroup")
+	mock.ExpectQuery("SELECT").WillReturnRows(rows)
+
 	//  valid name, shortName latitude, longitude
 	nm = "This.is.2.a-Valid---Cachegroup."
 	sn = `awesome-cachegroup`
 	la = 90.0
 	lo = 90.0
-	c = TOCacheGroup{CacheGroupNullable: v13.CacheGroupNullable{ID: &id,
+	c = TOCacheGroup{ReqInfo: &reqInfo, CacheGroupNullable: v13.CacheGroupNullable{ID: &id,
 		Name:        &nm,
 		ShortName:   &sn,
 		Latitude:    &la,
@@ -198,7 +219,7 @@ func TestValidate(t *testing.T) {
 		TypeID:      &ti,
 		LastUpdated: &lu,
 	}}
-	err := c.Validate()
+	err = c.Validate()
 	if err != nil {
 		t.Errorf("expected nil, got %s", err)
 	}
diff --git a/traffic_ops/traffic_ops_golang/deliveryservicesregexes/deliveryservicesregexes.go b/traffic_ops/traffic_ops_golang/deliveryservicesregexes/deliveryservicesregexes.go
index 5b8caa0..b8a515d 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservicesregexes/deliveryservicesregexes.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservicesregexes/deliveryservicesregexes.go
@@ -272,39 +272,29 @@ ORDER BY dsr.set_number ASC
 	}
 }
 
-func Post(dbx *sqlx.DB) http.HandlerFunc {
-	db := dbx.DB
+func Post() http.HandlerFunc {
 	return func(w http.ResponseWriter, r *http.Request) {
-		handleErrs := tc.GetHandleErrorsFunc(w, r)
-		user, err := auth.GetCurrentUser(r.Context())
-		if err != nil {
-			handleErrs(http.StatusInternalServerError, errors.New("unable to retrieve current user from context: "+err.Error()))
-			return
-		}
-		params, err := api.GetCombinedParams(r)
-		if err != nil {
-			handleErrs(http.StatusInternalServerError, errors.New("unable to get parameters from request: "+err.Error()))
-			return
-		}
-		dsIDStr, ok := params["dsid"]
-		if !ok {
-			handleErrs(http.StatusInternalServerError, errors.New("no deliveryservice ID"))
-			return
-		}
-		dsID, err := strconv.Atoi(dsIDStr)
-		if err != nil {
-			handleErrs(http.StatusInternalServerError, errors.New("deliveryservice ID not an integer"))
+		defer r.Body.Close()
+		paramList := []string{"dsid"}
+		inf, userErr, sysErr, errCode := api.NewInfo(r, paramList, paramList)
+		if userErr != nil || sysErr != nil {
+			api.HandleErr(w, r, errCode, userErr, sysErr)
 			return
 		}
+		defer inf.Close()
+		tx := inf.Tx.Tx
+
+		handleErrs := tc.GetHandleErrorsFunc(w, r)
+		dsID := inf.IntParams["dsid"]
 		dsTenantID := 0
-		if err := db.QueryRow(`SELECT tenant_id from deliveryservice where id = $1`, dsID).Scan(&dsTenantID); err != nil {
+		if err := tx.QueryRow(`SELECT tenant_id from deliveryservice where id = $1`, dsID).Scan(&dsTenantID); err != nil {
 			if err != sql.ErrNoRows {
 				log.Errorln("getting deliveryservice name: " + err.Error())
 			}
 			handleErrs(http.StatusInternalServerError, err)
 			return
 		}
-		if ok, err := tenant.IsResourceAuthorizedToUser(dsTenantID, user, dbx); !ok {
+		if ok, err := tenant.IsResourceAuthorizedToUserTx(dsTenantID, inf.User, tx); !ok {
 			handleErrs(http.StatusInternalServerError, errors.New("unauthorized"))
 			return
 		} else if err != nil {
@@ -318,21 +308,26 @@ func Post(dbx *sqlx.DB) http.HandlerFunc {
 			return
 		}
 
+		if err := validateDSRegexType(tx, dsr.Type); err != nil {
+			handleErrs(http.StatusBadRequest, err)
+			return
+		}
+
 		regexID := 0
-		if err := db.QueryRow(`INSERT INTO regex (pattern, type) VALUES ($1, $2) RETURNING id`, dsr.Pattern, dsr.Type).Scan(&regexID); err != nil {
+		if err := tx.QueryRow(`INSERT INTO regex (pattern, type) VALUES ($1, $2) RETURNING id`, dsr.Pattern, dsr.Type).Scan(&regexID); err != nil {
 			log.Errorln("inserting regex: " + err.Error())
 			handleErrs(http.StatusInternalServerError, err)
 			return
 		}
 
-		if _, err := db.Exec(`INSERT INTO deliveryservice_regex (deliveryservice, regex, set_number) values ($1, $2, $3)`, dsID, regexID, dsr.SetNumber); err != nil {
+		if _, err := tx.Exec(`INSERT INTO deliveryservice_regex (deliveryservice, regex, set_number) values ($1, $2, $3)`, dsID, regexID, dsr.SetNumber); err != nil {
 			log.Errorln("inserting deliveryservice_regex: " + err.Error())
 			handleErrs(http.StatusInternalServerError, err)
 			return
 		}
 
 		typeName := ""
-		if err := db.QueryRow(`SELECT name from type where id = $1`, dsr.Type).Scan(&typeName); err != nil {
+		if err := tx.QueryRow(`SELECT name from type where id = $1`, dsr.Type).Scan(&typeName); err != nil {
 			if err != sql.ErrNoRows {
 				log.Errorln("getting regex type: " + err.Error())
 			}
@@ -347,66 +342,36 @@ func Post(dbx *sqlx.DB) http.HandlerFunc {
 			TypeName:  typeName,
 			SetNumber: dsr.SetNumber,
 		}
-		resp := struct {
-			Response tc.DeliveryServiceIDRegex `json:"response"`
-			tc.Alerts
-		}{respObj, tc.CreateAlerts(tc.SuccessLevel, "Delivery service regex creation was successful.")}
 
-		respBts, err := json.Marshal(&resp)
-		if err != nil {
-			handleErrs(http.StatusInternalServerError, errors.New("marshalling JSON: "+err.Error()))
-			return
-		}
-		w.Header().Set("Content-Type", "application/json")
-		w.Write(respBts)
+		*inf.CommitTx = true
+		api.WriteRespAlertObj(w, r, tc.SuccessLevel, "Delivery service regex creation was successful.", respObj)
 	}
 }
 
-func Put(dbx *sqlx.DB) http.HandlerFunc {
-	db := dbx.DB
+func Put() http.HandlerFunc {
 	return func(w http.ResponseWriter, r *http.Request) {
-		handleErrs := tc.GetHandleErrorsFunc(w, r)
-		user, err := auth.GetCurrentUser(r.Context())
-		if err != nil {
-			log.Errorf("unable to retrieve current user from context: %s", err)
-			handleErrs(http.StatusInternalServerError, err)
-			return
-		}
-		params, err := api.GetCombinedParams(r)
-		if err != nil {
-			log.Errorf("unable to get parameters from request: %s", err)
-			handleErrs(http.StatusInternalServerError, err)
-			return
-		}
-		dsIDStr, ok := params["dsid"]
-		if !ok {
-			handleErrs(http.StatusInternalServerError, err)
-			return
-		}
-		dsID, err := strconv.Atoi(dsIDStr)
-		if err != nil {
-			handleErrs(http.StatusInternalServerError, err)
-			return
-		}
-		regexIDStr, ok := params["regexid"]
-		if !ok {
-			handleErrs(http.StatusInternalServerError, errors.New("no regex ID"))
-			return
-		}
-		regexID, err := strconv.Atoi(regexIDStr)
-		if err != nil {
-			handleErrs(http.StatusInternalServerError, errors.New("Regex ID '"+regexIDStr+"' not an integer"))
+		defer r.Body.Close()
+		paramList := []string{"dsid", "regexid"}
+		inf, userErr, sysErr, errCode := api.NewInfo(r, paramList, paramList)
+		if userErr != nil || sysErr != nil {
+			api.HandleErr(w, r, errCode, userErr, sysErr)
 			return
 		}
+		defer inf.Close()
+		tx := inf.Tx.Tx
+
+		handleErrs := tc.GetHandleErrorsFunc(w, r)
+		dsID := inf.IntParams["dsid"]
+		regexID := inf.IntParams["regexid"]
 		dsTenantID := 0
-		if err := db.QueryRow(`SELECT tenant_id from deliveryservice where id = $1`, dsID).Scan(&dsTenantID); err != nil {
+		if err := tx.QueryRow(`SELECT tenant_id from deliveryservice where id = $1`, dsID).Scan(&dsTenantID); err != nil {
 			if err != sql.ErrNoRows {
 				log.Errorln("getting deliveryservice name: " + err.Error())
 			}
 			handleErrs(http.StatusInternalServerError, err)
 			return
 		}
-		if ok, err := tenant.IsResourceAuthorizedToUser(dsTenantID, user, dbx); !ok {
+		if ok, err := tenant.IsResourceAuthorizedToUserTx(dsTenantID, inf.User, tx); !ok {
 			handleErrs(http.StatusInternalServerError, errors.New("unauthorized"))
 			return
 		} else if err != nil {
@@ -419,18 +384,24 @@ func Put(dbx *sqlx.DB) http.HandlerFunc {
 			handleErrs(http.StatusInternalServerError, err)
 			return
 		}
-		if _, err := db.Exec(`UPDATE regex SET pattern=$1, type=$2 WHERE id=$3`, dsr.Pattern, dsr.Type, regexID); err != nil {
+
+		if err := validateDSRegexType(tx, dsr.Type); err != nil {
+			handleErrs(http.StatusBadRequest, err)
+			return
+		}
+
+		if _, err := tx.Exec(`UPDATE regex SET pattern=$1, type=$2 WHERE id=$3`, dsr.Pattern, dsr.Type, regexID); err != nil {
 			log.Errorln("deliveryservicesregexes.Put: updating regex: " + err.Error())
 			handleErrs(http.StatusInternalServerError, errors.New("server error"))
 			return
 		}
-		if _, err := db.Exec(`UPDATE deliveryservice_regex SET set_number=$1 WHERE deliveryservice=$2 AND regex=$3`, dsr.SetNumber, dsID, regexID); err != nil {
+		if _, err := tx.Exec(`UPDATE deliveryservice_regex SET set_number=$1 WHERE deliveryservice=$2 AND regex=$3`, dsr.SetNumber, dsID, regexID); err != nil {
 			log.Errorln("deliveryservicesregexes.Put: updating deliveryservice_regex: " + err.Error())
 			handleErrs(http.StatusInternalServerError, errors.New("server error"))
 			return
 		}
 		typeName := ""
-		if err := db.QueryRow(`SELECT name from type where id = $1`, dsr.Type).Scan(&typeName); err != nil {
+		if err := tx.QueryRow(`SELECT name from type where id = $1`, dsr.Type).Scan(&typeName); err != nil {
 			if err != sql.ErrNoRows {
 				log.Errorln("getting regex type: " + err.Error())
 			}
@@ -444,20 +415,17 @@ func Put(dbx *sqlx.DB) http.HandlerFunc {
 			TypeName:  typeName,
 			SetNumber: dsr.SetNumber,
 		}
-		resp := struct {
-			Response tc.DeliveryServiceIDRegex `json:"response"`
-			tc.Alerts
-		}{respObj, tc.CreateAlerts(tc.SuccessLevel, "Delivery service regex creation was successful.")}
-		respBts, err := json.Marshal(&resp)
-		if err != nil {
-			handleErrs(http.StatusInternalServerError, errors.New("marshalling JSON: "+err.Error()))
-			return
-		}
-		w.Header().Set("Content-Type", "application/json")
-		w.Write(respBts)
+
+		*inf.CommitTx = true
+		api.WriteRespAlertObj(w, r, tc.SuccessLevel, "Delivery service regex creation was successful.", respObj)
 	}
 }
 
+func validateDSRegexType(tx *sql.Tx, typeID int) error {
+	_, err := tc.ValidateTypeID(tx, &typeID, "regex")
+	return err
+}
+
 func Delete(dbx *sqlx.DB) http.HandlerFunc {
 	db := dbx.DB
 	return func(w http.ResponseWriter, r *http.Request) {
diff --git a/traffic_ops/traffic_ops_golang/routes.go b/traffic_ops/traffic_ops_golang/routes.go
index 6550c53..13e0470 100644
--- a/traffic_ops/traffic_ops_golang/routes.go
+++ b/traffic_ops/traffic_ops_golang/routes.go
@@ -308,8 +308,8 @@ func Routes(d ServerData) ([]Route, []RawRoute, http.Handler, error) {
 		{1.1, http.MethodGet, `deliveryservices_regexes/?(\.json)?$`, deliveryservicesregexes.Get(d.DB), auth.PrivLevelReadOnly, Authenticated, nil},
 		{1.1, http.MethodGet, `deliveryservices/{dsid}/regexes/?(\.json)?$`, deliveryservicesregexes.DSGet(d.DB), auth.PrivLevelReadOnly, Authenticated, nil},
 		{1.1, http.MethodGet, `deliveryservices/{dsid}/regexes/{regexid}?(\.json)?$`, deliveryservicesregexes.DSGetID(d.DB), auth.PrivLevelReadOnly, Authenticated, nil},
-		{1.1, http.MethodPost, `deliveryservices/{dsid}/regexes/?(\.json)?$`, deliveryservicesregexes.Post(d.DB), auth.PrivLevelOperations, Authenticated, nil},
-		{1.1, http.MethodPut, `deliveryservices/{dsid}/regexes/{regexid}?(\.json)?$`, deliveryservicesregexes.Put(d.DB), auth.PrivLevelOperations, Authenticated, nil},
+		{1.1, http.MethodPost, `deliveryservices/{dsid}/regexes/?(\.json)?$`, deliveryservicesregexes.Post(), auth.PrivLevelOperations, Authenticated, nil},
+		{1.1, http.MethodPut, `deliveryservices/{dsid}/regexes/{regexid}?(\.json)?$`, deliveryservicesregexes.Put(), auth.PrivLevelOperations, Authenticated, nil},
 		{1.1, http.MethodDelete, `deliveryservices/{dsid}/regexes/{regexid}?(\.json)?$`, deliveryservicesregexes.Delete(d.DB), auth.PrivLevelOperations, Authenticated, nil},
 
 		//Servers
diff --git a/traffic_ops/traffic_ops_golang/server/servers.go b/traffic_ops/traffic_ops_golang/server/servers.go
index 91ef619..e766f0f 100644
--- a/traffic_ops/traffic_ops_golang/server/servers.go
+++ b/traffic_ops/traffic_ops_golang/server/servers.go
@@ -33,7 +33,7 @@ import (
 	"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/auth"
 	"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/dbhelpers"
 
-	validation "github.com/go-ozzo/ozzo-validation"
+	"github.com/go-ozzo/ozzo-validation"
 	"github.com/go-ozzo/ozzo-validation/is"
 	"github.com/jmoiron/sqlx"
 	"github.com/lib/pq"
@@ -108,24 +108,11 @@ func (server *TOServer) Validate() error {
 		return util.JoinErrs(errs)
 	}
 
-	rows, err := server.ReqInfo.Tx.Query("select use_in_table from type where id=$1", server.TypeID)
-	if err != nil {
-		log.Error.Printf("could not execute select use_in_table from type: %s\n", err)
-		return tc.DBError
-	}
-	defer rows.Close()
-	var useInTable string
-	for rows.Next() {
-		if err := rows.Scan(&useInTable); err != nil {
-			log.Error.Printf("could not scan use_in_table from type: %s\n", err)
-			return tc.DBError
-		}
-	}
-	if useInTable != "server" {
-		errs = append(errs, errors.New("invalid server type"))
+	if _, err := tc.ValidateTypeID(server.ReqInfo.Tx.Tx, server.TypeID, "server"); err != nil {
+		return err
 	}
 
-	rows, err = server.ReqInfo.Tx.Query("select cdn from profile where id=$1", server.ProfileID)
+	rows, err := server.ReqInfo.Tx.Query("select cdn from profile where id=$1", server.ProfileID)
 	if err != nil {
 		log.Error.Printf("could not execute select cdnID from profile: %s\n", err)
 		errs = append(errs, tc.DBError)
diff --git a/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry.go b/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry.go
index 744514b..70b089b 100644
--- a/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry.go
+++ b/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry.go
@@ -81,6 +81,10 @@ func (staticDNSEntry *TOStaticDNSEntry) SetKeys(keys map[string]interface{}) {
 
 // Validate fulfills the api.Validator interface
 func (staticDNSEntry TOStaticDNSEntry) Validate() error {
+	if _, err := tc.ValidateTypeID(staticDNSEntry.ReqInfo.Tx.Tx, &staticDNSEntry.TypeID, "staticdnsentry"); err != nil {
+		return err
+	}
+
 	errs := validation.Errors{
 		"host":              validation.Validate(staticDNSEntry.Host, validation.Required, is.DNSName),
 		"address":           validation.Validate(staticDNSEntry.Address, validation.Required, is.Host),
diff --git a/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry_test.go b/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry_test.go
index 6d556bf..e773fd1 100644
--- a/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry_test.go
+++ b/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry_test.go
@@ -28,6 +28,8 @@ import (
 	util "github.com/apache/trafficcontrol/lib/go-util"
 	"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/api"
 	"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/test"
+	"github.com/jmoiron/sqlx"
+	"gopkg.in/DATA-DOG/go-sqlmock.v1"
 )
 
 func TestFuncs(t *testing.T) {
@@ -68,8 +70,25 @@ func TestInterfaces(t *testing.T) {
 }
 
 func TestValidate(t *testing.T) {
+	mockDB, mock, err := sqlmock.New()
+	if err != nil {
+		t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+	}
+	defer mockDB.Close()
+
+	db := sqlx.NewDb(mockDB, "sqlmock")
+	defer db.Close()
+
+	rows := sqlmock.NewRows([]string{"name", "use_in_table"})
+	rows.AddRow("A_RECORD", "staticdnsentry")
+
+	mock.ExpectBegin()
+	mock.ExpectQuery("SELECT").WillReturnRows(rows)
+	tx := db.MustBegin()
+
+	reqInfo := api.APIInfo{Tx: tx, CommitTx: util.BoolPtr(false)}
 	// invalid name, empty domainname
-	ts := TOStaticDNSEntry{}
+	ts := TOStaticDNSEntry{ReqInfo: &reqInfo}
 	errs := util.JoinErrsStr(test.SortErrors(test.SplitErrors(ts.Validate())))
 
 	expectedErrs := util.JoinErrsStr([]error{
@@ -83,8 +102,4 @@ func TestValidate(t *testing.T) {
 	if !reflect.DeepEqual(expectedErrs, errs) {
 		t.Errorf("expected %s, GOT %s", expectedErrs, errs)
 	}
-	//if err != nil {
-	//t.Errorf("expected nil, got %s", err)
-	//}
-
 }