You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@trafficcontrol.apache.org by oc...@apache.org on 2020/03/26 21:16:26 UTC

[trafficcontrol] branch master updated: updated Lets Encrypt endpoint to perform checks (#4540)

This is an automated email from the ASF dual-hosted git repository.

ocket8888 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/trafficcontrol.git


The following commit(s) were added to refs/heads/master by this push:
     new c0af35f  updated Lets Encrypt endpoint to perform checks (#4540)
c0af35f is described below

commit c0af35f0c82d3cb22e168a63372d3730e2af1670
Author: mattjackson220 <33...@users.noreply.github.com>
AuthorDate: Thu Mar 26 15:16:17 2020 -0600

    updated Lets Encrypt endpoint to perform checks (#4540)
    
    * updated Lets Encrypt endpoint to perform checks
    
    * update per comment
    
    * updated godoc
    
    Co-authored-by: mjacks258 <ma...@comcast.com>
---
 .../traffic_ops_golang/dbhelpers/db_helpers.go     |  30 +++++
 .../dbhelpers/db_helpers_test.go                   | 121 +++++++++++++++++++++
 .../deliveryservice/letsencryptcert.go             |  31 ++++++
 3 files changed, 182 insertions(+)

diff --git a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
index 0a7f732..aa9af11 100644
--- a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
+++ b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
@@ -324,6 +324,24 @@ WHERE ds.id = $1
 	return name, cdn, true, nil
 }
 
+// GetDSIDAndCDNFromName returns the delivery service ID and cdn name given from the delivery service name, whether a result existed, and any error.
+func GetDSIDAndCDNFromName(tx *sql.Tx, xmlID string) (int, tc.CDNName, bool, error) {
+	dsId := 0
+	cdn := tc.CDNName("")
+	if err := tx.QueryRow(`
+SELECT ds.id, cdn.name
+FROM deliveryservice as ds
+JOIN cdn on cdn.id = ds.cdn_id
+WHERE ds.xml_id = $1
+`, xmlID).Scan(&dsId, &cdn); err != nil {
+		if err == sql.ErrNoRows {
+			return dsId, tc.CDNName(""), false, nil
+		}
+		return dsId, tc.CDNName(""), false, errors.New("querying delivery service name: " + err.Error())
+	}
+	return dsId, cdn, true, nil
+}
+
 // GetFederationResolversByFederationID fetches all of the federation resolvers currently assigned to a federation.
 // In the event of an error, it will return an empty slice and the error.
 func GetFederationResolversByFederationID(tx *sql.Tx, fedID int) ([]tc.FederationResolver, error) {
@@ -463,6 +481,18 @@ func GetCDNNameFromID(tx *sql.Tx, id int64) (tc.CDNName, bool, error) {
 	return tc.CDNName(name), true, nil
 }
 
+// GetCDNIDFromName returns the ID of the CDN if a CDN with the name exists
+func GetCDNIDFromName(tx *sql.Tx, name tc.CDNName) (int, bool, error) {
+	id := 0
+	if err := tx.QueryRow(`SELECT id FROM cdn WHERE name = $1`, name).Scan(&id); err != nil {
+		if err == sql.ErrNoRows {
+			return id, false, nil
+		}
+		return id, false, errors.New("querying CDN ID: " + err.Error())
+	}
+	return id, true, nil
+}
+
 // GetCDNDomainFromName returns the domain, whether the cdn exists, and any error.
 func GetCDNDomainFromName(tx *sql.Tx, cdnName tc.CDNName) (string, bool, error) {
 	domain := ""
diff --git a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go
index 8ec02a3..3826d7a 100644
--- a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go
+++ b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go
@@ -140,3 +140,124 @@ func TestGetCacheGroupByName(t *testing.T) {
 	}
 
 }
+
+func TestGetDSIDAndCDNFromName(t *testing.T) {
+	var testCases = []struct {
+		description  string
+		storageError error
+		found        bool
+	}{
+		{
+			description:  "Success: DS ID and CDN Name found",
+			storageError: nil,
+			found:        true,
+		},
+		{
+			description:  "Failure: DS ID or CDN Name not found",
+			storageError: nil,
+			found:        false,
+		},
+		{
+			description:  "Failure: Storage error getting DS ID or CDN Name",
+			storageError: errors.New("error getting the delivery service ID or the CDN name"),
+			found:        false,
+		},
+	}
+	for _, testCase := range testCases {
+		t.Run(testCase.description, func(t *testing.T) {
+			t.Log("Starting test scenario: ", testCase.description)
+			mockDB, mock, err := sqlmock.New()
+			if err != nil {
+				t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+			}
+			defer mockDB.Close()
+			db := sqlx.NewDb(mockDB, "sqlmock")
+			defer db.Close()
+			rows := sqlmock.NewRows([]string{
+				"id",
+				"name",
+			})
+			mock.ExpectBegin()
+			if testCase.storageError != nil {
+				mock.ExpectQuery("SELECT").WillReturnError(testCase.storageError)
+			} else {
+				if testCase.found {
+					rows = rows.AddRow(1, "testCdn")
+				}
+				mock.ExpectQuery("SELECT").WillReturnRows(rows)
+			}
+			mock.ExpectCommit()
+			_, _, exists, err := GetDSIDAndCDNFromName(db.MustBegin().Tx, "testDs")
+			if testCase.storageError != nil && err == nil {
+				t.Errorf("Storage error expected: received no storage error")
+			}
+			if testCase.storageError == nil && err != nil {
+				t.Errorf("Storage error not expected: received storage error")
+			}
+			if testCase.found != exists {
+				t.Errorf("Expected return exists: %t, actual %t", testCase.found, exists)
+			}
+		})
+	}
+
+}
+
+func TestGetCDNIDFromName(t *testing.T) {
+	var testCases = []struct {
+		description  string
+		storageError error
+		found        bool
+	}{
+		{
+			description:  "Success: CDN ID found",
+			storageError: nil,
+			found:        true,
+		},
+		{
+			description:  "Failure: CDN ID not found",
+			storageError: nil,
+			found:        false,
+		},
+		{
+			description:  "Failure: Storage error getting CDN ID",
+			storageError: errors.New("error getting the CDN ID"),
+			found:        false,
+		},
+	}
+	for _, testCase := range testCases {
+		t.Run(testCase.description, func(t *testing.T) {
+			t.Log("Starting test scenario: ", testCase.description)
+			mockDB, mock, err := sqlmock.New()
+			if err != nil {
+				t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+			}
+			defer mockDB.Close()
+			db := sqlx.NewDb(mockDB, "sqlmock")
+			defer db.Close()
+			rows := sqlmock.NewRows([]string{
+				"id",
+			})
+			mock.ExpectBegin()
+			if testCase.storageError != nil {
+				mock.ExpectQuery("SELECT").WillReturnError(testCase.storageError)
+			} else {
+				if testCase.found {
+					rows = rows.AddRow(1)
+				}
+				mock.ExpectQuery("SELECT").WillReturnRows(rows)
+			}
+			mock.ExpectCommit()
+			_, exists, err := GetCDNIDFromName(db.MustBegin().Tx, "testCdn")
+			if testCase.storageError != nil && err == nil {
+				t.Errorf("Storage error expected: received no storage error")
+			}
+			if testCase.storageError == nil && err != nil {
+				t.Errorf("Storage error not expected: received storage error")
+			}
+			if testCase.found != exists {
+				t.Errorf("Expected return exists: %t, actual %t", testCase.found, exists)
+			}
+		})
+	}
+
+}
diff --git a/traffic_ops/traffic_ops_golang/deliveryservice/letsencryptcert.go b/traffic_ops/traffic_ops_golang/deliveryservice/letsencryptcert.go
index 64e1aff..fc7ff49 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservice/letsencryptcert.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservice/letsencryptcert.go
@@ -38,7 +38,9 @@ import (
 	"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/api"
 	"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/auth"
 	"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/config"
+	"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/dbhelpers"
 	"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/riaksvc"
+	"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/tenant"
 	"github.com/go-acme/lego/certcrypto"
 	"github.com/go-acme/lego/certificate"
 	"github.com/go-acme/lego/challenge"
@@ -149,6 +151,35 @@ func GenerateLetsEncryptCertificates(w http.ResponseWriter, r *http.Request) {
 		req.DeliveryService = req.Key
 	}
 
+	dsID, cdnName, ok, err := dbhelpers.GetDSIDAndCDNFromName(inf.Tx.Tx, *req.DeliveryService)
+	if err != nil {
+		api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("deliveryservice.GenerateLetsEncryptCertificates: getting DS ID from name "+err.Error()))
+		return
+	} else if !ok {
+		api.HandleErr(w, r, inf.Tx.Tx, http.StatusNotFound, errors.New("no DS with name "+*req.DeliveryService), nil)
+		return
+	}
+
+	userErr, sysErr, errCode = tenant.CheckID(inf.Tx.Tx, inf.User, dsID)
+	if userErr != nil || sysErr != nil {
+		api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
+		return
+	}
+
+	_, ok, err = dbhelpers.GetCDNIDFromName(inf.Tx.Tx, tc.CDNName(*req.CDN))
+	if err != nil {
+		api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("checking CDN existence: "+err.Error()))
+		return
+	} else if !ok {
+		api.HandleErr(w, r, inf.Tx.Tx, http.StatusNotFound, errors.New("cdn not found with name "+*req.CDN), nil)
+		return
+	}
+
+	if cdnName != tc.CDNName(*req.CDN) {
+		api.HandleErr(w, r, inf.Tx.Tx, http.StatusBadRequest, errors.New("delivery service not in cdn"), nil)
+		return
+	}
+
 	go GetLetsEncryptCertificates(inf.Config, req, ctx, inf.User)
 
 	api.WriteRespAlert(w, r, tc.InfoLevel, "Beginning async call to Let's Encrypt for "+*req.DeliveryService+".  This may take a few minutes.")