rob05c closed pull request #2375: Fix go delivery service API validation
URL: https://github.com/apache/incubator-trafficcontrol/pull/2375
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv12.go 
b/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv12.go
index 67eb3ce5f..7184c7137 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv12.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv12.go
@@ -218,19 +218,17 @@ func validateV12(db *sqlx.DB, ds 
*tc.DeliveryServiceNullableV12) []error {
                "typeId":              validation.Validate(ds.TypeID, 
validation.Required, validation.Min(1)),
                "xmlId":               validation.Validate(ds.XMLID, noSpaces, 
noPeriods, validation.Length(1, 48)),
        }
-       if errs != nil {
-               return tovalidate.ToErrors(errs)
+       toErrs := tovalidate.ToErrors(errs)
+       if fieldErrs := validateTypeFields(db, ds); len(fieldErrs) > 0 {
+               toErrs = append(toErrs, fieldErrs...)
        }
-       errsResponse := validateTypeFields(db, ds)
-       if errsResponse != nil {
-               return errsResponse
+       if len(toErrs) > 0 {
+               return toErrs
        }
-
        return nil
 }
 
 func validateTypeFields(db *sqlx.DB, ds *tc.DeliveryServiceNullableV12) 
[]error {
-       fmt.Printf("validateTypeFields\n")
        // Validate the TypeName related fields below
        var typeName string
        var err error
@@ -238,43 +236,67 @@ func validateTypeFields(db *sqlx.DB, ds 
*tc.DeliveryServiceNullableV12) []error
        HTTPRegexType := "^HTTP.*$"
        SteeringRegexType := "^STEERING.*$"
 
-       if db != nil && ds == nil || ds.TypeID != nil {
-               typeID := *ds.TypeID
-               typeName, err = getTypeName(db, typeID)
-               if err != nil {
-                       return []error{err}
-               }
+       if ds.TypeID == nil {
+               return []error{errors.New("missing typeID")}
        }
 
-       if typeName != "" {
-               errs := validation.Errors{
-                       "initialDispersion": 
validation.Validate(ds.InitialDispersion,
-                               
validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, 
typeName))),
-                       "ipv6RoutingEnabled": 
validation.Validate(ds.IPV6RoutingEnabled,
-                               
validation.By(requiredIfMatchesTypeName([]string{SteeringRegexType, 
DNSRegexType, HTTPRegexType}, typeName))),
-                       "missLat": validation.Validate(ds.MissLat,
-                               
validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, 
typeName))),
-                       "missLong": validation.Validate(ds.MissLong,
-                               
validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, 
typeName))),
-                       "multiSiteOrigin": 
validation.Validate(ds.MultiSiteOrigin,
-                               
validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, 
typeName))),
-                       "orgServerFqdn": validation.Validate(ds.OrgServerFQDN,
-                               
validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, 
typeName))),
-                       "protocol": validation.Validate(ds.Protocol,
-                               
validation.By(requiredIfMatchesTypeName([]string{SteeringRegexType, 
DNSRegexType, HTTPRegexType}, typeName))),
-                       "qstringIgnore": validation.Validate(ds.QStringIgnore,
-                               
validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, 
typeName))),
-                       "rangeRequestHandling": 
validation.Validate(ds.RangeRequestHandling,
-                               
validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, 
typeName))),
-               }
-               return tovalidate.ToErrors(errs)
+       typeName, ok, err := getTypeName(db, *ds.TypeID)
+       if err != nil {
+               return []error{err}
+       }
+       if !ok {
+               return []error{errors.New("type not found")}
+       }
+
+       errs := validation.Errors{
+               "initialDispersion": validation.Validate(ds.InitialDispersion,
+                       
validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, 
typeName))),
+               "ipv6RoutingEnabled": validation.Validate(ds.IPV6RoutingEnabled,
+                       
validation.By(requiredIfMatchesTypeName([]string{SteeringRegexType, 
DNSRegexType, HTTPRegexType}, typeName))),
+               "missLat": validation.Validate(ds.MissLat,
+                       
validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, 
typeName))),
+               "missLong": validation.Validate(ds.MissLong,
+                       
validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, 
typeName))),
+               "multiSiteOrigin": validation.Validate(ds.MultiSiteOrigin,
+                       
validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, 
typeName))),
+               "orgServerFqdn": validation.Validate(ds.OrgServerFQDN,
+                       
validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, 
typeName))),
+               "protocol": validation.Validate(ds.Protocol,
+                       
validation.By(requiredIfMatchesTypeName([]string{SteeringRegexType, 
DNSRegexType, HTTPRegexType}, typeName))),
+               "qstringIgnore": validation.Validate(ds.QStringIgnore,
+                       
validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, 
typeName))),
+               "rangeRequestHandling": 
validation.Validate(ds.RangeRequestHandling,
+                       
validation.By(requiredIfMatchesTypeName([]string{DNSRegexType, HTTPRegexType}, 
typeName))),
+       }
+       toErrs := tovalidate.ToErrors(errs)
+       if len(toErrs) > 0 {
+               return toErrs
        }
        return nil
 }
 
 func requiredIfMatchesTypeName(patterns []string, typeName string) 
func(interface{}) error {
        return func(value interface{}) error {
-
+               switch v := value.(type) {
+               case *int:
+                       if v != nil {
+                               return nil
+                       }
+               case *bool:
+                       if v != nil {
+                               return nil
+                       }
+               case *string:
+                       if v != nil {
+                               return nil
+                       }
+               case *float64:
+                       if v != nil {
+                               return nil
+                       }
+               default:
+                       return fmt.Errorf("validation failure: unknown type 
%T", value)
+               }
                pattern := strings.Join(patterns, "|")
                var err error
                var match bool
@@ -288,31 +310,15 @@ func requiredIfMatchesTypeName(patterns []string, 
typeName string) func(interfac
        }
 }
 
-// TODO: drichardson - refactor to the types.go once implemented.
-func getTypeName(db *sqlx.DB, typeID int) (string, error) {
-
-       query := `SELECT name from type where id=$1`
-
-       var rows *sqlx.Rows
-       var err error
-
-       rows, err = db.Queryx(query, typeID)
-       if err != nil {
-               return "", err
-       }
-       defer rows.Close()
-
-       typeResults := []tc.Type{}
-       for rows.Next() {
-               var s tc.Type
-               if err = rows.StructScan(&s); err != nil {
-                       return "", fmt.Errorf("getting Type: %v", err)
+func getTypeName(db *sqlx.DB, typeID int) (string, bool, error) {
+       name := ""
+       if err := db.QueryRow(`SELECT name from type where id=$1`, 
typeID).Scan(&name); err != nil {
+               if err == sql.ErrNoRows {
+                       return "", false, nil
                }
-               typeResults = append(typeResults, s)
+               return "", false, errors.New("querying type name: " + 
err.Error())
        }
-
-       typeName := typeResults[0].Name
-       return typeName, err
+       return name, true, nil
 }
 
 func CreateV12(db *sqlx.DB, cfg config.Config) http.HandlerFunc {
diff --git 
a/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv13.go 
b/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv13.go
index 55727a77e..6244acf8a 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv13.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservice/deliveryservicesv13.go
@@ -375,6 +375,11 @@ func UpdateV13(db *sqlx.DB, cfg config.Config) 
http.HandlerFunc {
                }
                ds.ID = &id
 
+               if errs := validateV13(db, &ds); len(errs) > 0 {
+                       api.HandleErr(w, r, http.StatusBadRequest, 
errors.New("invalid request: "+util.JoinErrs(errs).Error()), nil)
+                       return
+               }
+
                if authorized, err := isTenantAuthorized(*user, db, 
&ds.DeliveryServiceNullableV12); err != nil {
                        api.HandleErr(w, r, http.StatusInternalServerError, 
nil, errors.New("checking tenant: "+err.Error()))
                        return


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to