dewrich closed pull request #2572: Fix TO Go API PUT/POST type checking
URL: https://github.com/apache/trafficcontrol/pull/2572
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/lib/go-tc/deliveryservices.go b/lib/go-tc/deliveryservices.go
index 1a5d1e4fa..0dd56ca48 100644
--- a/lib/go-tc/deliveryservices.go
+++ b/lib/go-tc/deliveryservices.go
@@ -241,23 +241,6 @@ func (ds *DeliveryServiceNullableV12) Sanitize() {
}
}
-// getTypeData returns the type's name and use_in_table, true/false if the
query returned data, and any error
-func getTypeData(tx *sql.Tx, id int) (string, string, bool, error) {
- name := ""
- var useInTablePtr *string
- if err := tx.QueryRow(`SELECT name, use_in_table from type where
id=$1`, id).Scan(&name, &useInTablePtr); err != nil {
- if err == sql.ErrNoRows {
- return "", "", false, nil
- }
- return "", "", false, errors.New("querying type data: " +
err.Error())
- }
- useInTable := ""
- if useInTablePtr != nil {
- useInTable = *useInTablePtr
- }
- return name, useInTable, true, nil
-}
-
func requiredIfMatchesTypeName(patterns []string, typeName string)
func(interface{}) error {
return func(value interface{}) error {
switch v := value.(type) {
@@ -293,30 +276,20 @@ func requiredIfMatchesTypeName(patterns []string,
typeName string) func(interfac
}
}
-// util.JoinErrs(errs).Error()
-
func (ds *DeliveryServiceNullableV12) validateTypeFields(tx *sql.Tx) error {
// Validate the TypeName related fields below
- typeName := ""
err := error(nil)
DNSRegexType := "^DNS.*$"
HTTPRegexType := "^HTTP.*$"
SteeringRegexType := "^STEERING.*$"
latitudeErr := "Must be a floating point number within the range +-90"
longitudeErr := "Must be a floating point number within the range +-180"
- if ds.TypeID == nil {
- return errors.New("missing type")
- }
- typeName, useInTable, ok, err := getTypeData(tx, *ds.TypeID)
+
+ typeName, err := ValidateTypeID(tx, ds.TypeID, "deliveryservice")
if err != nil {
- return errors.New("getting type name: " + err.Error())
- }
- if !ok {
- return errors.New("type not found")
- }
- if useInTable != "deliveryservice" {
- return errors.New("type is not a valid deliveryservice type")
+ return err
}
+
errs := validation.Errors{
"initialDispersion": validation.Validate(ds.InitialDispersion,
validation.By(requiredIfMatchesTypeName([]string{HTTPRegexType}, typeName)),
diff --git a/lib/go-tc/types.go b/lib/go-tc/types.go
index d8887fc50..206bc8f03 100644
--- a/lib/go-tc/types.go
+++ b/lib/go-tc/types.go
@@ -1,5 +1,10 @@
package tc
+import (
+ "database/sql"
+ "errors"
+)
+
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
@@ -41,3 +46,40 @@ type TypeNullable struct {
Description *string `json:"description" db:"description"`
UseInTable *string `json:"useInTable" db:"use_in_table"`
}
+
+// GetTypeData returns the type's name and use_in_table, true/false if the
query returned data, and any error
+func GetTypeData(tx *sql.Tx, id int) (string, string, bool, error) {
+ name := ""
+ var useInTablePtr *string
+ if err := tx.QueryRow(`SELECT name, use_in_table from type where
id=$1`, id).Scan(&name, &useInTablePtr); err != nil {
+ if err == sql.ErrNoRows {
+ return "", "", false, nil
+ }
+ return "", "", false, errors.New("querying type data: " +
err.Error())
+ }
+ useInTable := ""
+ if useInTablePtr != nil {
+ useInTable = *useInTablePtr
+ }
+ return name, useInTable, true, nil
+}
+
+// ValidateTypeID validates that the typeID references a type with the
expected use_in_table string and
+// returns "" and an error if the typeID is invalid. If valid, the type's name
is returned.
+func ValidateTypeID(tx *sql.Tx, typeID *int, expectedUseInTable string)
(string, error) {
+ if typeID == nil {
+ return "", errors.New("missing type")
+ }
+
+ typeName, useInTable, ok, err := GetTypeData(tx, *typeID)
+ if err != nil {
+ return "", errors.New("validating type: " + err.Error())
+ }
+ if !ok {
+ return "", errors.New("type not found")
+ }
+ if useInTable != expectedUseInTable {
+ return "", errors.New("type is not a valid " +
expectedUseInTable + " type")
+ }
+ return typeName, nil
+}
diff --git a/traffic_ops/traffic_ops_golang/cachegroup/cachegroups.go
b/traffic_ops/traffic_ops_golang/cachegroup/cachegroups.go
index 6a008fcf1..a4c939eb9 100644
--- a/traffic_ops/traffic_ops_golang/cachegroup/cachegroups.go
+++ b/traffic_ops/traffic_ops_golang/cachegroup/cachegroups.go
@@ -149,6 +149,11 @@ func IsValidParentCachegroupID(id *int) bool {
// Validate fulfills the api.Validator interface
func (cg TOCacheGroup) Validate() error {
+
+ if _, err := tc.ValidateTypeID(cg.ReqInfo.Tx.Tx, cg.TypeID,
"cachegroup"); err != nil {
+ return err
+ }
+
validName := validation.NewStringRule(IsValidCacheGroupName, "invalid
characters found - Use alphanumeric . or - or _ .")
validShortName := validation.NewStringRule(IsValidCacheGroupName,
"invalid characters found - Use alphanumeric . or - or _ .")
latitudeErr := "Must be a floating point number within the range +-90"
diff --git a/traffic_ops/traffic_ops_golang/cachegroup/cachegroups_test.go
b/traffic_ops/traffic_ops_golang/cachegroup/cachegroups_test.go
index f330c461f..645f990e3 100644
--- a/traffic_ops/traffic_ops_golang/cachegroup/cachegroups_test.go
+++ b/traffic_ops/traffic_ops_golang/cachegroup/cachegroups_test.go
@@ -153,6 +153,24 @@ func TestInterfaces(t *testing.T) {
}
func TestValidate(t *testing.T) {
+ mockDB, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub
database connection", err)
+ }
+ defer mockDB.Close()
+
+ db := sqlx.NewDb(mockDB, "sqlmock")
+ defer db.Close()
+
+ rows := sqlmock.NewRows([]string{"name", "use_in_table"})
+ rows.AddRow("EDGE_LOC", "cachegroup")
+
+ mock.ExpectBegin()
+ mock.ExpectQuery("SELECT").WillReturnRows(rows)
+ tx := db.MustBegin()
+
+ reqInfo := api.APIInfo{Tx: tx, CommitTx: util.BoolPtr(false)}
+
// invalid name, shortname, loattude, and longitude
id := 1
nm := "not!a!valid!cachegroup"
@@ -162,7 +180,7 @@ func TestValidate(t *testing.T) {
ty := "EDGE_LOC"
ti := 6
lu := tc.TimeNoMod{Time: time.Now()}
- c := TOCacheGroup{CacheGroupNullable: v13.CacheGroupNullable{ID: &id,
+ c := TOCacheGroup{ReqInfo: &reqInfo, CacheGroupNullable:
v13.CacheGroupNullable{ID: &id,
Name: &nm,
ShortName: &sn,
Latitude: &la,
@@ -184,12 +202,15 @@ func TestValidate(t *testing.T) {
t.Errorf("expected %s, got %s", expectedErrs, errs)
}
+ rows.AddRow("EDGE_LOC", "cachegroup")
+ mock.ExpectQuery("SELECT").WillReturnRows(rows)
+
// valid name, shortName latitude, longitude
nm = "This.is.2.a-Valid---Cachegroup."
sn = `awesome-cachegroup`
la = 90.0
lo = 90.0
- c = TOCacheGroup{CacheGroupNullable: v13.CacheGroupNullable{ID: &id,
+ c = TOCacheGroup{ReqInfo: &reqInfo, CacheGroupNullable:
v13.CacheGroupNullable{ID: &id,
Name: &nm,
ShortName: &sn,
Latitude: &la,
@@ -198,7 +219,7 @@ func TestValidate(t *testing.T) {
TypeID: &ti,
LastUpdated: &lu,
}}
- err := c.Validate()
+ err = c.Validate()
if err != nil {
t.Errorf("expected nil, got %s", err)
}
diff --git
a/traffic_ops/traffic_ops_golang/deliveryservicesregexes/deliveryservicesregexes.go
b/traffic_ops/traffic_ops_golang/deliveryservicesregexes/deliveryservicesregexes.go
index 5b8caa07e..b8a515d35 100644
---
a/traffic_ops/traffic_ops_golang/deliveryservicesregexes/deliveryservicesregexes.go
+++
b/traffic_ops/traffic_ops_golang/deliveryservicesregexes/deliveryservicesregexes.go
@@ -272,39 +272,29 @@ ORDER BY dsr.set_number ASC
}
}
-func Post(dbx *sqlx.DB) http.HandlerFunc {
- db := dbx.DB
+func Post() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
- handleErrs := tc.GetHandleErrorsFunc(w, r)
- user, err := auth.GetCurrentUser(r.Context())
- if err != nil {
- handleErrs(http.StatusInternalServerError,
errors.New("unable to retrieve current user from context: "+err.Error()))
- return
- }
- params, err := api.GetCombinedParams(r)
- if err != nil {
- handleErrs(http.StatusInternalServerError,
errors.New("unable to get parameters from request: "+err.Error()))
- return
- }
- dsIDStr, ok := params["dsid"]
- if !ok {
- handleErrs(http.StatusInternalServerError,
errors.New("no deliveryservice ID"))
- return
- }
- dsID, err := strconv.Atoi(dsIDStr)
- if err != nil {
- handleErrs(http.StatusInternalServerError,
errors.New("deliveryservice ID not an integer"))
+ defer r.Body.Close()
+ paramList := []string{"dsid"}
+ inf, userErr, sysErr, errCode := api.NewInfo(r, paramList,
paramList)
+ if userErr != nil || sysErr != nil {
+ api.HandleErr(w, r, errCode, userErr, sysErr)
return
}
+ defer inf.Close()
+ tx := inf.Tx.Tx
+
+ handleErrs := tc.GetHandleErrorsFunc(w, r)
+ dsID := inf.IntParams["dsid"]
dsTenantID := 0
- if err := db.QueryRow(`SELECT tenant_id from deliveryservice
where id = $1`, dsID).Scan(&dsTenantID); err != nil {
+ if err := tx.QueryRow(`SELECT tenant_id from deliveryservice
where id = $1`, dsID).Scan(&dsTenantID); err != nil {
if err != sql.ErrNoRows {
log.Errorln("getting deliveryservice name: " +
err.Error())
}
handleErrs(http.StatusInternalServerError, err)
return
}
- if ok, err := tenant.IsResourceAuthorizedToUser(dsTenantID,
user, dbx); !ok {
+ if ok, err := tenant.IsResourceAuthorizedToUserTx(dsTenantID,
inf.User, tx); !ok {
handleErrs(http.StatusInternalServerError,
errors.New("unauthorized"))
return
} else if err != nil {
@@ -318,21 +308,26 @@ func Post(dbx *sqlx.DB) http.HandlerFunc {
return
}
+ if err := validateDSRegexType(tx, dsr.Type); err != nil {
+ handleErrs(http.StatusBadRequest, err)
+ return
+ }
+
regexID := 0
- if err := db.QueryRow(`INSERT INTO regex (pattern, type) VALUES
($1, $2) RETURNING id`, dsr.Pattern, dsr.Type).Scan(®exID); err != nil {
+ if err := tx.QueryRow(`INSERT INTO regex (pattern, type) VALUES
($1, $2) RETURNING id`, dsr.Pattern, dsr.Type).Scan(®exID); err != nil {
log.Errorln("inserting regex: " + err.Error())
handleErrs(http.StatusInternalServerError, err)
return
}
- if _, err := db.Exec(`INSERT INTO deliveryservice_regex
(deliveryservice, regex, set_number) values ($1, $2, $3)`, dsID, regexID,
dsr.SetNumber); err != nil {
+ if _, err := tx.Exec(`INSERT INTO deliveryservice_regex
(deliveryservice, regex, set_number) values ($1, $2, $3)`, dsID, regexID,
dsr.SetNumber); err != nil {
log.Errorln("inserting deliveryservice_regex: " +
err.Error())
handleErrs(http.StatusInternalServerError, err)
return
}
typeName := ""
- if err := db.QueryRow(`SELECT name from type where id = $1`,
dsr.Type).Scan(&typeName); err != nil {
+ if err := tx.QueryRow(`SELECT name from type where id = $1`,
dsr.Type).Scan(&typeName); err != nil {
if err != sql.ErrNoRows {
log.Errorln("getting regex type: " +
err.Error())
}
@@ -347,66 +342,36 @@ func Post(dbx *sqlx.DB) http.HandlerFunc {
TypeName: typeName,
SetNumber: dsr.SetNumber,
}
- resp := struct {
- Response tc.DeliveryServiceIDRegex `json:"response"`
- tc.Alerts
- }{respObj, tc.CreateAlerts(tc.SuccessLevel, "Delivery service
regex creation was successful.")}
- respBts, err := json.Marshal(&resp)
- if err != nil {
- handleErrs(http.StatusInternalServerError,
errors.New("marshalling JSON: "+err.Error()))
- return
- }
- w.Header().Set("Content-Type", "application/json")
- w.Write(respBts)
+ *inf.CommitTx = true
+ api.WriteRespAlertObj(w, r, tc.SuccessLevel, "Delivery service
regex creation was successful.", respObj)
}
}
-func Put(dbx *sqlx.DB) http.HandlerFunc {
- db := dbx.DB
+func Put() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
- handleErrs := tc.GetHandleErrorsFunc(w, r)
- user, err := auth.GetCurrentUser(r.Context())
- if err != nil {
- log.Errorf("unable to retrieve current user from
context: %s", err)
- handleErrs(http.StatusInternalServerError, err)
- return
- }
- params, err := api.GetCombinedParams(r)
- if err != nil {
- log.Errorf("unable to get parameters from request: %s",
err)
- handleErrs(http.StatusInternalServerError, err)
- return
- }
- dsIDStr, ok := params["dsid"]
- if !ok {
- handleErrs(http.StatusInternalServerError, err)
- return
- }
- dsID, err := strconv.Atoi(dsIDStr)
- if err != nil {
- handleErrs(http.StatusInternalServerError, err)
- return
- }
- regexIDStr, ok := params["regexid"]
- if !ok {
- handleErrs(http.StatusInternalServerError,
errors.New("no regex ID"))
- return
- }
- regexID, err := strconv.Atoi(regexIDStr)
- if err != nil {
- handleErrs(http.StatusInternalServerError,
errors.New("Regex ID '"+regexIDStr+"' not an integer"))
+ defer r.Body.Close()
+ paramList := []string{"dsid", "regexid"}
+ inf, userErr, sysErr, errCode := api.NewInfo(r, paramList,
paramList)
+ if userErr != nil || sysErr != nil {
+ api.HandleErr(w, r, errCode, userErr, sysErr)
return
}
+ defer inf.Close()
+ tx := inf.Tx.Tx
+
+ handleErrs := tc.GetHandleErrorsFunc(w, r)
+ dsID := inf.IntParams["dsid"]
+ regexID := inf.IntParams["regexid"]
dsTenantID := 0
- if err := db.QueryRow(`SELECT tenant_id from deliveryservice
where id = $1`, dsID).Scan(&dsTenantID); err != nil {
+ if err := tx.QueryRow(`SELECT tenant_id from deliveryservice
where id = $1`, dsID).Scan(&dsTenantID); err != nil {
if err != sql.ErrNoRows {
log.Errorln("getting deliveryservice name: " +
err.Error())
}
handleErrs(http.StatusInternalServerError, err)
return
}
- if ok, err := tenant.IsResourceAuthorizedToUser(dsTenantID,
user, dbx); !ok {
+ if ok, err := tenant.IsResourceAuthorizedToUserTx(dsTenantID,
inf.User, tx); !ok {
handleErrs(http.StatusInternalServerError,
errors.New("unauthorized"))
return
} else if err != nil {
@@ -419,18 +384,24 @@ func Put(dbx *sqlx.DB) http.HandlerFunc {
handleErrs(http.StatusInternalServerError, err)
return
}
- if _, err := db.Exec(`UPDATE regex SET pattern=$1, type=$2
WHERE id=$3`, dsr.Pattern, dsr.Type, regexID); err != nil {
+
+ if err := validateDSRegexType(tx, dsr.Type); err != nil {
+ handleErrs(http.StatusBadRequest, err)
+ return
+ }
+
+ if _, err := tx.Exec(`UPDATE regex SET pattern=$1, type=$2
WHERE id=$3`, dsr.Pattern, dsr.Type, regexID); err != nil {
log.Errorln("deliveryservicesregexes.Put: updating
regex: " + err.Error())
handleErrs(http.StatusInternalServerError,
errors.New("server error"))
return
}
- if _, err := db.Exec(`UPDATE deliveryservice_regex SET
set_number=$1 WHERE deliveryservice=$2 AND regex=$3`, dsr.SetNumber, dsID,
regexID); err != nil {
+ if _, err := tx.Exec(`UPDATE deliveryservice_regex SET
set_number=$1 WHERE deliveryservice=$2 AND regex=$3`, dsr.SetNumber, dsID,
regexID); err != nil {
log.Errorln("deliveryservicesregexes.Put: updating
deliveryservice_regex: " + err.Error())
handleErrs(http.StatusInternalServerError,
errors.New("server error"))
return
}
typeName := ""
- if err := db.QueryRow(`SELECT name from type where id = $1`,
dsr.Type).Scan(&typeName); err != nil {
+ if err := tx.QueryRow(`SELECT name from type where id = $1`,
dsr.Type).Scan(&typeName); err != nil {
if err != sql.ErrNoRows {
log.Errorln("getting regex type: " +
err.Error())
}
@@ -444,20 +415,17 @@ func Put(dbx *sqlx.DB) http.HandlerFunc {
TypeName: typeName,
SetNumber: dsr.SetNumber,
}
- resp := struct {
- Response tc.DeliveryServiceIDRegex `json:"response"`
- tc.Alerts
- }{respObj, tc.CreateAlerts(tc.SuccessLevel, "Delivery service
regex creation was successful.")}
- respBts, err := json.Marshal(&resp)
- if err != nil {
- handleErrs(http.StatusInternalServerError,
errors.New("marshalling JSON: "+err.Error()))
- return
- }
- w.Header().Set("Content-Type", "application/json")
- w.Write(respBts)
+
+ *inf.CommitTx = true
+ api.WriteRespAlertObj(w, r, tc.SuccessLevel, "Delivery service
regex creation was successful.", respObj)
}
}
+func validateDSRegexType(tx *sql.Tx, typeID int) error {
+ _, err := tc.ValidateTypeID(tx, &typeID, "regex")
+ return err
+}
+
func Delete(dbx *sqlx.DB) http.HandlerFunc {
db := dbx.DB
return func(w http.ResponseWriter, r *http.Request) {
diff --git a/traffic_ops/traffic_ops_golang/routes.go
b/traffic_ops/traffic_ops_golang/routes.go
index baab83ab9..6394ec21b 100644
--- a/traffic_ops/traffic_ops_golang/routes.go
+++ b/traffic_ops/traffic_ops_golang/routes.go
@@ -309,8 +309,8 @@ func Routes(d ServerData) ([]Route, []RawRoute,
http.Handler, error) {
{1.1, http.MethodGet, `deliveryservices_regexes/?(\.json)?$`,
deliveryservicesregexes.Get(d.DB), auth.PrivLevelReadOnly, Authenticated, nil},
{1.1, http.MethodGet,
`deliveryservices/{dsid}/regexes/?(\.json)?$`,
deliveryservicesregexes.DSGet(d.DB), auth.PrivLevelReadOnly, Authenticated,
nil},
{1.1, http.MethodGet,
`deliveryservices/{dsid}/regexes/{regexid}?(\.json)?$`,
deliveryservicesregexes.DSGetID(d.DB), auth.PrivLevelReadOnly, Authenticated,
nil},
- {1.1, http.MethodPost,
`deliveryservices/{dsid}/regexes/?(\.json)?$`,
deliveryservicesregexes.Post(d.DB), auth.PrivLevelOperations, Authenticated,
nil},
- {1.1, http.MethodPut,
`deliveryservices/{dsid}/regexes/{regexid}?(\.json)?$`,
deliveryservicesregexes.Put(d.DB), auth.PrivLevelOperations, Authenticated,
nil},
+ {1.1, http.MethodPost,
`deliveryservices/{dsid}/regexes/?(\.json)?$`, deliveryservicesregexes.Post(),
auth.PrivLevelOperations, Authenticated, nil},
+ {1.1, http.MethodPut,
`deliveryservices/{dsid}/regexes/{regexid}?(\.json)?$`,
deliveryservicesregexes.Put(), auth.PrivLevelOperations, Authenticated, nil},
{1.1, http.MethodDelete,
`deliveryservices/{dsid}/regexes/{regexid}?(\.json)?$`,
deliveryservicesregexes.Delete(d.DB), auth.PrivLevelOperations, Authenticated,
nil},
//Servers
diff --git a/traffic_ops/traffic_ops_golang/server/servers.go
b/traffic_ops/traffic_ops_golang/server/servers.go
index 91ef61967..e766f0fa5 100644
--- a/traffic_ops/traffic_ops_golang/server/servers.go
+++ b/traffic_ops/traffic_ops_golang/server/servers.go
@@ -33,7 +33,7 @@ import (
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/auth"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/dbhelpers"
- validation "github.com/go-ozzo/ozzo-validation"
+ "github.com/go-ozzo/ozzo-validation"
"github.com/go-ozzo/ozzo-validation/is"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
@@ -108,24 +108,11 @@ func (server *TOServer) Validate() error {
return util.JoinErrs(errs)
}
- rows, err := server.ReqInfo.Tx.Query("select use_in_table from type
where id=$1", server.TypeID)
- if err != nil {
- log.Error.Printf("could not execute select use_in_table from
type: %s\n", err)
- return tc.DBError
- }
- defer rows.Close()
- var useInTable string
- for rows.Next() {
- if err := rows.Scan(&useInTable); err != nil {
- log.Error.Printf("could not scan use_in_table from
type: %s\n", err)
- return tc.DBError
- }
- }
- if useInTable != "server" {
- errs = append(errs, errors.New("invalid server type"))
+ if _, err := tc.ValidateTypeID(server.ReqInfo.Tx.Tx, server.TypeID,
"server"); err != nil {
+ return err
}
- rows, err = server.ReqInfo.Tx.Query("select cdn from profile where
id=$1", server.ProfileID)
+ rows, err := server.ReqInfo.Tx.Query("select cdn from profile where
id=$1", server.ProfileID)
if err != nil {
log.Error.Printf("could not execute select cdnID from profile:
%s\n", err)
errs = append(errs, tc.DBError)
diff --git a/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry.go
b/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry.go
index 5f213b006..bcadeae87 100644
--- a/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry.go
+++ b/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry.go
@@ -80,6 +80,10 @@ func (staticDNSEntry *TOStaticDNSEntry) SetKeys(keys
map[string]interface{}) {
// Validate fulfills the api.Validator interface
func (staticDNSEntry TOStaticDNSEntry) Validate() error {
+ if _, err := tc.ValidateTypeID(staticDNSEntry.ReqInfo.Tx.Tx,
&staticDNSEntry.TypeID, "staticdnsentry"); err != nil {
+ return err
+ }
+
errs := validation.Errors{
"host": validation.Validate(staticDNSEntry.Host,
validation.Required),
"address": validation.Validate(staticDNSEntry.Address,
validation.Required),
diff --git
a/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry_test.go
b/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry_test.go
index 122062a48..af227dfad 100644
--- a/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry_test.go
+++ b/traffic_ops/traffic_ops_golang/staticdnsentry/staticdnsentry_test.go
@@ -28,6 +28,8 @@ import (
util "github.com/apache/trafficcontrol/lib/go-util"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/api"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/test"
+ "github.com/jmoiron/sqlx"
+ "gopkg.in/DATA-DOG/go-sqlmock.v1"
)
func TestFuncs(t *testing.T) {
@@ -68,8 +70,25 @@ func TestInterfaces(t *testing.T) {
}
func TestValidate(t *testing.T) {
+ mockDB, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub
database connection", err)
+ }
+ defer mockDB.Close()
+
+ db := sqlx.NewDb(mockDB, "sqlmock")
+ defer db.Close()
+
+ rows := sqlmock.NewRows([]string{"name", "use_in_table"})
+ rows.AddRow("A_RECORD", "staticdnsentry")
+
+ mock.ExpectBegin()
+ mock.ExpectQuery("SELECT").WillReturnRows(rows)
+ tx := db.MustBegin()
+
+ reqInfo := api.APIInfo{Tx: tx, CommitTx: util.BoolPtr(false)}
// invalid name, empty domainname
- ts := TOStaticDNSEntry{}
+ ts := TOStaticDNSEntry{ReqInfo: &reqInfo}
errs :=
util.JoinErrsStr(test.SortErrors(test.SplitErrors(ts.Validate())))
expectedErrs := util.JoinErrsStr([]error{
@@ -83,8 +102,4 @@ func TestValidate(t *testing.T) {
if !reflect.DeepEqual(expectedErrs, errs) {
t.Errorf("expected %s, GOT %s", expectedErrs, errs)
}
- //if err != nil {
- //t.Errorf("expected nil, got %s", err)
- //}
-
}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services