diff --git a/cmd/yopass-server/main.go b/cmd/yopass-server/main.go index 3f0ccc4a6..8e9691fc3 100644 --- a/cmd/yopass-server/main.go +++ b/cmd/yopass-server/main.go @@ -71,7 +71,7 @@ func main() { go func() { addr := fmt.Sprintf("%s:%d", viper.GetString("address"), viper.GetInt("port")) logger.Info("Starting yopass server", zap.String("address", addr)) - y := server.New(db, viper.GetInt("max-length"), registry, viper.GetBool("force-onetime-secrets")) + y := server.New(db, viper.GetInt("max-length"), registry, viper.GetBool("force-onetime-secrets"), logger) errc <- listenAndServe(addr, y.HTTPHandler(), cert, key) }() diff --git a/pkg/server/server.go b/pkg/server/server.go index 731c75777..ac9fb4bb7 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -13,6 +13,7 @@ import ( "github.com/gorilla/mux" "github.com/jhaals/yopass/pkg/yopass" "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" ) // Server struct holding database and settings. @@ -22,15 +23,20 @@ type Server struct { maxLength int registry *prometheus.Registry forceOneTimeSecrets bool + logger *zap.Logger } // New is the main way of creating the server. -func New(db Database, maxLength int, r *prometheus.Registry, forceOneTimeSecrets bool) Server { +func New(db Database, maxLength int, r *prometheus.Registry, forceOneTimeSecrets bool, logger *zap.Logger) Server { + if logger == nil { + logger = zap.NewNop() + } return Server{ db: db, maxLength: maxLength, registry: r, forceOneTimeSecrets: forceOneTimeSecrets, + logger: logger, } } @@ -41,6 +47,7 @@ func (y *Server) createSecret(w http.ResponseWriter, request *http.Request) { decoder := json.NewDecoder(request.Body) var s yopass.Secret if err := decoder.Decode(&s); err != nil { + y.logger.Debug("Unable to decode request", zap.Error(err)) http.Error(w, `{"message": "Unable to parse json"}`, http.StatusBadRequest) return } @@ -63,6 +70,7 @@ func (y *Server) createSecret(w http.ResponseWriter, request *http.Request) { // Generate new UUID uuidVal, err := uuid.NewV4() if err != nil { + y.logger.Error("Unable to generate UUID", zap.Error(err)) http.Error(w, `{"message": "Unable to generate UUID"}`, http.StatusInternalServerError) return } @@ -70,31 +78,44 @@ func (y *Server) createSecret(w http.ResponseWriter, request *http.Request) { // store secret in memcache with specified expiration. if err := y.db.Put(key, s); err != nil { + y.logger.Error("Unable to store secret", zap.Error(err)) http.Error(w, `{"message": "Failed to store secret in database"}`, http.StatusInternalServerError) return } resp := map[string]string{"message": key} - jsonData, _ := json.Marshal(resp) - w.Write(jsonData) + jsonData, err := json.Marshal(resp) + if err != nil { + y.logger.Error("Failed to marshal create secret response", zap.Error(err), zap.String("key", key)) + } + + if _, err = w.Write(jsonData); err != nil { + y.logger.Error("Failed to write response", zap.Error(err), zap.String("key", key)) + } } // getSecret from database func (y *Server) getSecret(w http.ResponseWriter, request *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") - secret, err := y.db.Get(mux.Vars(request)["key"]) + secretKey := mux.Vars(request)["key"] + secret, err := y.db.Get(secretKey) if err != nil { + y.logger.Debug("Secret not found", zap.Error(err), zap.String("key", secretKey)) http.Error(w, `{"message": "Secret not found"}`, http.StatusNotFound) return } data, err := secret.ToJSON() if err != nil { + y.logger.Error("Failed to encode request", zap.Error(err), zap.String("key", secretKey)) http.Error(w, `{"message": "Failed to encode secret"}`, http.StatusInternalServerError) return } - w.Write(data) + + if _, err := w.Write(data); err != nil { + y.logger.Error("Failed to write response", zap.Error(err), zap.String("key", secretKey)) + } } // HTTPHandler containing all routes diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 7f26bb1e9..82f841d63 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -12,6 +12,7 @@ import ( "github.com/jhaals/yopass/pkg/yopass" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" + "go.uber.org/zap/zaptest" ) type mockDB struct{} @@ -103,7 +104,7 @@ func TestCreateSecret(t *testing.T) { t.Run(fmt.Sprintf(tc.name), func(t *testing.T) { req, _ := http.NewRequest("POST", "/secret", tc.body) rr := httptest.NewRecorder() - y := New(tc.db, tc.maxLength, prometheus.NewRegistry(), false) + y := New(tc.db, tc.maxLength, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) y.createSecret(rr, req) var s yopass.Secret json.Unmarshal(rr.Body.Bytes(), &s) @@ -160,7 +161,7 @@ func TestOneTimeEnforcement(t *testing.T) { t.Run(fmt.Sprintf(tc.name), func(t *testing.T) { req, _ := http.NewRequest("POST", "/secret", tc.body) rr := httptest.NewRecorder() - y := New(&mockDB{}, 100, prometheus.NewRegistry(), tc.requireOneTime) + y := New(&mockDB{}, 100, prometheus.NewRegistry(), tc.requireOneTime, zaptest.NewLogger(t)) y.createSecret(rr, req) var s yopass.Secret json.Unmarshal(rr.Body.Bytes(), &s) @@ -204,7 +205,7 @@ func TestGetSecret(t *testing.T) { t.Fatal(err) } rr := httptest.NewRecorder() - y := New(tc.db, 1, prometheus.NewRegistry(), false) + y := New(tc.db, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) y.getSecret(rr, req) var s yopass.Secret json.Unmarshal(rr.Body.Bytes(), &s) @@ -232,7 +233,7 @@ func TestMetrics(t *testing.T) { path: "/secret/invalid-key-format", }, } - y := New(&mockDB{}, 1, prometheus.NewRegistry(), false) + y := New(&mockDB{}, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) h := y.HTTPHandler() for _, r := range requests { @@ -305,7 +306,7 @@ func TestSecurityHeaders(t *testing.T) { }, } - y := New(&mockDB{}, 1, prometheus.NewRegistry(), false) + y := New(&mockDB{}, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) h := y.HTTPHandler() t.Parallel() diff --git a/pkg/yopass/client_test.go b/pkg/yopass/client_test.go index 37b0e5a66..9366c618a 100644 --- a/pkg/yopass/client_test.go +++ b/pkg/yopass/client_test.go @@ -3,6 +3,7 @@ package yopass_test import ( "errors" "fmt" + "go.uber.org/zap/zaptest" "net/http/httptest" "testing" @@ -13,7 +14,7 @@ import ( func TestFetch(t *testing.T) { db := testDB(map[string]string{}) - y := server.New(&db, 1024, prometheus.NewRegistry(), false) + y := server.New(&db, 1024, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) ts := httptest.NewServer(y.HTTPHandler()) defer ts.Close() @@ -45,7 +46,7 @@ func TestFetchInvalidServer(t *testing.T) { } func TestStore(t *testing.T) { db := testDB(map[string]string{}) - y := server.New(&db, 1024, prometheus.NewRegistry(), false) + y := server.New(&db, 1024, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) ts := httptest.NewServer(y.HTTPHandler()) defer ts.Close()