This is an automated email from the ASF dual-hosted git repository.

ocket8888 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/trafficcontrol.git


The following commit(s) were added to refs/heads/master by this push:
     new 74bc7f2  Fix API (deliveryserviceserver and 
deliveryservices/dsName/servers) should not assign Server from different CDN to 
Delivery Service (#4754)
74bc7f2 is described below

commit 74bc7f2b12ce715aa6e44323af815ea9a68c43fa
Author: Srijeet Chatterjee <[email protected]>
AuthorDate: Thu Jun 11 10:55:59 2020 -0600

    Fix API (deliveryserviceserver and deliveryservices/dsName/servers) should 
not assign Server from different CDN to Delivery Service (#4754)
    
    * Fix API (deliveryserviceserver and deliveryservices/dsName/servers) 
should not assign Server from different CDN to Delivery Service
    
    * Formatting
    
    * Code review
---
 .../traffic_ops_golang/dbhelpers/db_helpers.go     | 24 ++++++----
 .../deliveryservice/servers/servers.go             | 33 ++++++++------
 .../deliveryservice/servers/servers_test.go        | 53 ++++++++++++++++++++++
 3 files changed, 87 insertions(+), 23 deletions(-)

diff --git a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go 
b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
index 8ee6c28..cdb46f0 100644
--- a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
+++ b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
@@ -644,15 +644,18 @@ func GetServerNameFromID(tx *sql.Tx, id int) (string, 
bool, error) {
        return name, true, nil
 }
 
-type ServerHostNameAndType struct {
+type ServerHostNameCDNIDAndType struct {
        HostName string
+       CDNID    int
        Type     string
 }
 
-func GetServerHostNamesAndTypesFromIDs(tx *sql.Tx, ids []int) 
([]ServerHostNameAndType, error) {
+// GetServerHostNamesAndTypesFromIDs returns the server's hostname, cdn ID and 
associated type name
+func GetServerHostNamesAndTypesFromIDs(tx *sql.Tx, ids []int) 
([]ServerHostNameCDNIDAndType, error) {
        qry := `
 SELECT
   s.host_name,
+  s.cdn_id,
   t.name
 FROM
   server s JOIN type t ON s.type = t.id
@@ -665,10 +668,10 @@ WHERE
        }
        defer log.Close(rows, "error closing rows")
 
-       servers := []ServerHostNameAndType{}
+       servers := []ServerHostNameCDNIDAndType{}
        for rows.Next() {
-               s := ServerHostNameAndType{}
-               if err := rows.Scan(&s.HostName, &s.Type); err != nil {
+               s := ServerHostNameCDNIDAndType{}
+               if err := rows.Scan(&s.HostName, &s.CDNID, &s.Type); err != nil 
{
                        return nil, errors.New("scanning server host name and 
type: " + err.Error())
                }
                servers = append(servers, s)
@@ -676,11 +679,12 @@ WHERE
        return servers, nil
 }
 
-// GetServerTypesFromHostNames returns the host names and types of the given 
server host names or an error if any occur.
-func GetServerTypesFromHostNames(tx *sql.Tx, hostNames []string) 
([]ServerHostNameAndType, error) {
+// GetServerTypesCdnIdFromHostNames returns the host names, server cdn and 
types of the given server host names or an error if any occur.
+func GetServerTypesCdnIdFromHostNames(tx *sql.Tx, hostNames []string) 
([]ServerHostNameCDNIDAndType, error) {
        qry := `
 SELECT
   s.host_name,
+  s.cdn_id,
   t.name
 FROM
   server s JOIN type t ON s.type = t.id
@@ -693,10 +697,10 @@ WHERE
        }
        defer log.Close(rows, "error closing rows")
 
-       servers := []ServerHostNameAndType{}
+       servers := []ServerHostNameCDNIDAndType{}
        for rows.Next() {
-               s := ServerHostNameAndType{}
-               if err := rows.Scan(&s.HostName, &s.Type); err != nil {
+               s := ServerHostNameCDNIDAndType{}
+               if err := rows.Scan(&s.HostName, &s.CDNID, &s.Type); err != nil 
{
                        return nil, errors.New("scanning server host name and 
type: " + err.Error())
                }
                servers = append(servers, s)
diff --git a/traffic_ops/traffic_ops_golang/deliveryservice/servers/servers.go 
b/traffic_ops/traffic_ops_golang/deliveryservice/servers/servers.go
index 191fb17..99d4bc4 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservice/servers/servers.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservice/servers/servers.go
@@ -319,19 +319,18 @@ func GetReplaceHandler(w http.ResponseWriter, r 
*http.Request) {
                api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
                return
        }
-       serverNamesAndTypes, err := 
dbhelpers.GetServerHostNamesAndTypesFromIDs(inf.Tx.Tx, servers)
+       serverNamesCdnIdAndTypes, err := 
dbhelpers.GetServerHostNamesAndTypesFromIDs(inf.Tx.Tx, servers)
        if err != nil {
                api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, 
err, nil)
                return
        }
-
-       userErr = ValidateDSSAssignments(ds, serverNamesAndTypes)
+       userErr = ValidateDSSAssignments(ds, serverNamesCdnIdAndTypes)
        if userErr != nil {
                api.HandleErr(w, r, inf.Tx.Tx, http.StatusBadRequest, userErr, 
nil)
                return
        }
 
-       usrErr, sysErr, status := ValidateServerCapabilities(ds.ID, 
serverNamesAndTypes, inf.Tx.Tx)
+       usrErr, sysErr, status := ValidateServerCapabilities(ds.ID, 
serverNamesCdnIdAndTypes, inf.Tx.Tx)
        if usrErr != nil || sysErr != nil {
                api.HandleErr(w, r, inf.Tx.Tx, status, usrErr, sysErr)
                return
@@ -401,19 +400,19 @@ func GetCreateHandler(w http.ResponseWriter, r 
*http.Request) {
        payload.XmlId = dsName
        serverNames := payload.ServerNames
 
-       serverNamesAndTypes, err := 
dbhelpers.GetServerTypesFromHostNames(inf.Tx.Tx, serverNames)
+       serverNamesCdnIdAndTypes, err := 
dbhelpers.GetServerTypesCdnIdFromHostNames(inf.Tx.Tx, serverNames)
        if err != nil {
                api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, 
err, nil)
                return
        }
 
-       userErr = ValidateDSSAssignments(ds, serverNamesAndTypes)
+       userErr = ValidateDSSAssignments(ds, serverNamesCdnIdAndTypes)
        if userErr != nil {
                api.HandleErr(w, r, inf.Tx.Tx, http.StatusBadRequest, userErr, 
nil)
                return
        }
 
-       usrErr, sysErr, status := ValidateServerCapabilities(ds.ID, 
serverNamesAndTypes, inf.Tx.Tx)
+       usrErr, sysErr, status := ValidateServerCapabilities(ds.ID, 
serverNamesCdnIdAndTypes, inf.Tx.Tx)
        if usrErr != nil || sysErr != nil {
                api.HandleErr(w, r, inf.Tx.Tx, status, usrErr, sysErr)
                return
@@ -445,8 +444,13 @@ func GetCreateHandler(w http.ResponseWriter, r 
*http.Request) {
 }
 
 // ValidateDSSAssignments returns an error if the given servers cannot be 
assigned to the given delivery service.
-func ValidateDSSAssignments(ds DSInfo, servers 
[]dbhelpers.ServerHostNameAndType) error {
+func ValidateDSSAssignments(ds DSInfo, servers 
[]dbhelpers.ServerHostNameCDNIDAndType) error {
        if ds.Topology == nil {
+               for _, s := range servers {
+                       if ds.CDNID != nil && s.CDNID != *ds.CDNID {
+                               return errors.New("server and delivery service 
CDNs do not match")
+                       }
+               }
                return nil
        }
        for _, s := range servers {
@@ -458,7 +462,7 @@ func ValidateDSSAssignments(ds DSInfo, servers 
[]dbhelpers.ServerHostNameAndType
 }
 
 // ValidateServerCapabilities checks that the delivery service's requirements 
are met by each server to be assigned.
-func ValidateServerCapabilities(dsID int, serverNamesAndTypes 
[]dbhelpers.ServerHostNameAndType, tx *sql.Tx) (error, error, int) {
+func ValidateServerCapabilities(dsID int, serverNamesAndTypes 
[]dbhelpers.ServerHostNameCDNIDAndType, tx *sql.Tx) (error, error, int) {
        nonOriginServerNames := []string{}
        for _, s := range serverNamesAndTypes {
                if strings.HasPrefix(s.Type, tc.EdgeTypePrefix) {
@@ -682,6 +686,7 @@ type DSInfo struct {
        CacheURL             *string
        MaxOriginConnections *int
        Topology             *string
+       CDNID                *int
 }
 
 // GetDSInfo loads the DeliveryService fields needed by Delivery Service 
Servers from the database, from the ID. Returns the data, whether the delivery 
service was found, and any error.
@@ -696,7 +701,8 @@ SELECT
   ds.signing_algorithm,
   ds.cacheurl,
   ds.max_origin_connections,
-  ds.topology
+  ds.topology,
+  ds.cdn_id
 FROM
   deliveryservice ds
   JOIN type tp ON ds.type = tp.id
@@ -704,7 +710,7 @@ WHERE
   ds.id = $1
 `
        di := DSInfo{ID: id}
-       if err := tx.QueryRow(qry, id).Scan(&di.Name, &di.Type, 
&di.EdgeHeaderRewrite, &di.MidHeaderRewrite, &di.RegexRemap, 
&di.SigningAlgorithm, &di.CacheURL, &di.MaxOriginConnections, &di.Topology); 
err != nil {
+       if err := tx.QueryRow(qry, id).Scan(&di.Name, &di.Type, 
&di.EdgeHeaderRewrite, &di.MidHeaderRewrite, &di.RegexRemap, 
&di.SigningAlgorithm, &di.CacheURL, &di.MaxOriginConnections, &di.Topology, 
&di.CDNID); err != nil {
                if err == sql.ErrNoRows {
                        return DSInfo{}, false, nil
                }
@@ -726,7 +732,8 @@ SELECT
   ds.signing_algorithm,
   ds.cacheurl,
   ds.max_origin_connections,
-  ds.topology
+  ds.topology,
+  ds.cdn_id
 FROM
   deliveryservice ds
   JOIN type tp ON ds.type = tp.id
@@ -734,7 +741,7 @@ WHERE
   ds.xml_id = $1
 `
        di := DSInfo{Name: dsName}
-       if err := tx.QueryRow(qry, dsName).Scan(&di.ID, &di.Type, 
&di.EdgeHeaderRewrite, &di.MidHeaderRewrite, &di.RegexRemap, 
&di.SigningAlgorithm, &di.CacheURL, &di.MaxOriginConnections, &di.Topology); 
err != nil {
+       if err := tx.QueryRow(qry, dsName).Scan(&di.ID, &di.Type, 
&di.EdgeHeaderRewrite, &di.MidHeaderRewrite, &di.RegexRemap, 
&di.SigningAlgorithm, &di.CacheURL, &di.MaxOriginConnections, &di.Topology, 
&di.CDNID); err != nil {
                if err == sql.ErrNoRows {
                        return DSInfo{}, false, nil
                }
diff --git 
a/traffic_ops/traffic_ops_golang/deliveryservice/servers/servers_test.go 
b/traffic_ops/traffic_ops_golang/deliveryservice/servers/servers_test.go
new file mode 100644
index 0000000..319dfc5
--- /dev/null
+++ b/traffic_ops/traffic_ops_golang/deliveryservice/servers/servers_test.go
@@ -0,0 +1,53 @@
+package servers
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+import (
+       
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/dbhelpers"
+       "testing"
+)
+
+func TestValidateDSSAssignments(t *testing.T) {
+       expected := `server and delivery service CDNs do not match`
+       cdnID := 1
+       ds := DSInfo{
+               ID:    0,
+               CDNID: &cdnID,
+       }
+       var servers []dbhelpers.ServerHostNameCDNIDAndType
+       server := dbhelpers.ServerHostNameCDNIDAndType{
+               HostName: "serverHost",
+               CDNID:    0,
+               Type:     "",
+       }
+       servers = append(servers, server)
+       userErr := ValidateDSSAssignments(ds, servers)
+       if userErr == nil {
+               t.Fatalf("Expected user error with mismatching ds and server 
CDN IDs, got no error instead")
+       }
+       if userErr.Error() != expected {
+               t.Errorf("Expected error details %v, got %v", expected, 
userErr.Error())
+       }
+       servers[0].CDNID = 1
+       userErr = ValidateDSSAssignments(ds, servers)
+       if userErr != nil {
+               t.Fatalf("Expected no user error, got %v", userErr.Error())
+       }
+}

Reply via email to