diff --git a/app-prod.yaml b/app-prod.yaml index cebb5385..d0e13111 100644 --- a/app-prod.yaml +++ b/app-prod.yaml @@ -10,3 +10,4 @@ env_variables: APP_ID: 119816 KEY_SECRET: "gcpsecretmanager://projects/allstar-ossf/secrets/allstar-private-key?decoder=bytes" DO_NOTHING_ON_OPT_OUT: true + ALLSTAR_NUM_WORKERS: 1 diff --git a/app-staging.yaml b/app-staging.yaml index 081b9c48..a9623f7b 100644 --- a/app-staging.yaml +++ b/app-staging.yaml @@ -10,3 +10,4 @@ resources: env_variables: APP_ID: 166485 KEY_SECRET: "gcpsecretmanager://projects/allstar-ossf/secrets/allstar-staging-private-key?decoder=bytes" + ALLSTAR_NUM_WORKERS: 1 diff --git a/cmd/allstar/main.go b/cmd/allstar/main.go index 357a11f0..b57c7b6a 100644 --- a/cmd/allstar/main.go +++ b/cmd/allstar/main.go @@ -62,8 +62,6 @@ func main() { specificPolicyArg := flag.String("policy", "", fmt.Sprintf("Run a specific policy check. Supported policies: %s", supportedPoliciesMsg)) specificRepoArg := flag.String("repo", "", "Run on a specific \"owner/repo\". For example \"ossf/allstar\"") - numWorkersArg := flag.Int("workers", 5, "maximum number of active goroutines for Allstar scans") - flag.Parse() if *specificPolicyArg != "" { @@ -83,7 +81,7 @@ func main() { } if runOnce { - _, err := enforce.EnforceAll(ctx, ghc, *specificPolicyArg, *specificRepoArg, *numWorkersArg) + _, err := enforce.EnforceAll(ctx, ghc, *specificPolicyArg, *specificRepoArg) if err != nil { log.Fatal(). Err(err). @@ -96,7 +94,7 @@ func main() { go func() { defer wg.Done() log.Info(). - Err(enforce.EnforceJob(ctx, ghc, (5 * time.Minute), *specificPolicyArg, *specificRepoArg, *numWorkersArg)). + Err(enforce.EnforceJob(ctx, ghc, (5 * time.Minute), *specificPolicyArg, *specificRepoArg)). Msg("Enforce job shutting down.") }() sigs := make(chan os.Signal, 1) diff --git a/pkg/config/operator/operator.go b/pkg/config/operator/operator.go index d0039fa9..636138a9 100644 --- a/pkg/config/operator/operator.go +++ b/pkg/config/operator/operator.go @@ -95,6 +95,12 @@ const setNoticePingDurationHrs = (24 * time.Hour) var NoticePingDuration time.Duration +// NumWorkers is the number of concurrent orginazations/installations the +// Allstar binary will scan concurrently. +const setNumWorkers = 5 + +var NumWorkers int + var osGetenv func(string) string func init() { @@ -147,4 +153,12 @@ func setVars() { allowedOrgs := osGetenv("GITHUB_ALLOWED_ORGS") AllowedOrganizations = strings.Split(allowedOrgs, ",") + + nws := osGetenv("ALLSTAR_NUM_WORKERS") + nw, err := strconv.Atoi(nws) + if err == nil { + NumWorkers = nw + } else { + NumWorkers = setNumWorkers + } } diff --git a/pkg/enforce/enforce.go b/pkg/enforce/enforce.go index 137b8f1c..684ec28f 100644 --- a/pkg/enforce/enforce.go +++ b/pkg/enforce/enforce.go @@ -67,7 +67,7 @@ func init() { // // TBD: determine if this should remain exported, or if it will only be called // from EnforceJob. -func EnforceAll(ctx context.Context, ghc ghclients.GhClientsInterface, specificPolicyArg string, specificRepoArg string, numWorkersArg int) (EnforceAllResults, error) { +func EnforceAll(ctx context.Context, ghc ghclients.GhClientsInterface, specificPolicyArg string, specificRepoArg string) (EnforceAllResults, error) { var repoCount int var enforceAllResults = make(EnforceAllResults) ac, err := ghc.Get(0) @@ -85,10 +85,13 @@ func EnforceAll(ctx context.Context, ghc ghclients.GhClientsInterface, specificP Msg("Enforcing policies on installations.") g, ctx := errgroup.WithContext(ctx) - g.SetLimit(numWorkersArg) + g.SetLimit(operator.NumWorkers) var mu sync.Mutex for _, i := range insts { + if ctx.Err() != nil { + break + } if i.SuspendedAt != nil { log.Info(). Str("area", "bot"). @@ -302,9 +305,9 @@ func getAppInstallationReposReal(ctx context.Context, ic *github.Client) ([]*git // EnforceJob is a reconciliation job that enforces policies on all repos every // d duration. It runs forever until the context is done. -func EnforceJob(ctx context.Context, ghc *ghclients.GHClients, d time.Duration, specificPolicyArg string, specificRepoArg string, numWorkersArg int) error { +func EnforceJob(ctx context.Context, ghc *ghclients.GHClients, d time.Duration, specificPolicyArg string, specificRepoArg string) error { for { - _, err := EnforceAll(ctx, ghc, specificPolicyArg, specificRepoArg, numWorkersArg) + _, err := EnforceAll(ctx, ghc, specificPolicyArg, specificRepoArg) if err != nil { log.Error(). Err(err). diff --git a/pkg/enforce/enforce_test.go b/pkg/enforce/enforce_test.go index c259d371..28e97d7c 100644 --- a/pkg/enforce/enforce_test.go +++ b/pkg/enforce/enforce_test.go @@ -549,8 +549,7 @@ func TestEnforceAll(t *testing.T) { policy1Results = test.Policy1Results policy2Results = test.Policy2Results - numWorkers := 1 - enforceAllResults, err := EnforceAll(context.Background(), mockGhc, "", "", numWorkers) + enforceAllResults, err := EnforceAll(context.Background(), mockGhc, "", "") if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -582,8 +581,7 @@ func TestSuspendedEnforce(t *testing.T) { } suspended = false gaicalled = false - numWorkers := 1 - if _, err := EnforceAll(context.Background(), &MockGhClients{}, "", "", numWorkers); err != nil { + if _, err := EnforceAll(context.Background(), &MockGhClients{}, "", ""); err != nil { t.Fatalf("Unexpected error: %v", err) } if !gaicalled { @@ -591,7 +589,7 @@ func TestSuspendedEnforce(t *testing.T) { } suspended = true gaicalled = false - if _, err := EnforceAll(context.Background(), &MockGhClients{}, "", "", numWorkers); err != nil { + if _, err := EnforceAll(context.Background(), &MockGhClients{}, "", ""); err != nil { t.Fatalf("Unexpected error: %v", err) } if gaicalled {