This is an automated email from the ASF dual-hosted git repository.

ocket8888 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/trafficcontrol.git


The following commit(s) were added to refs/heads/master by this push:
     new c0af35f  updated Lets Encrypt endpoint to perform checks (#4540)
c0af35f is described below

commit c0af35f0c82d3cb22e168a63372d3730e2af1670
Author: mattjackson220 <[email protected]>
AuthorDate: Thu Mar 26 15:16:17 2020 -0600

    updated Lets Encrypt endpoint to perform checks (#4540)
    
    * updated Lets Encrypt endpoint to perform checks
    
    * update per comment
    
    * updated godoc
    
    Co-authored-by: mjacks258 <[email protected]>
---
 .../traffic_ops_golang/dbhelpers/db_helpers.go     |  30 +++++
 .../dbhelpers/db_helpers_test.go                   | 121 +++++++++++++++++++++
 .../deliveryservice/letsencryptcert.go             |  31 ++++++
 3 files changed, 182 insertions(+)

diff --git a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go 
b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
index 0a7f732..aa9af11 100644
--- a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
+++ b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
@@ -324,6 +324,24 @@ WHERE ds.id = $1
        return name, cdn, true, nil
 }
 
+// GetDSIDAndCDNFromName returns the delivery service ID and cdn name given 
from the delivery service name, whether a result existed, and any error.
+func GetDSIDAndCDNFromName(tx *sql.Tx, xmlID string) (int, tc.CDNName, bool, 
error) {
+       dsId := 0
+       cdn := tc.CDNName("")
+       if err := tx.QueryRow(`
+SELECT ds.id, cdn.name
+FROM deliveryservice as ds
+JOIN cdn on cdn.id = ds.cdn_id
+WHERE ds.xml_id = $1
+`, xmlID).Scan(&dsId, &cdn); err != nil {
+               if err == sql.ErrNoRows {
+                       return dsId, tc.CDNName(""), false, nil
+               }
+               return dsId, tc.CDNName(""), false, errors.New("querying 
delivery service name: " + err.Error())
+       }
+       return dsId, cdn, true, nil
+}
+
 // GetFederationResolversByFederationID fetches all of the federation 
resolvers currently assigned to a federation.
 // In the event of an error, it will return an empty slice and the error.
 func GetFederationResolversByFederationID(tx *sql.Tx, fedID int) 
([]tc.FederationResolver, error) {
@@ -463,6 +481,18 @@ func GetCDNNameFromID(tx *sql.Tx, id int64) (tc.CDNName, 
bool, error) {
        return tc.CDNName(name), true, nil
 }
 
+// GetCDNIDFromName returns the ID of the CDN if a CDN with the name exists
+func GetCDNIDFromName(tx *sql.Tx, name tc.CDNName) (int, bool, error) {
+       id := 0
+       if err := tx.QueryRow(`SELECT id FROM cdn WHERE name = $1`, 
name).Scan(&id); err != nil {
+               if err == sql.ErrNoRows {
+                       return id, false, nil
+               }
+               return id, false, errors.New("querying CDN ID: " + err.Error())
+       }
+       return id, true, nil
+}
+
 // GetCDNDomainFromName returns the domain, whether the cdn exists, and any 
error.
 func GetCDNDomainFromName(tx *sql.Tx, cdnName tc.CDNName) (string, bool, 
error) {
        domain := ""
diff --git a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go 
b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go
index 8ec02a3..3826d7a 100644
--- a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go
+++ b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go
@@ -140,3 +140,124 @@ func TestGetCacheGroupByName(t *testing.T) {
        }
 
 }
+
+func TestGetDSIDAndCDNFromName(t *testing.T) {
+       var testCases = []struct {
+               description  string
+               storageError error
+               found        bool
+       }{
+               {
+                       description:  "Success: DS ID and CDN Name found",
+                       storageError: nil,
+                       found:        true,
+               },
+               {
+                       description:  "Failure: DS ID or CDN Name not found",
+                       storageError: nil,
+                       found:        false,
+               },
+               {
+                       description:  "Failure: Storage error getting DS ID or 
CDN Name",
+                       storageError: errors.New("error getting the delivery 
service ID or the CDN name"),
+                       found:        false,
+               },
+       }
+       for _, testCase := range testCases {
+               t.Run(testCase.description, func(t *testing.T) {
+                       t.Log("Starting test scenario: ", testCase.description)
+                       mockDB, mock, err := sqlmock.New()
+                       if err != nil {
+                               t.Fatalf("an error '%s' was not expected when 
opening a stub database connection", err)
+                       }
+                       defer mockDB.Close()
+                       db := sqlx.NewDb(mockDB, "sqlmock")
+                       defer db.Close()
+                       rows := sqlmock.NewRows([]string{
+                               "id",
+                               "name",
+                       })
+                       mock.ExpectBegin()
+                       if testCase.storageError != nil {
+                               
mock.ExpectQuery("SELECT").WillReturnError(testCase.storageError)
+                       } else {
+                               if testCase.found {
+                                       rows = rows.AddRow(1, "testCdn")
+                               }
+                               mock.ExpectQuery("SELECT").WillReturnRows(rows)
+                       }
+                       mock.ExpectCommit()
+                       _, _, exists, err := 
GetDSIDAndCDNFromName(db.MustBegin().Tx, "testDs")
+                       if testCase.storageError != nil && err == nil {
+                               t.Errorf("Storage error expected: received no 
storage error")
+                       }
+                       if testCase.storageError == nil && err != nil {
+                               t.Errorf("Storage error not expected: received 
storage error")
+                       }
+                       if testCase.found != exists {
+                               t.Errorf("Expected return exists: %t, actual 
%t", testCase.found, exists)
+                       }
+               })
+       }
+
+}
+
+func TestGetCDNIDFromName(t *testing.T) {
+       var testCases = []struct {
+               description  string
+               storageError error
+               found        bool
+       }{
+               {
+                       description:  "Success: CDN ID found",
+                       storageError: nil,
+                       found:        true,
+               },
+               {
+                       description:  "Failure: CDN ID not found",
+                       storageError: nil,
+                       found:        false,
+               },
+               {
+                       description:  "Failure: Storage error getting CDN ID",
+                       storageError: errors.New("error getting the CDN ID"),
+                       found:        false,
+               },
+       }
+       for _, testCase := range testCases {
+               t.Run(testCase.description, func(t *testing.T) {
+                       t.Log("Starting test scenario: ", testCase.description)
+                       mockDB, mock, err := sqlmock.New()
+                       if err != nil {
+                               t.Fatalf("an error '%s' was not expected when 
opening a stub database connection", err)
+                       }
+                       defer mockDB.Close()
+                       db := sqlx.NewDb(mockDB, "sqlmock")
+                       defer db.Close()
+                       rows := sqlmock.NewRows([]string{
+                               "id",
+                       })
+                       mock.ExpectBegin()
+                       if testCase.storageError != nil {
+                               
mock.ExpectQuery("SELECT").WillReturnError(testCase.storageError)
+                       } else {
+                               if testCase.found {
+                                       rows = rows.AddRow(1)
+                               }
+                               mock.ExpectQuery("SELECT").WillReturnRows(rows)
+                       }
+                       mock.ExpectCommit()
+                       _, exists, err := GetCDNIDFromName(db.MustBegin().Tx, 
"testCdn")
+                       if testCase.storageError != nil && err == nil {
+                               t.Errorf("Storage error expected: received no 
storage error")
+                       }
+                       if testCase.storageError == nil && err != nil {
+                               t.Errorf("Storage error not expected: received 
storage error")
+                       }
+                       if testCase.found != exists {
+                               t.Errorf("Expected return exists: %t, actual 
%t", testCase.found, exists)
+                       }
+               })
+       }
+
+}
diff --git a/traffic_ops/traffic_ops_golang/deliveryservice/letsencryptcert.go 
b/traffic_ops/traffic_ops_golang/deliveryservice/letsencryptcert.go
index 64e1aff..fc7ff49 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservice/letsencryptcert.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservice/letsencryptcert.go
@@ -38,7 +38,9 @@ import (
        "github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/api"
        "github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/auth"
        "github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/config"
+       
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/dbhelpers"
        
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/riaksvc"
+       "github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/tenant"
        "github.com/go-acme/lego/certcrypto"
        "github.com/go-acme/lego/certificate"
        "github.com/go-acme/lego/challenge"
@@ -149,6 +151,35 @@ func GenerateLetsEncryptCertificates(w 
http.ResponseWriter, r *http.Request) {
                req.DeliveryService = req.Key
        }
 
+       dsID, cdnName, ok, err := dbhelpers.GetDSIDAndCDNFromName(inf.Tx.Tx, 
*req.DeliveryService)
+       if err != nil {
+               api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, 
nil, errors.New("deliveryservice.GenerateLetsEncryptCertificates: getting DS ID 
from name "+err.Error()))
+               return
+       } else if !ok {
+               api.HandleErr(w, r, inf.Tx.Tx, http.StatusNotFound, 
errors.New("no DS with name "+*req.DeliveryService), nil)
+               return
+       }
+
+       userErr, sysErr, errCode = tenant.CheckID(inf.Tx.Tx, inf.User, dsID)
+       if userErr != nil || sysErr != nil {
+               api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
+               return
+       }
+
+       _, ok, err = dbhelpers.GetCDNIDFromName(inf.Tx.Tx, tc.CDNName(*req.CDN))
+       if err != nil {
+               api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, 
nil, errors.New("checking CDN existence: "+err.Error()))
+               return
+       } else if !ok {
+               api.HandleErr(w, r, inf.Tx.Tx, http.StatusNotFound, 
errors.New("cdn not found with name "+*req.CDN), nil)
+               return
+       }
+
+       if cdnName != tc.CDNName(*req.CDN) {
+               api.HandleErr(w, r, inf.Tx.Tx, http.StatusBadRequest, 
errors.New("delivery service not in cdn"), nil)
+               return
+       }
+
        go GetLetsEncryptCertificates(inf.Config, req, ctx, inf.User)
 
        api.WriteRespAlert(w, r, tc.InfoLevel, "Beginning async call to Let's 
Encrypt for "+*req.DeliveryService+".  This may take a few minutes.")

Reply via email to