Skip to content

Commit

Permalink
Refactor DeviceAgent EventLoop to use a StateMachine (#345)
Browse files Browse the repository at this point in the history
Replaced EventLoop with a Statemachine with clearer separation of concerns.
Added happy path tests.

Co-authored-by: Vegar Sechmann Molvig <[email protected]>
  • Loading branch information
mortenlj and sechmann authored Dec 21, 2023
1 parent 6da7574 commit c3cbe06
Show file tree
Hide file tree
Showing 24 changed files with 1,703 additions and 577 deletions.
2 changes: 2 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ packages:
interfaces:
APIServerClient:
APIServer_GetGatewayConfigurationClient:
APIServer_GetDeviceConfigurationClient:
DeviceHelperClient:
github.com/nais/device/internal/wireguard:
interfaces:
NetworkConfigurer:
Expand Down
116 changes: 114 additions & 2 deletions cmd/device-agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package main

import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"os/signal"
"path/filepath"
Expand All @@ -18,13 +20,19 @@ import (
"github.com/nais/device/internal/device-agent/config"
"github.com/nais/device/internal/device-agent/filesystem"
"github.com/nais/device/internal/device-agent/runtimeconfig"
"github.com/nais/device/internal/device-agent/statemachine"
"github.com/nais/device/internal/logger"
"github.com/nais/device/internal/notify"
"github.com/nais/device/internal/pb"
"github.com/nais/device/internal/unixsocket"
"github.com/nais/device/internal/version"
)

const (
healthCheckInterval = 20 * time.Second // how often to healthcheck gateways
versionCheckInterval = 1 * time.Hour // how often to check for a new version of naisdevice
)

func handleSignals(log *logrus.Entry, cancel context.CancelFunc) {
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
Expand Down Expand Up @@ -81,6 +89,9 @@ func main() {
}

func run(ctx context.Context, log *logrus.Entry, cfg *config.Config, notifier notify.Notifier) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

if err := filesystem.EnsurePrerequisites(cfg); err != nil {
return fmt.Errorf("missing prerequisites: %s", err)
}
Expand Down Expand Up @@ -111,18 +122,55 @@ func run(ctx context.Context, log *logrus.Entry, cfg *config.Config, notifier no
client := pb.NewDeviceHelperClient(connection)
defer connection.Close()

go func() {
for ctx.Err() == nil {
select {
case <-ctx.Done():
return
case <-time.After(healthCheckInterval):
err = helperHealthCheck(ctx, client)
if err != nil {
log.WithError(err).Errorf("Unable to communicate with helper. Shutting down")
notifier.Errorf("Unable to communicate with helper. Shutting down.")
cancel()
}
}
}
}()

listener, err := unixsocket.ListenWithFileMode(cfg.GrpcAddress, 0o666)
if err != nil {
return err
}
log.Infof("accepting network connections on unix socket %s", cfg.GrpcAddress)

statusChannel := make(chan *pb.AgentStatus, 32)
stateMachine := statemachine.NewStateMachine(ctx, rc, *cfg, notifier, client, statusChannel, log.WithField("component", "statemachine"))

grpcServer := grpc.NewServer()
das := deviceagent.NewServer(log.WithField("component", "device-agent-server"), client, cfg, rc, notifier)
das := deviceagent.NewServer(ctx, log.WithField("component", "device-agent-server"), cfg, rc, notifier, stateMachine.SendEvent)
pb.RegisterDeviceAgentServer(grpcServer, das)

newVersionChannel := make(chan bool, 1)
go versionChecker(ctx, newVersionChannel, notifier, log)

go func() {
das.EventLoop(ctx)
// This routine forwards status updates from the state machine to the device agent server
newVersionAvailable := false
for ctx.Err() == nil {
select {
case newVersionAvailable = <-newVersionChannel:
case s := <-statusChannel:
s.NewVersionAvailable = newVersionAvailable
das.UpdateAgentStatus(s)
case <-ctx.Done():
}
}
}()

go func() {
stateMachine.Run(ctx)
// after state machine is done, stop the grpcServer
grpcServer.Stop()
}()

Expand All @@ -135,3 +183,67 @@ func run(ctx context.Context, log *logrus.Entry, cfg *config.Config, notifier no

return nil
}

func versionChecker(ctx context.Context, newVersionChannel chan<- bool, notifier notify.Notifier, logger logrus.FieldLogger) {
versionCheckTimer := time.NewTimer(versionCheckInterval)
for ctx.Err() == nil {
select {
case <-ctx.Done():
return
case <-versionCheckTimer.C:
newVersionAvailable, err := checkNewVersionAvailable(ctx)
if err != nil {
logger.Infof("check for new version: %s", err)
break
}

newVersionChannel <- newVersionAvailable
if newVersionAvailable {
notifier.Infof("New version of device agent available: https://doc.nais.io/device/update/")
versionCheckTimer.Stop()
} else {
versionCheckTimer.Reset(versionCheckInterval)
}
}
}
}

func helperHealthCheck(ctx context.Context, client pb.DeviceHelperClient) error {
helperHealthCheckCtx, helperHealthCheckCancel := context.WithTimeout(ctx, 5*time.Second)
defer helperHealthCheckCancel()

if _, err := client.GetSerial(helperHealthCheckCtx, &pb.GetSerialRequest{}); err != nil {
return err
}
return nil
}

func checkNewVersionAvailable(ctx context.Context) (bool, error) {
type response struct {
Tag string `json:"tag_name"`
}

req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/repos/nais/device/releases/latest", nil)
if err != nil {
return false, err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, fmt.Errorf("retrieve current release version: %s", err)
}

defer resp.Body.Close()

res := &response{}
decoder := json.NewDecoder(resp.Body)
err = decoder.Decode(res)
if err != nil {
return false, fmt.Errorf("unmarshal response: %s", err)
}

if version.Version != res.Tag {
return true, nil
}

return false, nil
}
2 changes: 1 addition & 1 deletion internal/device-agent/auth/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type Tokens struct {
IDToken string
}

func GetDeviceAgentToken(ctx context.Context, log *logrus.Entry, conf oauth2.Config, authServer string) (*Tokens, error) {
func GetDeviceAgentToken(ctx context.Context, log logrus.FieldLogger, conf oauth2.Config, authServer string) (*Tokens, error) {
// Ignoring impossible error
codeVerifier, _ := codeverifier.CreateCodeVerifier()

Expand Down
10 changes: 0 additions & 10 deletions internal/device-agent/const_unix.go

This file was deleted.

8 changes: 0 additions & 8 deletions internal/device-agent/const_windows.go

This file was deleted.

53 changes: 19 additions & 34 deletions internal/device-agent/deviceagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package device_agent

import (
"context"
"fmt"
"sync"
"time"

"github.com/google/uuid"
"github.com/sirupsen/logrus"
Expand All @@ -15,42 +13,29 @@ import (

"github.com/nais/device/internal/device-agent/config"
"github.com/nais/device/internal/device-agent/runtimeconfig"
"github.com/nais/device/internal/device-agent/statemachine"
"github.com/nais/device/internal/pb"
)

type DeviceAgentServer struct {
pb.UnimplementedDeviceAgentServer
AgentStatus *pb.AgentStatus
DeviceHelper pb.DeviceHelperClient
lock sync.Mutex
stateChange chan pb.AgentState
statusChannels map[uuid.UUID]chan *pb.AgentStatus
Config *config.Config
rc runtimeconfig.RuntimeConfig
notifier notify.Notifier
rc runtimeconfig.RuntimeConfig
log *logrus.Entry
sendEvent func(statemachine.Event)
}

const maxLoginAttempts = 20

func (das *DeviceAgentServer) Login(ctx context.Context, request *pb.LoginRequest) (*pb.LoginResponse, error) {
var lastStatus pb.AgentState
for attempt := 1; attempt <= maxLoginAttempts; attempt += 1 {
lastStatus = das.AgentStatus.ConnectionState
if lastStatus == pb.AgentState_Disconnected {
das.stateChange <- pb.AgentState_Authenticating
return &pb.LoginResponse{}, nil
}

das.log.Debugf("[attempt %d/%d] device agent server login: agent not in correct state (state=%+v). wait 200ms and retry", attempt, maxLoginAttempts, lastStatus)
time.Sleep(200 * time.Millisecond)
}

return &pb.LoginResponse{}, fmt.Errorf("unable to connect, invalid state: %+v", lastStatus)
das.sendEvent(statemachine.EventLogin)
return &pb.LoginResponse{}, nil
}

func (das *DeviceAgentServer) Logout(ctx context.Context, request *pb.LogoutRequest) (*pb.LogoutResponse, error) {
das.stateChange <- pb.AgentState_Disconnecting
das.sendEvent(statemachine.EventDisconnect)
return &pb.LogoutResponse{}, nil
}

Expand All @@ -70,7 +55,7 @@ func (das *DeviceAgentServer) Status(request *pb.AgentStatusRequest, statusServe
das.log.Debugf("grpc: client connection with device helper closed")
if !request.GetKeepConnectionOnComplete() {
das.log.Debugf("grpc: keepalive not requested, tearing down connections...")
das.stateChange <- pb.AgentState_Disconnecting
das.sendEvent(statemachine.EventDisconnect)
}
das.lock.Lock()
close(agentStatusChan)
Expand Down Expand Up @@ -108,7 +93,6 @@ func (das *DeviceAgentServer) UpdateAgentStatus(status *pb.AgentStatus) {
func (das *DeviceAgentServer) SetAgentConfiguration(ctx context.Context, req *pb.SetAgentConfigurationRequest) (*pb.SetAgentConfigurationResponse, error) {
das.Config.AgentConfiguration = req.Config
das.Config.PersistAgentConfiguration(das.log)
das.stateChange <- pb.AgentState_AgentConfigurationChanged
return &pb.SetAgentConfigurationResponse{}, nil
}

Expand All @@ -120,29 +104,30 @@ func (das *DeviceAgentServer) GetAgentConfiguration(ctx context.Context, req *pb

func (das *DeviceAgentServer) SetActiveTenant(ctx context.Context, req *pb.SetActiveTenantRequest) (*pb.SetActiveTenantResponse, error) {
if err := das.rc.SetActiveTenant(req.Name); err != nil {
das.Notifier().Errorf("while activating tenant: %s", err)
das.stateChange <- pb.AgentState_Disconnecting
das.notifier.Errorf("while activating tenant: %s", err)
das.sendEvent(statemachine.EventDisconnect)
return &pb.SetActiveTenantResponse{}, nil
}

das.stateChange <- pb.AgentState_Disconnecting
das.sendEvent(statemachine.EventDisconnect)
das.log.Infof("activated tenant: %s", req.Name)
return &pb.SetActiveTenantResponse{}, nil
}

func (das *DeviceAgentServer) Notifier() notify.Notifier {
return das.notifier
}

func NewServer(log *logrus.Entry, helper pb.DeviceHelperClient, cfg *config.Config, rc runtimeconfig.RuntimeConfig, notifier notify.Notifier) *DeviceAgentServer {
func NewServer(ctx context.Context,
log *logrus.Entry,
cfg *config.Config,
rc runtimeconfig.RuntimeConfig,
notifier notify.Notifier,
sendEvent func(statemachine.Event),
) *DeviceAgentServer {
return &DeviceAgentServer{
DeviceHelper: helper,
log: log,
AgentStatus: &pb.AgentStatus{ConnectionState: pb.AgentState_Disconnected},
stateChange: make(chan pb.AgentState, 32),
statusChannels: make(map[uuid.UUID]chan *pb.AgentStatus),
Config: cfg,
rc: rc,
notifier: notifier,
log: log,
sendEvent: sendEvent,
}
}
Loading

0 comments on commit c3cbe06

Please sign in to comment.