You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@trafficcontrol.apache.org by de...@apache.org on 2018/07/24 22:09:54 UTC
[trafficcontrol] branch master updated: Fix TO Go API PUT/POST type
checking
This is an automated email from the ASF dual-hosted git repository.
dewrich pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/trafficcontrol.git
The following commit(s) were added to refs/heads/master by this push:
new 69cba53 Fix TO Go API PUT/POST type checking
69cba53 is described below
commit 69cba5347afb869c38b0159c664870d3a56f39fe
Author: Rawlin Peters <ra...@comcast.com>
AuthorDate: Thu Jul 19 18:04:39 2018 -0600
Fix TO Go API PUT/POST type checking
For entities that have a required type from the type table, validate
the type.
Fix the DS regex PUT/POST handlers to use transactions and some of the
shared api.NewInfo functionality.
---
lib/go-tc/deliveryservices.go | 35 +-----
lib/go-tc/types.go | 42 +++++++
.../traffic_ops_golang/cachegroup/cachegroups.go | 5 +
.../cachegroup/cachegroups_test.go | 27 +++-
.../deliveryservicesregexes.go | 140 ++++++++-------------
traffic_ops/traffic_ops_golang/routes.go | 4 +-
traffic_ops/traffic_ops_golang/server/servers.go | 21 +---
.../staticdnsentry/staticdnsentry.go | 4 +
.../staticdnsentry/staticdnsentry_test.go | 25 +++-
9 files changed, 159 insertions(+), 144 deletions(-)
diff --git a/lib/go-tc/deliveryservices.go b/lib/go-tc/deliveryservices.go
index 384cffb..5a2e752 100644
--- a/lib/go-tc/deliveryservices.go
+++ b/lib/go-tc/deliveryservices.go
@@ -244,23 +244,6 @@ func (ds *DeliveryServiceNullableV12) Sanitize() {
}
}
-// getTypeData returns the type's name and use_in_table, true/false if the query returned data, and any error
-func getTypeData(tx *sql.Tx, id int) (string, string, bool, error) {
- name := ""
- var useInTablePtr *string
- if err := tx.QueryRow(`SELECT name, use_in_table from type where id=$1`, id).Scan(&name, &useInTablePtr); err != nil {
- if err == sql.ErrNoRows {
- return "", "", false, nil
- }
- return "", "", false, errors.New("querying type data: " + err.Error())
- }
- useInTable := ""
- if useInTablePtr != nil {
- useInTable = *useInTablePtr
- }
- return name, useInTable, true, nil
-}
-
func requiredIfMatchesTypeName(patterns []string, typeName string) func(interface{}) error {
return func(value interface{}) error {
switch v := value.(type) {
@@ -296,30 +279,20 @@ func requiredIfMatchesTypeName(patterns []string, typeName string) func(interfac
}
}
-// util.JoinErrs(errs).Error()
-
func (ds *DeliveryServiceNullableV12) validateTypeFields(tx *sql.Tx) error {
// Validate the TypeName related fields below
- typeName := ""
err := error(nil)
DNSRegexType := "^DNS.*$"
HTTPRegexType := "^HTTP.*$"
SteeringRegexType := "^STEERING.*$"
latitudeErr := "Must be a floating point number within the range +-90"
longitudeErr := "Must be a floating point number within the range +-180"
- if ds.TypeID == nil {
- return errors.New("missing type")
- }
- typeName, useInTable, ok, err := getTypeData(tx, *ds.TypeID)
+
+ typeName, err := ValidateTypeID(tx, ds.TypeID, "deliveryservice")
if err != nil {
- return errors.New("getting type name: " + err.Error())
- }
- if !ok {
- return errors.New("type not found")
- }
- if useInTable != "deliveryservice" {
- return errors.New("type is not a valid deliveryservice type")
+ return err
}
+
errs := validation.Errors{
"initialDispersion": validation.Validate(ds.InitialDispersion,
validation.By(requiredIfMatchesTypeName([]string{HTTPRegexType}, typeName)),
diff --git a/lib/go-tc/types.go b/lib/go-tc/types.go
index d8887fc..206bc8f 100644
--- a/lib/go-tc/types.go
+++ b/lib/go-tc/types.go
@@ -1,5 +1,10 @@
package tc
+import (
+ "database/sql"
+ "errors"
+)
+
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
@@ -41,3 +46,40 @@ type TypeNullable struct {
Description *string `json:"description" db:"description"`
UseInTable *string `json:"useInTable" db:"use_in_table"`
}
+
+// GetTypeData returns the type's name and use_in_table, true/false if the query returned data, and any error
+func GetTypeData(tx *sql.Tx, id int) (string, string, bool, error) {
+ name := ""
+ var useInTablePtr *string
+ if err := tx.QueryRow(`SELECT name, use_in_table from type where id=$1`, id).Scan(&name, &useInTablePtr); err != nil {
+ if err == sql.ErrNoRows {
+ return "", "", false, nil
+ }
+ return "", "", false, errors.New("querying type data: " + err.Error())
+ }
+ useInTable := ""
+ if useInTablePtr != nil {
+ useInTable = *useInTablePtr
+ }
+ return name, useInTable, true, nil
+}
+
+// ValidateTypeID validates that the typeID references a type with the expected use_in_table string and
+// returns "" and an error if the typeID is invalid. If valid, the type's name is returned.
+func ValidateTypeID(tx *sql.Tx, typeID *int, expectedUseInTable string) (string, error) {
+ if typeID == nil {
+ return "", errors.New("missing type")
+ }
+
+ typeName, useInTable, ok, err := GetTypeData(tx, *typeID)
+ if err != nil {
+ return "", errors.New("validating type: " + err.Error())
+ }
+ if !ok {
+ return "", errors.New("type not found")
+ }
+ if useInTable != expectedUseInTable {
+ return "", errors.New("type is not a valid " + expectedUseInTable + " type")
+ }
+ return typeName, nil
+}
diff --git a/traffic_ops/traffic_ops_golang/cachegroup/cachegroups.go b/traffic_ops/traffic_ops_golang/cachegroup/cachegroups.go
index 1151bcc..22a7c4c 100644
--- a/traffic_ops/traffic_ops_golang/cachegroup/cachegroups.go
+++ b/traffic_ops/traffic_ops_golang/cachegroup/cachegroups.go
@@ -149,6 +149,11 @@ func IsValidParentCachegroupID(id *int) bool {
// Validate fulfills the api.Validator interface
func (cg TOCacheGroup) Validate() error {
+
+ if _, err := tc.ValidateTypeID(cg.ReqInfo.Tx.Tx, cg.TypeID, "cachegroup"); err != nil {
+ return err
+ }
+
validName := validation.NewStringRule(IsValidCacheGroupName, "invalid characters found - Use alphanumeric . or - or _ .")
validShortName := validation.NewStringRule(IsValidCacheGroupName, "invalid characters found - Use alphanumeric . or - or _ .")
latitudeErr := "Must be a floating point number within the range +-90"
diff --git a/traffic_ops/traffic_ops_golang/cachegroup/cachegroups_test.go b/traffic_ops/traffic_ops_golang/cachegroup/cachegroups_test.go
index f330c46..645f990 100644
--- a/traffic_ops/traffic_ops_golang/cachegroup/cachegroups_test.go
+++ b/traffic_ops/traffic_ops_golang/cachegroup/cachegroups_test.go
@@ -153,6 +153,24 @@ func TestInterfaces(t *testing.T) {
}
func TestValidate(t *testing.T) {
+ mockDB, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer mockDB.Close()
+
+ db := sqlx.NewDb(mockDB, "sqlmock")
+ defer db.Close()
+
+ rows := sqlmock.NewRows([]string{"name", "use_in_table"})
+ rows.AddRow("EDGE_LOC", "cachegroup")
+
+ mock.ExpectBegin()
+ mock.ExpectQuery("SELECT").WillReturnRows(rows)
+ tx := db.MustBegin()
+
+ reqInfo := api.APIInfo{Tx: tx, CommitTx: util.BoolPtr(false)}
+
// invalid name, shortname, loattude, and longitude
id := 1
nm := "not!a!valid!cachegroup"
@@ -162,7 +180,7 @@ func TestValidate(t *testing.T) {
ty := "EDGE_LOC"
ti := 6
lu := tc.TimeNoMod{Time: time.Now()}
- c := TOCacheGroup{CacheGroupNullable: v13.CacheGroupNullable{ID: &id,
+ c := TOCacheGroup{ReqInfo: &reqInfo, CacheGroupNullable: v13.CacheGroupNullable{ID: &id,
Name: &nm,
ShortName: &sn,
Latitude: &la,
@@ -184,12 +202,15 @@ func TestValidate(t *testing.T) {
t.Errorf("expected %s, got %s", expectedErrs, errs)
}
+ rows.AddRow("EDGE_LOC", "cachegroup")
+ mock.ExpectQuery("SELECT").WillReturnRows(rows)
+
// valid name, shortName latitude, longitude
nm = "This.is.2.a-Valid---Cachegroup."
sn = `awesome-cachegroup`
la = 90.0
lo = 90.0
- c = TOCacheGroup{CacheGroupNullable: v13.CacheGroupNullable{ID: &id,
+ c = TOCacheGroup{ReqInfo: &reqInfo, CacheGroupNullable: v13.CacheGroupNullable{ID: &id,
Name: &nm,
ShortName: &sn,
Latitude: &la,
@@ -198,7 +219,7 @@ func TestValidate(t *testing.T) {
TypeID: &ti,
LastUpdated: &lu,
}}
- err := c.Validate()
+ err = c.Validate()
if err != nil {
t.Errorf("expected nil, got %s", err)
}
diff --git a/traffic_ops/traffic_ops_golang/deliveryservicesregexes/deliveryservicesregexes.go b/traffic_ops/traffic_ops_golang/deliveryservicesregexes/deliveryservicesregexes.go
index 5b8caa0..b8a515d 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservicesregexes/deliveryservicesregexes.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservicesregexes/deliveryservicesregexes.go
@@ -272,39 +272,29 @@ ORDER BY dsr.set_number ASC
}
}
-func Post(dbx *sqlx.DB) http.HandlerFunc {
- db := dbx.DB
+func Post() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
- handleErrs := tc.GetHandleErrorsFunc(w, r)
- user, err := auth.GetCurrentUser(r.Context())
- if err != nil {
- handleErrs(http.StatusInternalServerError, errors.New("unable to retrieve current user from context: "+err.Error()))
- return
- }
- params, err := api.GetCombinedParams(r)
- if err != nil {
- handleErrs(http.StatusInternalServerError, errors.New("unable to get parameters from request: "+err.Error()))
- return
- }
- dsIDStr, ok := params["dsid"]
- if !ok {
- handleErrs(http.StatusInternalServerError, errors.New("no deliveryservice ID"))
- return
- }
- dsID, err := strconv.Atoi(dsIDStr)
- if err != nil {
- handleErrs(http.StatusInternalServerError, errors.New("deliveryservice ID not an integer"))
+ defer r.Body.Close()
+ paramList := []string{"dsid"}
+ inf, userErr, sysErr, errCode := api.NewInfo(r, paramList, paramList)
+ if userErr != nil || sysErr != nil {
+ api.HandleErr(w, r, errCode, userErr, sysErr)
return
}
+ defer inf.Close()
+ tx := inf.Tx.Tx
+
+ handleErrs := tc.GetHandleErrorsFunc(w, r)
+ dsID := inf.IntParams["dsid"]
dsTenantID := 0
- if err := db.QueryRow(`SELECT tenant_id from deliveryservice where id = $1`, dsID).Scan(&dsTenantID); err != nil {
+ if err := tx.QueryRow(`SELECT tenant_id from deliveryservice where id = $1`, dsID).Scan(&dsTenantID); err != nil {
if err != sql.ErrNoRows {
log.Errorln("getting deliveryservice name: " + err.Error())
}
handleErrs(http.StatusInternalServerError, err)
return
}
- if ok, err := tenant.IsResourceAuthorizedToUser(dsTenantID, user, dbx); !ok {
+ if ok, err := tenant.IsResourceAuthorizedToUserTx(dsTenantID, inf.User, tx); !ok {
handleErrs(http.StatusInternalServerError, errors.New("unauthorized"))
return
} else if err != nil {
@@ -318,21 +308,26 @@ func Post(dbx *sqlx.DB) http.HandlerFunc {
return
}
+ if err := validateDSRegexType(tx, dsr.Type); err != nil {
+ handleErrs(http.StatusBadRequest, err)
+ return
+ }
+
regexID := 0
- if err := db.QueryRow(`INSERT INTO regex (pattern, type) VALUES ($1, $2) RETURNING id`, dsr.Pattern, dsr.Type).Scan(®exID); err != nil {
+ if err := tx.QueryRow(`INSERT INTO regex (pattern, type) VALUES ($1, $2) RETURNING id`, dsr.Pattern, dsr.Type).Scan(®exID); err != nil {
log.Errorln("inserting regex: " + err.Error())
handleErrs(http.StatusInternalServerError, err)
return
}
- if _, err := db.Exec(`INSERT INTO deliveryservice_regex (deliveryservice, regex, set_number) values ($1, $2, $3)`, dsID, regexID, dsr.SetNumber); err != nil {
+ if _, err := tx.Exec(`INSERT INTO deliveryservice_regex (deliveryservice, regex, set_number) values ($1, $2, $3)`, dsID, regexID, dsr.SetNumber); err != nil {
log.Errorln("inserting deliveryservice_regex: " + err.Error())
handleErrs(http.StatusInternalServerError, err)
return
}
typeName := ""
- if err := db.QueryRow(`SELECT name from type where id = $1`, dsr.Type).Scan(&typeName); err != nil {
+ if err := tx.QueryRow(`SELECT name from type where id = $1`, dsr.Type).Scan(&typeName); err != nil {
if err != sql.ErrNoRows {
log.Errorln("getting regex type: " + err.Error())
}
@@ -347,66 +342,36 @@ func Post(dbx *sqlx.DB) http.HandlerFunc {
TypeName: typeName,
SetNumber: dsr.SetNumber,
}
- resp := struct {
- Response tc.DeliveryServiceIDRegex `json:"response"`
- tc.Alerts
- }{respObj, tc.CreateAlerts(tc.SuccessLevel, "Delivery service regex creation was successful.")}
- respBts, err := json.Marshal(&resp)
- if err != nil {
- handleErrs(http.StatusInternalServerError, errors.New("marshalling JSON: "+err.Error()))
- return
- }
- w.Header().Set("Content-Type", "application/json")
- w.Write(respBts)
+ *inf.CommitTx = true
+ api.WriteRespAlertObj(w, r, tc.SuccessLevel, "Delivery service regex creation was successful.", respObj)
}
}
-func Put(dbx *sqlx.DB) http.HandlerFunc {
- db := dbx.DB
+func Put() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
- handleErrs := tc.GetHandleErrorsFunc(w, r)
- user, err := auth.GetCurrentUser(r.Context())
- if err != nil {
- log.Errorf("unable to retrieve current user from context: %s", err)
- handleErrs(http.StatusInternalServerError, err)
- return
- }
- params, err := api.GetCombinedParams(r)
- if err != nil {
- log.Errorf("unable to get parameters from request: %s", err)
- handleErrs(http.StatusInternalServerError, err)
- return
- }
- dsIDStr, ok := params["dsid"]
- if !ok {
- handleErrs(http.StatusInternalServerError, err)
- return
- }
- dsID, err := strconv.Atoi(dsIDStr)
- if err != nil {
- handleErrs(http.StatusInternalServerError, err)
- return
- }
- regexIDStr, ok := params["regexid"]
- if !ok {
- handleErrs(http.StatusInternalServerError, errors.New("no regex ID"))
- return
- }
- regexID, err := strconv.Atoi(regexIDStr)
- if err != nil {
- handleErrs(http.StatusInternalServerError, errors.New("Regex ID '"+regexIDStr+"' not an integer"))
+ defer r.Body.Close()
+ paramList := []string{"dsid", "regexid"}
+ inf, userErr, sysErr, errCode := api.NewInfo(r, paramList, paramList)
+ if userErr != nil || sysErr != nil {
+ api.HandleErr(w, r, errCode, userErr, sysErr)
return
}
+ defer inf.Close()
+ tx := inf.Tx.Tx
+
+ handleErrs := tc.GetHandleErrorsFunc(w, r)
+ dsID := inf.IntParams["dsid"]
+ regexID := inf.IntParams["regexid"]
dsTenantID := 0
- if err := db.QueryRow(`SELECT tenant_id from deliveryservice where id = $1`, dsID).Scan(&dsTenantID); err != nil {
+ if err := tx.QueryRow(`SELECT tenant_id from deliveryservice where id = $1`, dsID).Scan(&dsTenantID); err != nil {
if err != sql.ErrNoRows {
log.Errorln("getting deliveryservice name: " + err.Error())
}
handleErrs(http.StatusInternalServerError, err)
return
}
- if ok, err := tenant.IsResourceAuthorizedToUser(dsTenantID, user, dbx); !ok {
+ if ok, err := tenant.IsResourceAuthorizedToUserTx(dsTenantID, inf.User, tx); !ok {
handleErrs(http.StatusInternalServerError, errors.New("unauthorized"))
return
} else if err != nil {
@@ -419,18 +384,24 @@ func Put(dbx *sqlx.DB) http.HandlerFunc {
handleErrs(http.StatusInternalServerError, err)
return
}
- if _, err := db.Exec(`UPDATE regex SET pattern=$1, type=$2 WHERE id=$3`, dsr.Pattern, dsr.Type, regexID); err != nil {
+
+ if err := validateDSRegexType(tx, dsr.Type); err != nil {
+ handleErrs(http.StatusBadRequest, err)
+ return
+ }
+
+ if _, err := tx.Exec(`UPDATE regex SET pattern=$1, type=$2 WHERE id=$3`, dsr.Pattern, dsr.Type, regexID); err != nil {
log.Errorln("deliveryservicesregexes.Put: updating regex: " + err.Error())
handleErrs(http.StatusInternalServerError, errors.New("server error"))
return
}
- if _, err := db.Exec(`UPDATE deliveryservice_regex SET set_number=$1 WHERE deliveryservice=$2 AND regex=$3`, dsr.SetNumber, dsID, regexID); err != nil {
+ if _, err := tx.Exec(`UPDATE deliveryservice_regex SET set_number=$1 WHERE deliveryservice=$2 AND regex=$3`, dsr.SetNumber, dsID, regexID); err != nil {
log.Errorln("deliveryservicesregexes.Put: updating deliveryservice_regex: " + err.Error())
handleErrs(http.StatusInternalServerError, errors.New("server error"))
return
}
typeName := ""
- if err := db.QueryRow(`SELECT name from type where id = $1`, dsr.Type).Scan(&typeName); err != nil {
+ if err := tx.QueryRow(`SELECT name from type where id = $1`, dsr.Type).Scan(&typeName); err != nil {
if err != sql.ErrNoRows {
log.Errorln("getting regex type: " + err.Error())
}
@@ -444,20 +415,17 @@ func Put(dbx *sqlx.DB) http.HandlerFunc {
TypeName: typeName,
SetNumber: dsr.SetNumber,
}
- resp := struct {
- Response tc.DeliveryServiceIDRegex `json:"response"`
- tc.Alerts
- }{respObj, tc.CreateAlerts(tc.SuccessLevel, "Delivery service regex creation was successful.")}
- respBts, err := json.Marshal(&resp)
- if err != nil {
- handleErrs(http.StatusInternalServerError, errors.New("marshalling JSON: "+err.Error()))
- return
- }
- w.Header().Set("Content-Type", "application/json")
- w.Write(respBts)
+
+ *inf.CommitTx = true
+ api.WriteRespAlertObj(w, r, tc.SuccessLevel, "Delivery service regex creation was successful.", respObj)
}
}
+func validateDSRegexType(tx *sql.Tx, typeID int) error {
+ _, err := tc.ValidateTypeID(tx, &typeID, "regex")
+ return err
+}
+
func Delete(dbx *sqlx.DB) http.HandlerFunc {
db := dbx.DB
return func(w http.ResponseWriter, r *http.Request) {
diff --git a/traffic_ops/traffic_ops_golang/routes.go b/traffic_ops/traffic_ops_golang/routes.go
index 6550c53..13e0470 100644
--- a/traffic_ops/traffic_ops_golang/routes.go
+++ b/traffic_ops/traffic_ops_golang/routes.go
@@ -308,8 +308,8 @@ func Routes(d ServerData) ([]Route, []RawRoute, http.Handler, error) {
{1.1, http.MethodGet, `deliveryservices_regexes/?(\.json)?$`, deliveryservicesregexes.Get(d.DB), auth.PrivLevelReadOnly, Authenticated, nil},
{1.1, http.MethodGet, `deliveryservices/{dsid}/regexes/?(\.json)?$`, deliveryservicesregexes.DSGet(d.DB), auth.PrivLevelReadOnly, Authenticated, nil},
{1.1, http.MethodGet, `deliveryservices/{dsid}/regexes/{regexid}?(\.json)?$`, deliveryservicesregexes.DSGetID(d.DB), auth.PrivLevelReadOnly, Authenticated, nil},
- {1.1, http.MethodPost, `deliveryservices/{dsid}/regexes/?(\.json)?$`, deliveryservicesregexes.Post(d.DB), auth.PrivLevelOperations, Authenticated, nil},
- {1.1, http.MethodPut, `deliveryservices/{dsid}/regexes/{regexid}?(\.json)?$`, deliveryservicesregexes.Put(d.DB), auth.PrivLevelOperations, Authenticated, nil},
+ {1.1, http.MethodPost, `deliveryservices/{dsid}/regexes/?(\.json)?$`, deliveryservicesregexes.Post(), auth.PrivLevelOperations, Authenticated, nil},
+ {1.1, http.MethodPut, `deliveryservices/{dsid}/regexes/{regexid}?(\.json)?$`, deliveryservicesregexes.Put(), auth.PrivLevelOperations, Authenticated, nil},
{1.1, http.MethodDelete, `deliveryservices/{dsid}/regexes/{regexid}?(\.json)?$`, deliveryservicesregexes.Delete(d.DB), auth.PrivLevelOperations, Authenticated, nil},
//Servers
diff --git a/traffic_ops/traffic_ops_golang/server/servers.go b/traffic_ops/traffic_ops_golang/server/servers.go
index 91ef619..e766f0f 100644
--- a/traffic_ops/traffic_ops_golang/server/servers.go
+++ b/traffic_ops/traffic_ops_golang/server/servers.go
@@ -33,7 +33,7 @@ import (
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/auth"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/dbhelpers"
- validation "github.com/go-ozzo/ozzo-validation"
+ "github.com/go-ozzo/ozzo-validation"
"github.com/go-ozzo/ozzo-validation/is"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
@@ -108,24 +108,11 @@ func (server *TOServer) Validate() error {
return util.JoinErrs(errs)
}
- rows, err := server.ReqInfo.Tx.Query("select use_in_table from type where id=$1", server.TypeID)
- if err != nil {
- log.Error.Printf("could not execute select use_in_table from type: %s\n", err)
- return tc.DBError
- }
- defer rows.Close()
- var useInTable string
- for rows.Next() {
- if err := rows.Scan(&useInTable); err != nil {
- log.Error.Printf("could not scan use_in_table from type: %s\n", err)
- return tc.DBError
- }
- }
- if useInTable != "server" {
- errs = append(errs, errors.New("invalid server type"))
+ if _, err := tc.ValidateTypeID(server.ReqInfo.Tx.Tx, server.TypeID, "server"); err != nil {
+ return err
}
- rows, err = server.ReqInfo.Tx.Query("select cdn from profile where id=$1", server.ProfileID)
+ rows, err := server.ReqInfo.Tx.Query("select cdn from profile where id=$1", server.ProfileID)
if err != nil {
log.Error.Printf("could not execute select cdnID from profile: %s\n", err)
errs = append(errs, tc.DBError)
diff --git a/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry.go b/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry.go
index 744514b..70b089b 100644
--- a/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry.go
+++ b/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry.go
@@ -81,6 +81,10 @@ func (staticDNSEntry *TOStaticDNSEntry) SetKeys(keys map[string]interface{}) {
// Validate fulfills the api.Validator interface
func (staticDNSEntry TOStaticDNSEntry) Validate() error {
+ if _, err := tc.ValidateTypeID(staticDNSEntry.ReqInfo.Tx.Tx, &staticDNSEntry.TypeID, "staticdnsentry"); err != nil {
+ return err
+ }
+
errs := validation.Errors{
"host": validation.Validate(staticDNSEntry.Host, validation.Required, is.DNSName),
"address": validation.Validate(staticDNSEntry.Address, validation.Required, is.Host),
diff --git a/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry_test.go b/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry_test.go
index 6d556bf..e773fd1 100644
--- a/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry_test.go
+++ b/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry_test.go
@@ -28,6 +28,8 @@ import (
util "github.com/apache/trafficcontrol/lib/go-util"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/api"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/test"
+ "github.com/jmoiron/sqlx"
+ "gopkg.in/DATA-DOG/go-sqlmock.v1"
)
func TestFuncs(t *testing.T) {
@@ -68,8 +70,25 @@ func TestInterfaces(t *testing.T) {
}
func TestValidate(t *testing.T) {
+ mockDB, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer mockDB.Close()
+
+ db := sqlx.NewDb(mockDB, "sqlmock")
+ defer db.Close()
+
+ rows := sqlmock.NewRows([]string{"name", "use_in_table"})
+ rows.AddRow("A_RECORD", "staticdnsentry")
+
+ mock.ExpectBegin()
+ mock.ExpectQuery("SELECT").WillReturnRows(rows)
+ tx := db.MustBegin()
+
+ reqInfo := api.APIInfo{Tx: tx, CommitTx: util.BoolPtr(false)}
// invalid name, empty domainname
- ts := TOStaticDNSEntry{}
+ ts := TOStaticDNSEntry{ReqInfo: &reqInfo}
errs := util.JoinErrsStr(test.SortErrors(test.SplitErrors(ts.Validate())))
expectedErrs := util.JoinErrsStr([]error{
@@ -83,8 +102,4 @@ func TestValidate(t *testing.T) {
if !reflect.DeepEqual(expectedErrs, errs) {
t.Errorf("expected %s, GOT %s", expectedErrs, errs)
}
- //if err != nil {
- //t.Errorf("expected nil, got %s", err)
- //}
-
}