Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contextify all the goroutines and clean shutdowns #244

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"encoding/json"
"fmt"
"html/template"
Expand Down Expand Up @@ -29,7 +30,7 @@ const cacheScanTime = time.Minute

// Type for performing checks against an input domain. Returns
// a DomainResult object from the checker.
type checkPerformer func(API, string) (checker.DomainResult, error)
type checkPerformer func(context.Context, API, string) (checker.DomainResult, error)

// API is the HTTP API that this service provides.
// All requests respond with an APIResponse JSON, with fields:
Expand Down Expand Up @@ -89,7 +90,7 @@ func (api *API) wrapper(handler apiHandler) func(w http.ResponseWriter, r *http.
}
}

func defaultCheck(api API, domain string) (checker.DomainResult, error) {
func defaultCheck(ctx context.Context, api API, domain string) (checker.DomainResult, error) {
policyChan := models.Domain{Name: domain}.AsyncPolicyListCheck(api.Database, api.List)
c := checker.Checker{
Cache: &checker.ScanCache{
Expand All @@ -98,7 +99,7 @@ func defaultCheck(api API, domain string) (checker.DomainResult, error) {
},
Timeout: 3 * time.Second,
}
result := c.CheckDomain(domain, nil)
result := c.CheckDomain(ctx, domain, nil)
policyResult := <-policyChan
result.ExtraResults["policylist"] = &policyResult
return result, nil
Expand Down Expand Up @@ -135,7 +136,7 @@ func (api API) Scan(r *http.Request) APIResponse {
}
}
// 1. Conduct scan via starttls-checker
scanData, err := api.CheckDomain(api, domain)
scanData, err := api.CheckDomain(r.Context(), api, domain)
if err != nil {
return APIResponse{StatusCode: http.StatusInternalServerError, Message: err.Error()}
}
Expand Down
3 changes: 2 additions & 1 deletion checker/checker.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package checker

import (
"context"
"net"
"time"
)
Expand All @@ -22,7 +23,7 @@ type Checker struct {

// CheckHostname defines the function that should be used to check each hostname.
// If nil, FullCheckHostname (all hostname checks) will be used.
CheckHostname func(string, string, time.Duration) HostnameResult
CheckHostname func(context.Context, string, string, time.Duration) HostnameResult

// checkMTASTSOverride is used to mock MTA-STS checks.
checkMTASTSOverride func(string, map[string]HostnameResult) *MTASTSResult
Expand Down
3 changes: 2 additions & 1 deletion checker/cmd/starttls-check/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"bufio"
"context"
"encoding/csv"
"encoding/json"
"flag"
Expand Down Expand Up @@ -55,7 +56,7 @@ func main() {

if *domain != "" {
// Handle single domain and return
result := c.CheckDomain(*domain, nil)
result := c.CheckDomain(context.Background(), *domain, nil)
resultHandler.HandleDomain(result)
os.Exit(0)
}
Expand Down
16 changes: 8 additions & 8 deletions checker/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ func (d DomainResult) setStatus(status DomainStatus) DomainResult {
return d
}

func lookupMXWithTimeout(domain string, timeout time.Duration) ([]*net.MX, error) {
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
func lookupMXWithTimeout(ctx context.Context, domain string, timeout time.Duration) ([]*net.MX, error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
var r net.Resolver
return r.LookupMX(ctx, domain)
}

// lookupHostnames retrieves the MX hostnames associated with a domain.
func (c *Checker) lookupHostnames(domain string) ([]string, error) {
func (c *Checker) lookupHostnames(ctx context.Context, domain string) ([]string, error) {
domainASCII, err := idna.ToASCII(domain)
if err != nil {
return nil, fmt.Errorf("domain name %s couldn't be converted to ASCII", domain)
Expand All @@ -81,7 +81,7 @@ func (c *Checker) lookupHostnames(domain string) ([]string, error) {
if c.lookupMXOverride != nil {
mxs, err = c.lookupMXOverride(domain)
} else {
mxs, err = lookupMXWithTimeout(domainASCII, c.timeout())
mxs, err = lookupMXWithTimeout(ctx, domainASCII, c.timeout())
}
if err != nil || len(mxs) == 0 {
return nil, fmt.Errorf("No MX records found")
Expand All @@ -104,7 +104,7 @@ func (c *Checker) lookupHostnames(domain string) ([]string, error) {
// `domain` is the mail domain to perform the lookup on.
// `expectedHostnames` is the list of expected hostnames.
// If `expectedHostnames` is nil, we don't validate the DNS lookup.
func (c *Checker) CheckDomain(domain string, expectedHostnames []string) DomainResult {
func (c *Checker) CheckDomain(ctx context.Context, domain string, expectedHostnames []string) DomainResult {
result := DomainResult{
Domain: domain,
MxHostnames: expectedHostnames,
Expand All @@ -114,20 +114,20 @@ func (c *Checker) CheckDomain(domain string, expectedHostnames []string) DomainR
// 1. Look up hostnames
// 2. Perform and aggregate checks from those hostnames.
// 3. Set a summary message.
hostnames, err := c.lookupHostnames(domain)
hostnames, err := c.lookupHostnames(ctx, domain)
if err != nil {
return result.setStatus(DomainCouldNotConnect)
}
checkedHostnames := make([]string, 0)
for _, hostname := range hostnames {
hostnameResult := c.checkHostname(domain, hostname)
hostnameResult := c.checkHostname(ctx, domain, hostname)
result.HostnameResults[hostname] = hostnameResult
if hostnameResult.couldConnect() {
checkedHostnames = append(checkedHostnames, hostname)
}
}
result.PreferredHostnames = checkedHostnames
result.MTASTSResult = c.checkMTASTS(domain, result.HostnameResults)
result.MTASTSResult = c.checkMTASTS(ctx, domain, result.HostnameResults)

// Derive Domain code from Hostname results.
if len(checkedHostnames) == 0 {
Expand Down
5 changes: 3 additions & 2 deletions checker/domain_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package checker

import (
"context"
"fmt"
"net"
"testing"
Expand Down Expand Up @@ -59,7 +60,7 @@ func mockLookupMX(domain string) ([]*net.MX, error) {
return result, nil
}

func mockCheckHostname(domain string, hostname string, _ time.Duration) HostnameResult {
func mockCheckHostname(_ context.Context, domain string, hostname string, _ time.Duration) HostnameResult {
if result, ok := hostnameResults[hostname]; ok {
return HostnameResult{
Result: &result,
Expand Down Expand Up @@ -120,7 +121,7 @@ func performTestsWithCacheTimeout(t *testing.T, tests []domainTestCase, cacheExp
if test.expectedHostnames == nil {
test.expectedHostnames = mxLookup[test.domain]
}
got := c.CheckDomain(test.domain, test.expectedHostnames).Status
got := c.CheckDomain(context.Background(), test.domain, test.expectedHostnames).Status
test.check(t, got)
}
}
Expand Down
39 changes: 21 additions & 18 deletions checker/hostname.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package checker

import (
"context"
"crypto/tls"
"crypto/x509"
"net"
Expand Down Expand Up @@ -70,11 +71,13 @@ func getThisHostname() string {

// Performs an SMTP dial with a short timeout.
// https://github.com/golang/go/issues/16436
func smtpDialWithTimeout(hostname string, timeout time.Duration) (*smtp.Client, error) {
func smtpDialWithTimeout(ctx context.Context, hostname string, timeout time.Duration) (*smtp.Client, error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if _, _, err := net.SplitHostPort(hostname); err != nil {
hostname += ":25"
}
conn, err := net.DialTimeout("tcp", hostname, timeout)
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", hostname)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -160,13 +163,13 @@ func tlsConfigForCipher(ciphers []uint16) tls.Config {
}

// Checks to see that insecure ciphers are disabled.
func checkTLSCipher(hostname string, timeout time.Duration) *Result {
func checkTLSCipher(ctx context.Context, hostname string, timeout time.Duration) *Result {
result := MakeResult("cipher")
badCiphers := []uint16{
tls.TLS_RSA_WITH_RC4_128_SHA,
tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA}
client, err := smtpDialWithTimeout(hostname, timeout)
client, err := smtpDialWithTimeout(ctx, hostname, timeout)
if err != nil {
return result.Error("Could not establish connection with hostname %s", hostname)
}
Expand All @@ -179,7 +182,7 @@ func checkTLSCipher(hostname string, timeout time.Duration) *Result {
return result.Success()
}

func checkTLSVersion(client *smtp.Client, hostname string, timeout time.Duration) *Result {
func checkTLSVersion(ctx context.Context, client *smtp.Client, hostname string, timeout time.Duration) *Result {
result := MakeResult(Version)

// Check the TLS version of the existing connection.
Expand All @@ -193,7 +196,7 @@ func checkTLSVersion(client *smtp.Client, hostname string, timeout time.Duration
}

// Attempt to connect with an old SSL version.
client, err := smtpDialWithTimeout(hostname, timeout)
client, err := smtpDialWithTimeout(ctx, hostname, timeout)
if err != nil {
return result.Error("Could not establish connection: %v", err)
}
Expand All @@ -212,26 +215,26 @@ func checkTLSVersion(client *smtp.Client, hostname string, timeout time.Duration

// checkHostname returns the result of c.CheckHostname or FullCheckHostname,
// using or updating the Checker's cache.
func (c *Checker) checkHostname(domain string, hostname string) HostnameResult {
func (c *Checker) checkHostname(ctx context.Context, domain string, hostname string) HostnameResult {
check := c.CheckHostname
if check == nil {
// If CheckHostname hasn't been set, default to the full set of checks.
check = FullCheckHostname
}

if c.Cache == nil {
return check(domain, hostname, c.timeout())
return check(ctx, domain, hostname, c.timeout())
}
hostnameResult, err := c.Cache.GetHostnameScan(hostname)
if err != nil {
hostnameResult = check(domain, hostname, c.timeout())
hostnameResult = check(ctx, domain, hostname, c.timeout())
c.Cache.PutHostnameScan(hostname, hostnameResult)
}
return hostnameResult
}

// NoopCheckHostname returns a fake error result containing `domain` and `hostname`.
func NoopCheckHostname(domain string, hostname string, _ time.Duration) HostnameResult {
func NoopCheckHostname(ctx context.Context, domain string, hostname string, _ time.Duration) HostnameResult {
r := HostnameResult{
Domain: domain,
Hostname: hostname,
Expand All @@ -244,8 +247,8 @@ func NoopCheckHostname(domain string, hostname string, _ time.Duration) Hostname
// FullCheckHostname performs a series of checks against a hostname for an email domain.
// `domain` is the mail domain that this server serves email for.
// `hostname` is the hostname for this server.
func FullCheckHostname(domain string, hostname string, timeout time.Duration) HostnameResult {
result := HostnameResult{
func FullCheckHostname(ctx context.Context, domain string, hostname string, timeout time.Duration) HostnameResult {
result := &HostnameResult{
Domain: domain,
Hostname: hostname,
Result: MakeResult("hostnames"),
Expand All @@ -254,23 +257,23 @@ func FullCheckHostname(domain string, hostname string, timeout time.Duration) Ho

// Connect to the SMTP server and use that connection to perform as many checks as possible.
connectivityResult := MakeResult(Connectivity)
client, err := smtpDialWithTimeout(hostname, timeout)
client, err := smtpDialWithTimeout(ctx, hostname, timeout)
if err != nil {
result.addCheck(connectivityResult.Error("Could not establish connection: %v", err))
return result
return *result
}
defer client.Close()
defer client.Quit()
result.addCheck(connectivityResult.Success())

result.addCheck(checkStartTLS(client))
if result.Status != Success {
return result
return *result
}
result.addCheck(checkCert(client, domain, hostname))
// result.addCheck(checkTLSCipher(hostname))

// Creates a new connection to check for SSLv2/3 support because we can't call starttls twice.
result.addCheck(checkTLSVersion(client, hostname, timeout))
result.addCheck(checkTLSVersion(ctx, client, hostname, timeout))

return result
return *result
}
24 changes: 12 additions & 12 deletions checker/hostname_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package checker

import (
"bufio"
"context"
"crypto/rand"
"crypto/tls"
"crypto/x509"
Expand Down Expand Up @@ -100,7 +101,7 @@ func TestPolicyMatch(t *testing.T) {
}

func TestNoConnection(t *testing.T) {
result := FullCheckHostname("", "example.com", testTimeout)
result := FullCheckHostname(context.Background(), "", "example.com", testTimeout)

expected := Result{
Status: 3,
Expand All @@ -115,7 +116,7 @@ func TestNoTLS(t *testing.T) {
ln := smtpListenAndServe(t, &tls.Config{})
defer ln.Close()

result := FullCheckHostname("", ln.Addr().String(), testTimeout)
result := FullCheckHostname(context.Background(), "", ln.Addr().String(), testTimeout)

expected := Result{
Status: 2,
Expand All @@ -135,7 +136,7 @@ func TestSelfSigned(t *testing.T) {
ln := smtpListenAndServe(t, &tls.Config{Certificates: []tls.Certificate{cert}})
defer ln.Close()

result := FullCheckHostname("", ln.Addr().String(), testTimeout)
result := FullCheckHostname(context.Background(), "", ln.Addr().String(), testTimeout)

expected := Result{
Status: 2,
Expand All @@ -161,7 +162,7 @@ func TestNoTLS12(t *testing.T) {
})
defer ln.Close()

result := FullCheckHostname("", ln.Addr().String(), testTimeout)
result := FullCheckHostname(context.Background(), "", ln.Addr().String(), testTimeout)

expected := Result{
Status: 2,
Expand Down Expand Up @@ -194,7 +195,7 @@ func TestSuccessWithFakeCA(t *testing.T) {
// conserving the port number.
addrParts := strings.Split(ln.Addr().String(), ":")
port := addrParts[len(addrParts)-1]
result := FullCheckHostname("", "localhost:"+port, testTimeout)
result := FullCheckHostname(context.Background(), "", "localhost:"+port, testTimeout)
expected := Result{
Status: 0,
Checks: map[string]*Result{
Expand All @@ -217,7 +218,7 @@ func TestSuccessWithDelayedGreeting(t *testing.T) {
defer ln.Close()
go ServeDelayedGreeting(ln, t)

client, err := smtpDialWithTimeout(ln.Addr().String(), testTimeout)
client, err := smtpDialWithTimeout(context.Background(), ln.Addr().String(), testTimeout)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -269,7 +270,7 @@ func TestFailureWithBadHostname(t *testing.T) {
// conserving the port number.
addrParts := strings.Split(ln.Addr().String(), ":")
port := addrParts[len(addrParts)-1]
result := FullCheckHostname("", "localhost:"+port, testTimeout)
result := FullCheckHostname(context.Background(), "", "localhost:"+port, testTimeout)
expected := Result{
Status: 2,
Checks: map[string]*Result{
Expand Down Expand Up @@ -309,7 +310,7 @@ func TestAdvertisedCiphers(t *testing.T) {

ln := smtpListenAndServe(t, tlsConfig)
defer ln.Close()
FullCheckHostname("", ln.Addr().String(), testTimeout)
FullCheckHostname(context.Background(), "", ln.Addr().String(), testTimeout)

// Partial list of ciphers we want to support
expectedCipherSuites := []struct {
Expand Down Expand Up @@ -340,14 +341,13 @@ func compareStatuses(t *testing.T, expected Result, result HostnameResult) {
if result.Status != expected.Status {
t.Errorf("hostname status = %d, want %d", result.Status, expected.Status)
}

if len(result.Checks) > len(expected.Checks) {
t.Errorf("result contains too many checks\n expected %v\n want %v", result.Checks, expected.Checks)
}

for _, c := range expected.Checks {
if got := result.Checks[c.Name].Status; got != c.Status {
t.Errorf("%s status = %d, want %d", c.Name, got, c.Status)
got, ok := result.Checks[c.Name]
if !ok || got.Status != c.Status {
t.Errorf("%s check result was %v, want status %d", c.Name, got, c.Status)
}
}
}
Expand Down
Loading