diff --git a/mint/invoicesub.go b/mint/invoicesub.go index 64d1370..67545e6 100644 --- a/mint/invoicesub.go +++ b/mint/invoicesub.go @@ -1,7 +1,9 @@ package mint import ( + "context" "encoding/json" + "errors" "time" "github.com/elnosh/gonuts/cashu/nuts/nut04" @@ -10,14 +12,14 @@ import ( // checkInvoicePaid should be called in a different goroutine to check in the background // if the invoice for the quoteId gets paid and update it in the db. -func (m *Mint) checkInvoicePaid(quoteId string) { +func (m *Mint) checkInvoicePaid(ctx context.Context, quoteId string) { mintQuote, err := m.db.GetMintQuote(quoteId) if err != nil { m.logErrorf("could not get mint quote '%v' from db: %v", quoteId, err) return } - invoiceSub, err := m.lightningClient.SubscribeInvoice(mintQuote.PaymentHash) + invoiceSub, err := m.lightningClient.SubscribeInvoice(ctx, mintQuote.PaymentHash) if err != nil { m.logErrorf("could not subscribe to invoice changes for mint quote '%v': %v", quoteId, err) return @@ -56,7 +58,11 @@ func (m *Mint) checkInvoicePaid(quoteId string) { m.publisher.Publish(BOLT11_MINT_QUOTE_TOPIC, jsonQuote) } case err := <-errChan: - m.logErrorf("error reading from invoice subscription: %v", err) + if errors.Is(ctx.Err(), context.Canceled) { + m.logDebugf("canceling invoice subscription for quote '%v'. Context canceled", mintQuote.Id) + } else { + m.logErrorf("error reading from invoice subscription: %v", err) + } case <-time.After(time.Second * time.Duration(timeUntilExpiry)): // cancel when quote reaches expiry time m.logDebugf("canceling invoice subscription for quote '%v'. Reached deadline", mintQuote.Id) diff --git a/mint/lightning/fakebackend.go b/mint/lightning/fakebackend.go index 645f48c..80e0b36 100644 --- a/mint/lightning/fakebackend.go +++ b/mint/lightning/fakebackend.go @@ -157,7 +157,7 @@ func (fb *FakeBackend) FeeReserve(amount uint64) uint64 { return 0 } -func (fb *FakeBackend) SubscribeInvoice(paymentHash string) (InvoiceSubscriptionClient, error) { +func (fb *FakeBackend) SubscribeInvoice(ctx context.Context, paymentHash string) (InvoiceSubscriptionClient, error) { return &FakeInvoiceSub{ paymentHash: paymentHash, fb: fb, diff --git a/mint/lightning/lightning.go b/mint/lightning/lightning.go index 9566249..4ae2377 100644 --- a/mint/lightning/lightning.go +++ b/mint/lightning/lightning.go @@ -11,7 +11,7 @@ type Client interface { PayPartialAmount(ctx context.Context, request string, amountMsat uint64, maxFee uint64) (PaymentStatus, error) OutgoingPaymentStatus(ctx context.Context, hash string) (PaymentStatus, error) FeeReserve(amount uint64) uint64 - SubscribeInvoice(paymentHash string) (InvoiceSubscriptionClient, error) + SubscribeInvoice(ctx context.Context, paymentHash string) (InvoiceSubscriptionClient, error) } type Invoice struct { diff --git a/mint/lightning/lnd.go b/mint/lightning/lnd.go index 8dfb093..7a6b1a7 100644 --- a/mint/lightning/lnd.go +++ b/mint/lightning/lnd.go @@ -249,7 +249,7 @@ func (lnd *LndClient) FeeReserve(amount uint64) uint64 { return uint64(fee) } -func (lnd *LndClient) SubscribeInvoice(paymentHash string) (InvoiceSubscriptionClient, error) { +func (lnd *LndClient) SubscribeInvoice(ctx context.Context, paymentHash string) (InvoiceSubscriptionClient, error) { hash, err := hex.DecodeString(paymentHash) if err != nil { return nil, err @@ -257,7 +257,7 @@ func (lnd *LndClient) SubscribeInvoice(paymentHash string) (InvoiceSubscriptionC invoiceSubRequest := &invoicesrpc.SubscribeSingleInvoiceRequest{ RHash: hash, } - lndInvoiceClient, err := lnd.invoicesClient.SubscribeSingleInvoice(context.Background(), invoiceSubRequest) + lndInvoiceClient, err := lnd.invoicesClient.SubscribeSingleInvoice(ctx, invoiceSubRequest) if err != nil { return nil, err } diff --git a/mint/mint.go b/mint/mint.go index 7e62d5c..df2af06 100644 --- a/mint/mint.go +++ b/mint/mint.go @@ -60,6 +60,8 @@ type Mint struct { mppEnabled bool publisher *pubsub.PubSub + ctx context.Context + cancel context.CancelFunc } func LoadMint(config Config) (*Mint, error) { @@ -107,6 +109,7 @@ func LoadMint(config Config) (*Mint, error) { return nil, err } logger.Info(fmt.Sprintf("setting active keyset '%v' with fee %v", activeKeyset.Id, activeKeyset.InputFeePpk)) + ctx, cancel := context.WithCancel(context.Background()) mint := &Mint{ db: db, @@ -115,6 +118,8 @@ func LoadMint(config Config) (*Mint, error) { logger: logger, mppEnabled: config.EnableMPP, publisher: pubsub.NewPubSub(), + ctx: ctx, + cancel: cancel, } dbKeysets, err := mint.db.GetKeysets() @@ -248,6 +253,11 @@ func (m *Mint) logDebugf(format string, args ...any) { _ = m.logger.Handler().Handle(context.Background(), r) } +func (m *Mint) Shutdown() error { + m.cancel() + return m.db.Close() +} + // RequestMintQuote will process a request to mint tokens // and returns a mint quote or an error. // The request to mint a token is explained in @@ -306,7 +316,7 @@ func (m *Mint) RequestMintQuote(mintQuoteRequest nut04.PostMintQuoteBolt11Reques } // goroutine to check in the background when invoice gets paid and update db if so - go m.checkInvoicePaid(quoteId) + go m.checkInvoicePaid(m.ctx, quoteId) return mintQuote, nil } diff --git a/mint/server.go b/mint/server.go index 0580200..c343f2a 100644 --- a/mint/server.go +++ b/mint/server.go @@ -64,10 +64,18 @@ func SetupMintServer(config Config) (*MintServer, error) { return mintServer, nil } -func (ms *MintServer) Shutdown() { +func (ms *MintServer) Shutdown() error { ms.mint.logger.Info("starting shutdown") - ms.mint.db.Close() - ms.httpServer.Shutdown(context.Background()) + if err := ms.mint.Shutdown(); err != nil { + return err + } + if err := ms.websocketManager.Shutdown(); err != nil { + return err + } + if err := ms.httpServer.Shutdown(context.Background()); err != nil { + return err + } + return nil } func (ms *MintServer) setupHttpServer(port int) error { diff --git a/mint/storage/sqlite/sqlite.go b/mint/storage/sqlite/sqlite.go index bb555ba..d8cd3cf 100644 --- a/mint/storage/sqlite/sqlite.go +++ b/mint/storage/sqlite/sqlite.go @@ -99,8 +99,8 @@ func InitSQLite(path string) (*SQLiteDB, error) { return &SQLiteDB{db: db}, nil } -func (sqlite *SQLiteDB) Close() { - sqlite.db.Close() +func (sqlite *SQLiteDB) Close() error { + return sqlite.db.Close() } func (sqlite *SQLiteDB) GetBalance() (uint64, error) { diff --git a/mint/storage/storage.go b/mint/storage/storage.go index 34c0cea..5e2f97f 100644 --- a/mint/storage/storage.go +++ b/mint/storage/storage.go @@ -38,7 +38,7 @@ type MintDB interface { GetBlindSignature(B_ string) (cashu.BlindedSignature, error) GetBlindSignatures(B_s []string) (cashu.BlindedSignatures, error) - Close() + Close() error } type DBKeyset struct { diff --git a/mint/websocket.go b/mint/websocket.go index 7407c31..ee0c91f 100644 --- a/mint/websocket.go +++ b/mint/websocket.go @@ -30,8 +30,8 @@ var upgrader = websocket.Upgrader{ type WebsocketManager struct { clients map[*Client]bool - sync.RWMutex - mint *Mint + mu sync.RWMutex + mint *Mint } func NewWebSocketManager(mint *Mint) *WebsocketManager { @@ -58,18 +58,30 @@ func (wm *WebsocketManager) serveWS(w http.ResponseWriter, r *http.Request) { } func (wm *WebsocketManager) addClient(client *Client) { - wm.Lock() + wm.mu.Lock() wm.clients[client] = true - wm.Unlock() + wm.mu.Unlock() } -func (wm *WebsocketManager) removeClient(client *Client) { - wm.Lock() +func (wm *WebsocketManager) removeClient(client *Client) error { + wm.mu.Lock() if _, ok := wm.clients[client]; ok { - client.close() + if err := client.close(); err != nil { + return err + } delete(wm.clients, client) } - wm.Unlock() + wm.mu.Unlock() + return nil +} + +func (wm *WebsocketManager) Shutdown() error { + for client := range wm.clients { + if err := wm.removeClient(client); err != nil { + return err + } + } + return nil } type Client struct {