You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@trafficcontrol.apache.org by ra...@apache.org on 2021/05/12 22:57:23 UTC
[trafficcontrol] branch master updated: Call cancel functions to
avoid context leaks (#5836)
This is an automated email from the ASF dual-hosted git repository.
rawlin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/trafficcontrol.git
The following commit(s) were added to refs/heads/master by this push:
new 49c6148 Call cancel functions to avoid context leaks (#5836)
49c6148 is described below
commit 49c6148652023d251524162d51a10db07f6e1e2a
Author: Zach Hoffman <zr...@apache.org>
AuthorDate: Wed May 12 16:57:08 2021 -0600
Call cancel functions to avoid context leaks (#5836)
* Implicitly skip JSON for InvalidationJobInput.dsid and InvalidationJobInput.ttl fields
* Enable already-satisfied checks for Go vet GitHub Action:
- httpresponse
- loopclosure
- structtag
- unmarshal
- unreachable
* Add CancelTx to APIInfo
* Call cancel functions to avoid context leaks
* Remove time.Duration typecast
* Call APIInfo.CancelTx() in APIInfo.Close()
* Do not cancel the context before it is used
* Call cancelTx() on all execution paths
* Return after handling an API error
---
.github/workflows/go.vet.yml | 6 -----
lib/go-tc/invalidationjobs.go | 4 ++--
traffic_ops/traffic_ops_golang/api/api.go | 9 +++++---
.../traffic_ops_golang/crconfig/config_test.go | 6 +++--
.../crconfig/deliveryservice_test.go | 12 ++++++----
.../crconfig/edgelocations_test.go | 3 ++-
.../traffic_ops_golang/crconfig/servers_test.go | 21 +++++++++++------
.../traffic_ops_golang/crconfig/snapshot_test.go | 6 +++--
.../traffic_ops_golang/crconfig/topologies_test.go | 3 ++-
.../dbhelpers/db_helpers_test.go | 3 ++-
.../traffic_ops_golang/deliveryservice/acme.go | 27 ++++++++++++++++++----
.../deliveryservice/acme_renew.go | 3 ++-
.../deliveryservice/autorenewcerts.go | 12 ++++++----
.../deliveryservice/deleteoldcerts.go | 2 +-
.../traffic_ops_golang/hwinfo/hwinfo_test.go | 3 ++-
.../traffic_ops_golang/login/logout_test.go | 3 ++-
.../server/servers_assignment_test.go | 3 ++-
.../systeminfo/system_info_test.go | 3 ++-
.../backends/riaksvc/riak_services_test.go | 3 ++-
19 files changed, 87 insertions(+), 45 deletions(-)
diff --git a/.github/workflows/go.vet.yml b/.github/workflows/go.vet.yml
index 156af50..f347c36 100644
--- a/.github/workflows/go.vet.yml
+++ b/.github/workflows/go.vet.yml
@@ -51,13 +51,7 @@ jobs:
run: |
skip_analyzers=(
-copylocks=false
- -httpresponse=false
- -loopclosure=false
- -lostcancel=false
-printf=false
- -structtag=false
-tests=false
- -unmarshal=false
- -unreachable=false
)
go vet "${skip_analyzers[@]}" ./...
diff --git a/lib/go-tc/invalidationjobs.go b/lib/go-tc/invalidationjobs.go
index 414ec97..5ceef96 100644
--- a/lib/go-tc/invalidationjobs.go
+++ b/lib/go-tc/invalidationjobs.go
@@ -79,8 +79,8 @@ type InvalidationJobInput struct {
// number
TTL *interface{} `json:"ttl"`
- dsid *uint `json:"-"`
- ttl *time.Duration `json:"-"`
+ dsid *uint
+ ttl *time.Duration
}
// UserInvalidationJobInput Represents legacy-style user input to the /user/current/jobs API endpoint.
diff --git a/traffic_ops/traffic_ops_golang/api/api.go b/traffic_ops/traffic_ops_golang/api/api.go
index 2c0f55c..3550f85 100644
--- a/traffic_ops/traffic_ops_golang/api/api.go
+++ b/traffic_ops/traffic_ops_golang/api/api.go
@@ -492,6 +492,7 @@ type APIInfo struct {
ReqID uint64
Version *Version
Tx *sqlx.Tx
+ CancelTx context.CancelFunc
Vault trafficvault.TrafficVault
Config *config.Config
request *http.Request
@@ -553,10 +554,10 @@ func NewInfo(r *http.Request, requiredParams []string, intParamNames []string) (
if userErr != nil || sysErr != nil {
return &APIInfo{Tx: &sqlx.Tx{}}, userErr, sysErr, errCode
}
- dbCtx, _ := context.WithTimeout(r.Context(), time.Duration(cfg.DBQueryTimeoutSeconds)*time.Second) //only place we could call cancel here is in APIInfo.Close(), which already will rollback the transaction (which is all cancel will do.)
- tx, err := db.BeginTxx(dbCtx, nil) // must be last, MUST not return an error if this succeeds, without closing the tx
+ dbCtx, cancelTx := context.WithTimeout(r.Context(), time.Duration(cfg.DBQueryTimeoutSeconds)*time.Second) //only place we could call cancel here is in APIInfo.Close(), which already will rollback the transaction (which is all cancel will do.)
+ tx, err := db.BeginTxx(dbCtx, nil) // must be last, MUST not return an error if this succeeds, without closing the tx
if err != nil {
- return &APIInfo{Tx: &sqlx.Tx{}}, userErr, errors.New("could not begin transaction: " + err.Error()), http.StatusInternalServerError
+ return &APIInfo{Tx: &sqlx.Tx{}, CancelTx: cancelTx}, userErr, errors.New("could not begin transaction: " + err.Error()), http.StatusInternalServerError
}
return &APIInfo{
Config: cfg,
@@ -566,6 +567,7 @@ func NewInfo(r *http.Request, requiredParams []string, intParamNames []string) (
IntParams: intParams,
User: user,
Tx: tx,
+ CancelTx: cancelTx,
Vault: tv,
request: r,
}, nil, nil, http.StatusOK
@@ -654,6 +656,7 @@ func (inf APIInfo) CheckPrecondition(query string, args ...interface{}) (int, er
//
// Close will commit the transaction, if it hasn't been rolled back.
func (inf *APIInfo) Close() {
+ defer inf.CancelTx()
if err := inf.Tx.Tx.Commit(); err != nil && err != sql.ErrTxDone {
log.Errorln("committing transaction: " + err.Error())
}
diff --git a/traffic_ops/traffic_ops_golang/crconfig/config_test.go b/traffic_ops/traffic_ops_golang/crconfig/config_test.go
index 0878265..816ace7 100644
--- a/traffic_ops/traffic_ops_golang/crconfig/config_test.go
+++ b/traffic_ops/traffic_ops_golang/crconfig/config_test.go
@@ -61,7 +61,8 @@ func TestGetConfigParams(t *testing.T) {
MockGetConfigParams(mock, expected, cdn)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
@@ -124,7 +125,8 @@ func TestMakeCRConfigConfig(t *testing.T) {
expected := ExpectedMakeCRConfigConfig(expectedGetConfigParams, dnssecEnabled)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
diff --git a/traffic_ops/traffic_ops_golang/crconfig/deliveryservice_test.go b/traffic_ops/traffic_ops_golang/crconfig/deliveryservice_test.go
index 0940044..3904834 100644
--- a/traffic_ops/traffic_ops_golang/crconfig/deliveryservice_test.go
+++ b/traffic_ops/traffic_ops_golang/crconfig/deliveryservice_test.go
@@ -221,7 +221,8 @@ func TestMakeDSes(t *testing.T) {
MockMakeDSes(mock, expected, cdn)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
@@ -288,7 +289,8 @@ func TestGetServerProfileParams(t *testing.T) {
MockGetServerProfileParams(mock, expected, cdn)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
@@ -376,7 +378,8 @@ func TestGetDSRegexesDomains(t *testing.T) {
MockGetDSRegexesDomains(mock, expectedMatchsets, expectedDomains, cdn)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
@@ -439,7 +442,8 @@ func TestGetStaticDNSEntries(t *testing.T) {
MockGetStaticDNSEntries(mock, expected, cdn)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
diff --git a/traffic_ops/traffic_ops_golang/crconfig/edgelocations_test.go b/traffic_ops/traffic_ops_golang/crconfig/edgelocations_test.go
index f7b14d3..c52dc7f 100644
--- a/traffic_ops/traffic_ops_golang/crconfig/edgelocations_test.go
+++ b/traffic_ops/traffic_ops_golang/crconfig/edgelocations_test.go
@@ -87,7 +87,8 @@ func TestMakeLocations(t *testing.T) {
MockMakeLocations(mock, expectedEdgeLocs, expectedRouterLocs, cdn)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
diff --git a/traffic_ops/traffic_ops_golang/crconfig/servers_test.go b/traffic_ops/traffic_ops_golang/crconfig/servers_test.go
index 7c7dc74..a30d30c 100644
--- a/traffic_ops/traffic_ops_golang/crconfig/servers_test.go
+++ b/traffic_ops/traffic_ops_golang/crconfig/servers_test.go
@@ -180,7 +180,8 @@ func TestGetServerParams(t *testing.T) {
MockGetServerParams(mock, expected, cdn)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
@@ -431,7 +432,8 @@ func TestGetAllServers(t *testing.T) {
MockGetAllServers(mock, expected, cdn, true, true)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
@@ -467,7 +469,8 @@ func TestGetAllServersNonService(t *testing.T) {
MockGetAllServers(mock, expected, cdn, true, false)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
@@ -518,7 +521,8 @@ func TestGetServerDSNames(t *testing.T) {
MockGetServerDSNames(mock, expected, cdn)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
@@ -586,7 +590,8 @@ func TestGetServerDSes(t *testing.T) {
MockGetServerDSes(mock, expected, cdn)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
@@ -628,7 +633,8 @@ func TestGetCDNInfo(t *testing.T) {
MockGetCDNInfo(mock, expectedDomain, expectedDNSSECEnabled, cdn)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
@@ -672,7 +678,8 @@ func TestGetCDNNameFromID(t *testing.T) {
MockGetCDNNameFromID(mock, expected, cdnID)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
diff --git a/traffic_ops/traffic_ops_golang/crconfig/snapshot_test.go b/traffic_ops/traffic_ops_golang/crconfig/snapshot_test.go
index 535637a..4a59a3b 100644
--- a/traffic_ops/traffic_ops_golang/crconfig/snapshot_test.go
+++ b/traffic_ops/traffic_ops_golang/crconfig/snapshot_test.go
@@ -69,7 +69,8 @@ func TestGetSnapshot(t *testing.T) {
MockGetSnapshot(mock, expected, cdn)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
@@ -121,7 +122,8 @@ func TestSnapshot(t *testing.T) {
crc.Stats.CDNName = &cdn
mock.ExpectBegin()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
diff --git a/traffic_ops/traffic_ops_golang/crconfig/topologies_test.go b/traffic_ops/traffic_ops_golang/crconfig/topologies_test.go
index 9c9b637..3a8a632 100644
--- a/traffic_ops/traffic_ops_golang/crconfig/topologies_test.go
+++ b/traffic_ops/traffic_ops_golang/crconfig/topologies_test.go
@@ -68,7 +68,8 @@ func TestMakeTops(t *testing.T) {
MockMakeTops(mock, expected)
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatal("creating transaction: ", err)
diff --git a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go
index 289a99f..b388e8e 100644
--- a/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go
+++ b/traffic_ops/traffic_ops_golang/dbhelpers/db_helpers_test.go
@@ -246,7 +246,8 @@ func TestGetServerInterfaces(t *testing.T) {
mock.ExpectBegin()
mockServerInterfaces(mock, cacheID, serverInterfaces)
- dbCtx, _ := context.WithTimeout(context.Background(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.Background(), time.Duration(10)*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
diff --git a/traffic_ops/traffic_ops_golang/deliveryservice/acme.go b/traffic_ops/traffic_ops_golang/deliveryservice/acme.go
index 4c4d85e..cab649e 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservice/acme.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservice/acme.go
@@ -159,10 +159,11 @@ func GenerateAcmeCertificates(w http.ResponseWriter, r *http.Request) {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("deliveryservice.GenerateAcmeCertificates: Traffic Vault is not configured"))
return
}
- ctx, _ := context.WithTimeout(r.Context(), AcmeTimeout)
+ ctx, cancelTx := context.WithTimeout(r.Context(), AcmeTimeout)
req := tc.DeliveryServiceAcmeSSLKeysReq{}
if err := api.Parse(r.Body, nil, &req); err != nil {
+ defer cancelTx()
api.HandleErr(w, r, nil, http.StatusBadRequest, fmt.Errorf("parsing request: %v", err), nil)
return
}
@@ -172,29 +173,35 @@ func GenerateAcmeCertificates(w http.ResponseWriter, r *http.Request) {
dsID, cdnName, ok, err := dbhelpers.GetDSIDAndCDNFromName(inf.Tx.Tx, *req.DeliveryService)
if err != nil {
+ defer cancelTx()
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, fmt.Errorf("deliveryservice.GenerateLetsEncryptCertificates: getting DS ID from name: %v", err))
return
} else if !ok {
+ defer cancelTx()
api.HandleErr(w, r, inf.Tx.Tx, http.StatusNotFound, errors.New("no DS with name "+*req.DeliveryService), nil)
return
}
userErr, sysErr, errCode = tenant.CheckID(inf.Tx.Tx, inf.User, dsID)
if userErr != nil || sysErr != nil {
+ defer cancelTx()
api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
return
}
_, ok, err = dbhelpers.GetCDNIDFromName(inf.Tx.Tx, tc.CDNName(*req.CDN))
if err != nil {
+ defer cancelTx()
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, fmt.Errorf("checking CDN existence: %v", err))
return
} else if !ok {
+ defer cancelTx()
api.HandleErr(w, r, inf.Tx.Tx, http.StatusNotFound, errors.New("cdn not found with name "+*req.CDN), nil)
return
}
if cdnName != tc.CDNName(*req.CDN) {
+ defer cancelTx()
api.HandleErr(w, r, inf.Tx.Tx, http.StatusBadRequest, errors.New("delivery service not in cdn"), nil)
return
}
@@ -204,7 +211,7 @@ func GenerateAcmeCertificates(w http.ResponseWriter, r *http.Request) {
api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
}
- go GetAcmeCertificates(inf.Config, req, ctx, inf.User, asyncStatusId, inf.Vault)
+ go GetAcmeCertificates(inf.Config, req, ctx, cancelTx, true, inf.User, asyncStatusId, inf.Vault)
var alerts tc.Alerts
alerts.AddAlert(tc.Alert{
@@ -230,7 +237,7 @@ func GenerateLetsEncryptCertificates(w http.ResponseWriter, r *http.Request) {
return
}
- ctx, _ := context.WithTimeout(r.Context(), AcmeTimeout)
+ ctx, cancelTx := context.WithTimeout(r.Context(), AcmeTimeout)
req := tc.DeliveryServiceAcmeSSLKeysReq{}
if req.AuthType == nil {
@@ -239,6 +246,7 @@ func GenerateLetsEncryptCertificates(w http.ResponseWriter, r *http.Request) {
}
if err := api.Parse(r.Body, nil, &req); err != nil {
+ defer cancelTx()
api.HandleErr(w, r, nil, http.StatusBadRequest, fmt.Errorf("parsing request: %v", err), nil)
return
}
@@ -248,29 +256,35 @@ func GenerateLetsEncryptCertificates(w http.ResponseWriter, r *http.Request) {
dsID, cdnName, ok, err := dbhelpers.GetDSIDAndCDNFromName(inf.Tx.Tx, *req.DeliveryService)
if err != nil {
+ defer cancelTx()
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, fmt.Errorf("deliveryservice.GenerateLetsEncryptCertificates: getting DS ID from name: %v", err))
return
} else if !ok {
+ defer cancelTx()
api.HandleErr(w, r, inf.Tx.Tx, http.StatusNotFound, errors.New("no DS with name "+*req.DeliveryService), nil)
return
}
userErr, sysErr, errCode = tenant.CheckID(inf.Tx.Tx, inf.User, dsID)
if userErr != nil || sysErr != nil {
+ defer cancelTx()
api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
return
}
_, ok, err = dbhelpers.GetCDNIDFromName(inf.Tx.Tx, tc.CDNName(*req.CDN))
if err != nil {
+ defer cancelTx()
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, fmt.Errorf("checking CDN existence: %v", err))
return
} else if !ok {
+ defer cancelTx()
api.HandleErr(w, r, inf.Tx.Tx, http.StatusNotFound, errors.New("cdn not found with name "+*req.CDN), nil)
return
}
if cdnName != tc.CDNName(*req.CDN) {
+ defer cancelTx()
api.HandleErr(w, r, inf.Tx.Tx, http.StatusBadRequest, errors.New("delivery service not in cdn"), nil)
return
}
@@ -280,7 +294,7 @@ func GenerateLetsEncryptCertificates(w http.ResponseWriter, r *http.Request) {
api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
}
- go GetAcmeCertificates(inf.Config, req, ctx, inf.User, asyncStatusId, inf.Vault)
+ go GetAcmeCertificates(inf.Config, req, ctx, cancelTx, true, inf.User, asyncStatusId, inf.Vault)
var alerts tc.Alerts
alerts.AddAlerts(api.CreateDeprecationAlerts(util.StrPtr(API_ACME_GENERATE_LE)))
@@ -294,8 +308,11 @@ func GenerateLetsEncryptCertificates(w http.ResponseWriter, r *http.Request) {
}
// GetAcmeCertificates gets or creates an ACME account based on the provider, then gets new certificates for the delivery service requested and saves them to Vault.
-func GetAcmeCertificates(cfg *config.Config, req tc.DeliveryServiceAcmeSSLKeysReq, ctx context.Context, currentUser *auth.CurrentUser, asyncStatusId int, tv trafficvault.TrafficVault) error {
+func GetAcmeCertificates(cfg *config.Config, req tc.DeliveryServiceAcmeSSLKeysReq, ctx context.Context, cancelTx context.CancelFunc, shouldCancelTx bool, currentUser *auth.CurrentUser, asyncStatusId int, tv trafficvault.TrafficVault) error {
defer func() {
+ if shouldCancelTx {
+ defer cancelTx()
+ }
if err := recover(); err != nil {
db, dbErr := api.GetDB(ctx)
if dbErr != nil {
diff --git a/traffic_ops/traffic_ops_golang/deliveryservice/acme_renew.go b/traffic_ops/traffic_ops_golang/deliveryservice/acme_renew.go
index 2e0288c..20c15df 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservice/acme_renew.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservice/acme_renew.go
@@ -56,7 +56,8 @@ func RenewAcmeCertificate(w http.ResponseWriter, r *http.Request) {
return
}
- ctx, _ := context.WithTimeout(r.Context(), AcmeTimeout)
+ ctx, cancelTx := context.WithTimeout(r.Context(), AcmeTimeout)
+ defer cancelTx()
userErr, sysErr, statusCode := renewAcmeCerts(inf.Config, xmlID, ctx, r.Context(), inf.User, inf.Vault)
if userErr != nil || sysErr != nil {
diff --git a/traffic_ops/traffic_ops_golang/deliveryservice/autorenewcerts.go b/traffic_ops/traffic_ops_golang/deliveryservice/autorenewcerts.go
index a557dff..87707a4 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservice/autorenewcerts.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservice/autorenewcerts.go
@@ -98,18 +98,21 @@ func renewCertificates(w http.ResponseWriter, r *http.Request, deprecated bool)
err := rows.Scan(&ds.XmlId, &ds.Version)
if err != nil {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, err)
+ return
}
existingCerts = append(existingCerts, ExistingCerts{Version: ds.Version, XmlId: ds.XmlId})
}
- ctx, _ := context.WithTimeout(r.Context(), AcmeTimeout*time.Duration(len(existingCerts)))
+ ctx, cancelTx := context.WithTimeout(r.Context(), AcmeTimeout*time.Duration(len(existingCerts)))
asyncStatusId, errCode, userErr, sysErr := api.InsertAsyncStatus(inf.Tx.Tx, "ACME async job has started.")
if userErr != nil || sysErr != nil {
+ defer cancelTx()
api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
+ return
}
- go RunAutorenewal(existingCerts, inf.Config, ctx, inf.User, asyncStatusId, inf.Vault)
+ go RunAutorenewal(existingCerts, inf.Config, ctx, cancelTx, inf.User, asyncStatusId, inf.Vault)
var alerts tc.Alerts
if deprecated {
@@ -125,7 +128,8 @@ func renewCertificates(w http.ResponseWriter, r *http.Request, deprecated bool)
api.WriteAlerts(w, r, http.StatusAccepted, alerts)
}
-func RunAutorenewal(existingCerts []ExistingCerts, cfg *config.Config, ctx context.Context, currentUser *auth.CurrentUser, asyncStatusId int, tv trafficvault.TrafficVault) {
+func RunAutorenewal(existingCerts []ExistingCerts, cfg *config.Config, ctx context.Context, cancelTx context.CancelFunc, currentUser *auth.CurrentUser, asyncStatusId int, tv trafficvault.TrafficVault) {
+ defer cancelTx()
db, err := api.GetDB(ctx)
if err != nil {
log.Errorf("Error getting db: %s", err.Error())
@@ -230,7 +234,7 @@ func RunAutorenewal(existingCerts []ExistingCerts, cfg *config.Config, ctx conte
},
}
- if err := GetAcmeCertificates(cfg, req, ctx, currentUser, 0, tv); err != nil {
+ if err := GetAcmeCertificates(cfg, req, ctx, nil, false, currentUser, 0, tv); err != nil {
dsExpInfo.Error = err
errorCount++
} else {
diff --git a/traffic_ops/traffic_ops_golang/deliveryservice/deleteoldcerts.go b/traffic_ops/traffic_ops_golang/deliveryservice/deleteoldcerts.go
index 330a48b..3946a4d 100644
--- a/traffic_ops/traffic_ops_golang/deliveryservice/deleteoldcerts.go
+++ b/traffic_ops/traffic_ops_golang/deliveryservice/deleteoldcerts.go
@@ -75,12 +75,12 @@ func deleteOldDSCerts(tx *sql.Tx, cdn tc.CDNName, tv trafficvault.TrafficVault)
// deleteOldDSCertsDB takes a db, and creates a transaction to pass to deleteOldDSCerts.
func deleteOldDSCertsDB(db *sql.DB, dbTimeout time.Duration, cdn tc.CDNName, tv trafficvault.TrafficVault) {
dbCtx, cancelTx := context.WithTimeout(context.Background(), dbTimeout)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
log.Errorln("Old Cert Deleter Job: beginning tx: " + err.Error())
return
}
- defer cancelTx()
txCommit := false
defer dbhelpers.CommitIf(tx, &txCommit)
if err := deleteOldDSCerts(tx, cdn, tv); err != nil {
diff --git a/traffic_ops/traffic_ops_golang/hwinfo/hwinfo_test.go b/traffic_ops/traffic_ops_golang/hwinfo/hwinfo_test.go
index d93c147..9295be4 100644
--- a/traffic_ops/traffic_ops_golang/hwinfo/hwinfo_test.go
+++ b/traffic_ops/traffic_ops_golang/hwinfo/hwinfo_test.go
@@ -82,7 +82,8 @@ func TestGetHWInfo(t *testing.T) {
v := map[string]string{"ServerId": "1"}
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTxx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
diff --git a/traffic_ops/traffic_ops_golang/login/logout_test.go b/traffic_ops/traffic_ops_golang/login/logout_test.go
index bf7cd3c..e605554 100644
--- a/traffic_ops/traffic_ops_golang/login/logout_test.go
+++ b/traffic_ops/traffic_ops_golang/login/logout_test.go
@@ -96,7 +96,8 @@ func TestLogout(t *testing.T) {
ctx = context.WithValue(ctx, api.PathParamsKey, map[string]string{})
var tv trafficvault.TrafficVault = &disabled.Disabled{}
ctx = context.WithValue(ctx, api.TrafficVaultContextKey, tv)
- ctx, _ = context.WithDeadline(ctx, time.Now().Add(24*time.Hour))
+ ctx, cancelTx := context.WithDeadline(ctx, time.Now().Add(24*time.Hour))
+ defer cancelTx()
req = req.WithContext(ctx)
req.AddCookie(cookie)
diff --git a/traffic_ops/traffic_ops_golang/server/servers_assignment_test.go b/traffic_ops/traffic_ops_golang/server/servers_assignment_test.go
index 2c62618..0aff38a 100644
--- a/traffic_ops/traffic_ops_golang/server/servers_assignment_test.go
+++ b/traffic_ops/traffic_ops_golang/server/servers_assignment_test.go
@@ -87,7 +87,8 @@ func TestAssignDsesToServer(t *testing.T) {
mock.ExpectExec("DELETE").WithArgs(pq.Array(delete)).WillReturnResult(sqlmock.NewResult(1, 3))
mock.ExpectCommit()
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
diff --git a/traffic_ops/traffic_ops_golang/systeminfo/system_info_test.go b/traffic_ops/traffic_ops_golang/systeminfo/system_info_test.go
index 6c60829..9cc34b6 100644
--- a/traffic_ops/traffic_ops_golang/systeminfo/system_info_test.go
+++ b/traffic_ops/traffic_ops_golang/systeminfo/system_info_test.go
@@ -95,7 +95,8 @@ func TestGetSystemInfo(t *testing.T) {
mock.ExpectBegin()
mock.ExpectQuery(`SELECT.*WHERE p.config_file = \$1`).WillReturnRows(rows)
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTxx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)
diff --git a/traffic_ops/traffic_ops_golang/trafficvault/backends/riaksvc/riak_services_test.go b/traffic_ops/traffic_ops_golang/trafficvault/backends/riaksvc/riak_services_test.go
index 925b4a6..af13070 100644
--- a/traffic_ops/traffic_ops_golang/trafficvault/backends/riaksvc/riak_services_test.go
+++ b/traffic_ops/traffic_ops_golang/trafficvault/backends/riaksvc/riak_services_test.go
@@ -182,7 +182,8 @@ func TestGetRiakCluster(t *testing.T) {
rows1 := sqlmock.NewRows([]string{"fqdn"})
rows1.AddRow("www.devnull.com")
- dbCtx, _ := context.WithTimeout(context.TODO(), time.Duration(10)*time.Second)
+ dbCtx, cancelTx := context.WithTimeout(context.TODO(), 10*time.Second)
+ defer cancelTx()
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
t.Fatalf("creating transaction: %v", err)