You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@trafficcontrol.apache.org by oc...@apache.org on 2020/03/26 21:16:26 UTC
[trafficcontrol] branch master updated: updated Lets Encrypt
endpoint to perform checks (#4540)
This is an automated email from the ASF dual-hosted git repository.
ocket8888 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/trafficcontrol.git
The following commit(s) were added to refs/heads/master by this push:
new c0af35f updated Lets Encrypt endpoint to perform checks (#4540)
c0af35f is described below
commit c0af35f0c82d3cb22e168a63372d3730e2af1670
Author: mattjackson220 <33...@users.noreply.github.com>
AuthorDate: Thu Mar 26 15:16:17 2020 -0600
updated Lets Encrypt endpoint to perform checks (#4540)
* updated Lets Encrypt endpoint to perform checks
* update per comment
* updated godoc
Co-authored-by: mjacks258 <ma...@comcast.com>
---
.../traffic_ops_golang/dbhelpers/db_helpers.go | 30 +++++
.../dbhelpers/db_helpers_test.go | 121 +++++++++++++++++++++
.../deliveryservice/letsencryptcert.go | 31 ++++++
3 files changed, 182 insertions(+)
diff --git a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
index 0a7f732..aa9af11 100644
--- a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
+++ b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers.go
@@ -324,6 +324,24 @@ WHERE ds.id = $1
return name, cdn, true, nil
}
+// GetDSIDAndCDNFromName returns the delivery service ID and cdn name given from the delivery service name, whether a result existed, and any error.
+func GetDSIDAndCDNFromName(tx *sql.Tx, xmlID string) (int, tc.CDNName, bool, error) {
+ dsId := 0
+ cdn := tc.CDNName("")
+ if err := tx.QueryRow(`
+SELECT ds.id, cdn.name
+FROM deliveryservice as ds
+JOIN cdn on cdn.id = ds.cdn_id
+WHERE ds.xml_id = $1
+`, xmlID).Scan(&dsId, &cdn); err != nil {
+ if err == sql.ErrNoRows {
+ return dsId, tc.CDNName(""), false, nil
+ }
+ return dsId, tc.CDNName(""), false, errors.New("querying delivery service name: " + err.Error())
+ }
+ return dsId, cdn, true, nil
+}
+
// GetFederationResolversByFederationID fetches all of the federation resolvers currently assigned to a federation.
// In the event of an error, it will return an empty slice and the error.
func GetFederationResolversByFederationID(tx *sql.Tx, fedID int) ([]tc.FederationResolver, error) {
@@ -463,6 +481,18 @@ func GetCDNNameFromID(tx *sql.Tx, id int64) (tc.CDNName, bool, error) {
return tc.CDNName(name), true, nil
}
+// GetCDNIDFromName returns the ID of the CDN if a CDN with the name exists
+func GetCDNIDFromName(tx *sql.Tx, name tc.CDNName) (int, bool, error) {
+ id := 0
+ if err := tx.QueryRow(`SELECT id FROM cdn WHERE name = $1`, name).Scan(&id); err != nil {
+ if err == sql.ErrNoRows {
+ return id, false, nil
+ }
+ return id, false, errors.New("querying CDN ID: " + err.Error())
+ }
+ return id, true, nil
+}
+
// GetCDNDomainFromName returns the domain, whether the cdn exists, and any error.
func GetCDNDomainFromName(tx *sql.Tx, cdnName tc.CDNName) (string, bool, error) {
domain := ""
diff --git a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go
index 8ec02a3..3826d7a 100644
--- a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go
+++ b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go
@@ -140,3 +140,124 @@ func TestGetCacheGroupByName(t *testing.T) {
}
}
+
+func TestGetDSIDAndCDNFromName(t *testing.T) {
+ var testCases = []struct {
+ description string
+ storageError error
+ found bool
+ }{
+ {
+ description: "Success: DS ID and CDN Name found",
+ storageError: nil,
+ found: true,
+ },
+ {
+ description: "Failure: DS ID or CDN Name not found",
+ storageError: nil,
+ found: false,
+ },
+ {
+ description: "Failure: Storage error getting DS ID or CDN Name",
+ storageError: errors.New("error getting the delivery service ID or the CDN name"),
+ found: false,
+ },
+ }
+ for _, testCase := range testCases {
+ t.Run(testCase.description, func(t *testing.T) {
+ t.Log("Starting test scenario: ", testCase.description)
+ mockDB, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer mockDB.Close()
+ db := sqlx.NewDb(mockDB, "sqlmock")
+ defer db.Close()
+ rows := sqlmock.NewRows([]string{
+ "id",
+ "name",
+ })
+ mock.ExpectBegin()
+ if testCase.storageError != nil {
+ mock.ExpectQuery("SELECT").WillReturnError(testCase.storageError)
+ } else {
+ if testCase.found {
+ rows = rows.AddRow(1, "testCdn")
+ }
+ mock.ExpectQuery("SELECT").WillReturnRows(rows)
+ }
+ mock.ExpectCommit()
+ _, _, exists, err := GetDSIDAndCDNFromName(db.MustBegin().Tx, "testDs")
+ if testCase.storageError != nil && err == nil {
+ t.Errorf("Storage error expected: received no storage error")
+ }
+ if testCase.storageError == nil && err != nil {
+ t.Errorf("Storage error not expected: received storage error")
+ }
+ if testCase.found != exists {
+ t.Errorf("Expected return exists: %t, actual %t", testCase.found, exists)
+ }
+ })
+ }
+
+}
+
+func TestGetCDNIDFromName(t *testing.T) {
+ var testCases = []struct {
+ description string
+ storageError error
+ found bool
+ }{
+ {
+ description: "Success: CDN ID found",
+ storageError: nil,
+ found: true,
+ },
+ {
+ description: "Failure: CDN ID not found",
+ storageError: nil,
+ found: false,
+ },
+ {
+ description: "Failure: Storage error getting CDN ID",
+ storageError: errors.New("error getting the CDN ID"),
+ found: false,
+ },
+ }
+ for _, testCase := range testCases {
+ t.Run(testCase.description, func(t *testing.T) {
+ t.Log("Starting test scenario: ", testCase.description)
+ mockDB, mock, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer mockDB.Close()
+ db := sqlx.NewDb(mockDB, "sqlmock")
+ defer db.Close()
+ rows := sqlmock.NewRows([]string{
+ "id",
+ })
+ mock.ExpectBegin()
+ if testCase.storageError != nil {
+ mock.ExpectQuery("SELECT").WillReturnError(testCase.storageError)
+ } else {
+ if testCase.found {
+ rows = rows.AddRow(1)
+ }
+ mock.ExpectQuery("SELECT").WillReturnRows(rows)
+ }
+ mock.ExpectCommit()
+ _, exists, err := GetCDNIDFromName(db.MustBegin().Tx, "testCdn")
+ if testCase.storageError != nil && err == nil {
+ t.Errorf("Storage error expected: received no storage error")
+ }
+ if testCase.storageError == nil && err != nil {
+ t.Errorf("Storage error not expected: received storage error")
+ }
+ if testCase.found != exists {
+ t.Errorf("Expected return exists: %t, actual %t", testCase.found, exists)
+ }
+ })
+ }
+
+}
diff --git a/traffic_ops/traffic_ops_golang/deliveryservice/letsencryptcert.go b/traffic_ops/traffic_ops_golang/deliveryservice/letsencryptcert.go
index 64e1aff..fc7ff49 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservice/letsencryptcert.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservice/letsencryptcert.go
@@ -38,7 +38,9 @@ import (
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/api"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/auth"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/config"
+ "github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/dbhelpers"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/riaksvc"
+ "github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/tenant"
"github.com/go-acme/lego/certcrypto"
"github.com/go-acme/lego/certificate"
"github.com/go-acme/lego/challenge"
@@ -149,6 +151,35 @@ func GenerateLetsEncryptCertificates(w http.ResponseWriter, r *http.Request) {
req.DeliveryService = req.Key
}
+ dsID, cdnName, ok, err := dbhelpers.GetDSIDAndCDNFromName(inf.Tx.Tx, *req.DeliveryService)
+ if err != nil {
+ api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("deliveryservice.GenerateLetsEncryptCertificates: getting DS ID from name "+err.Error()))
+ return
+ } else if !ok {
+ api.HandleErr(w, r, inf.Tx.Tx, http.StatusNotFound, errors.New("no DS with name "+*req.DeliveryService), nil)
+ return
+ }
+
+ userErr, sysErr, errCode = tenant.CheckID(inf.Tx.Tx, inf.User, dsID)
+ if userErr != nil || sysErr != nil {
+ api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
+ return
+ }
+
+ _, ok, err = dbhelpers.GetCDNIDFromName(inf.Tx.Tx, tc.CDNName(*req.CDN))
+ if err != nil {
+ api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("checking CDN existence: "+err.Error()))
+ return
+ } else if !ok {
+ api.HandleErr(w, r, inf.Tx.Tx, http.StatusNotFound, errors.New("cdn not found with name "+*req.CDN), nil)
+ return
+ }
+
+ if cdnName != tc.CDNName(*req.CDN) {
+ api.HandleErr(w, r, inf.Tx.Tx, http.StatusBadRequest, errors.New("delivery service not in cdn"), nil)
+ return
+ }
+
go GetLetsEncryptCertificates(inf.Config, req, ctx, inf.User)
api.WriteRespAlert(w, r, tc.InfoLevel, "Beginning async call to Let's Encrypt for "+*req.DeliveryService+". This may take a few minutes.")