diff --git a/.vimbin.yaml b/.vimbin.yaml new file mode 100644 index 0000000..101614c --- /dev/null +++ b/.vimbin.yaml @@ -0,0 +1,4 @@ +server: + api: + address: http://localhost:8080 + token: mytoken diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..487c638 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,20 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Launch file", + "type": "go", + "request": "launch", + "mode": "debug", + "program": "main.go", + "args": [ + "--config" + ,"./.vimbin.yaml" + ,"--trace" + ] + } + ] +} \ No newline at end of file diff --git a/cmd/root.go b/cmd/root.go index 3082dc6..665bb2a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -19,6 +19,7 @@ import ( "fmt" "os" "time" + "vimbin/internal/config" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -28,7 +29,7 @@ import ( ) const ( - version = "v0.0.4" + version = "v0.0.5" ) var ( @@ -84,7 +85,7 @@ var rootCmd = &cobra.Command{ os.Exit(0) } else { // Display the root command's help message - cmd.Help() + _ = cmd.Help() // Make the linter happy } }, } @@ -127,8 +128,7 @@ func initConfig() { viper.AutomaticEnv() // Read in environment variables that match - // If a config file is found, read it in. - if err := viper.ReadInConfig(); err == nil { + if err := config.App.Read(viper.ConfigFileUsed()); err != nil { log.Fatal().Msgf("Error reading config file: %v", err) } } diff --git a/cmd/serve.go b/cmd/serve.go index 53b9f74..0fb9449 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -63,7 +63,7 @@ var serveCmd = &cobra.Command{ // Collect handlers and start the server handlers.Collect() - server.Run(config.App.Server.Web.Address) + server.Run(config.App.Server.Web.Address, config.App.Server.Api.Token.Get()) }, } diff --git a/deploy/ingress.yaml b/deploy/ingress.yaml index 715b88a..996d568 100644 --- a/deploy/ingress.yaml +++ b/deploy/ingress.yaml @@ -16,4 +16,4 @@ spec: service: name: vimbin port: - name: vimbin + name: http diff --git a/internal/config/parse.go b/internal/config/parse.go index cd3bf57..ade38ac 100644 --- a/internal/config/parse.go +++ b/internal/config/parse.go @@ -5,6 +5,8 @@ import ( "os" "path" "vimbin/internal/utils" + + "github.com/rs/zerolog/log" ) // Parse reads and processes the configuration settings. @@ -48,5 +50,13 @@ func (c *Config) Parse() (err error) { return fmt.Errorf("Unable to extract hostname and port: %s", err) } + // Check if the API token is valid + if c.Server.Api.Token.Get() == "" { + if err := c.Server.Api.Token.Generate(32); err != nil { + return fmt.Errorf("Unable to generate API token: %s", err) + } + log.Debug().Msgf("Generated API token: %s", c.Server.Api.Token.Get()) + } + return nil } diff --git a/internal/config/parse_test.go b/internal/config/parse_test.go index 7eca81e..74b6080 100644 --- a/internal/config/parse_test.go +++ b/internal/config/parse_test.go @@ -52,6 +52,4 @@ func TestParse(t *testing.T) { if cfg.Storage.Path != tempStoragePath { t.Errorf("Storage path not set correctly. Expected: %s, Got: %s", tempStoragePath, cfg.Storage.Path) } - - // Add more assertions based on your specific requirements } diff --git a/internal/config/read.go b/internal/config/read.go index 85d9153..ea45583 100644 --- a/internal/config/read.go +++ b/internal/config/read.go @@ -27,8 +27,12 @@ func (c *Config) Read(configPath string) error { return fmt.Errorf("Failed to read config file: %v", err) } - // Unmarshal the config into the Config struct - if err := viper.Unmarshal(c, func(d *mapstructure.DecoderConfig) { d.ZeroFields = true }); err != nil { + if err := viper.Unmarshal(c, func(d *mapstructure.DecoderConfig) { + d.ZeroFields = true // Zero out any existing fields + d.DecodeHook = mapstructure.ComposeDecodeHookFunc( + customTokenDecodeHook, // Custom decoder hook for the Token field + ) + }); err != nil { return fmt.Errorf("Failed to unmarshal config file: %v", err) } diff --git a/internal/config/read_test.go b/internal/config/read_test.go index c474d94..202c7e9 100644 --- a/internal/config/read_test.go +++ b/internal/config/read_test.go @@ -68,6 +68,6 @@ server: cfg := &Config{} err = cfg.Read(filePath.Name()) assert.Error(t, err) // We expect an error because the file has incorrect content - assert.Contains(t, err.Error(), "Failed to unmarshal config file: 1 error(s) decoding:\n\n* cannot parse 'Server.Api.SkipInsecureVerify' as bool: strconv.ParseBool: parsing \"not_a_boolean\": invalid syntax") + assert.EqualError(t, err, "Failed to unmarshal config file: 1 error(s) decoding:\n\n* cannot parse 'server.api.skipInsecureVerify' as bool: strconv.ParseBool: parsing \"not_a_boolean\": invalid syntax") }) } diff --git a/internal/config/structs.go b/internal/config/structs.go index 9661f83..0c051c4 100644 --- a/internal/config/structs.go +++ b/internal/config/structs.go @@ -3,6 +3,7 @@ package config import ( "sync" "text/template" + "vimbin/internal/utils" ) // App is the global configuration instance. @@ -10,40 +11,84 @@ var App Config // Config represents the application configuration. type Config struct { - HtmlTemplate *template.Template `yaml:"-"` // HtmlTemplate contains the HTML template content. - Server Server `yaml:"server"` // Server represents the server configuration. - Storage Storage `yaml:"storage"` // Storage represents the storage configuration. + HtmlTemplate *template.Template `mapstructure:"-"` // HtmlTemplate contains the HTML template content. + Server Server `mapstructure:"server"` // Server represents the server configuration. + Storage Storage `mapstructure:"storage"` // Storage represents the storage configuration. } // Web represents the web configuration. type Web struct { - Theme string `yaml:"server"` // Theme is the theme to use for the web interface. - Address string `yaml:"address"` // Address is the address to listen on for HTTP requests. + Theme string `mapstructure:"server"` // Theme is the theme to use for the web interface. + Address string `mapstructure:"address"` // Address is the address to listen on for HTTP requests. } +// Token represents the API token. +type Token struct { + value string +} + +// Get retrieves the current token value. +// +// Returns: +// - string +// The current token value. +func (t *Token) Get() string { + return t.value +} + +// Set sets the token to the specified value. +// +// Parameters: +// - token: string +// The value to set as the token. +func (t *Token) Set(token string) { + t.value = token +} + +// Generate generates a new random token of the specified length and updates the token value. +// +// Parameters: +// - len: int +// The length of the new token. +// +// Returns: +// - error +// An error, if any, encountered during token generation. +func (t *Token) Generate(len int) error { + tokenString, err := utils.GenerateRandomToken(len) + if err != nil { + return err + } + t.value = tokenString + + return nil +} + +// Api represents the api configuration. type Api struct { - SkipInsecureVerify bool `yaml:"skipInsecureVerify"` // SkipInsecureVerify skips the verification of TLS certificates. - Address string `yaml:"address"` // Address is the address to push/fetch content from. + Token Token `mapstructure:"token"` // Token is the API token. + SkipInsecureVerify bool `mapstructure:"skipInsecureVerify"` // SkipInsecureVerify skips the verification of TLS certificates. + Address string `mapstructure:"address"` // Address is the address to push/fetch content from. } // Server represents the server configuration. type Server struct { - Web Web `yaml:"web"` // Web represents the web configuration. - Api Api `yaml:"api"` // Api represents the api configuration. + Web Web `mapstructure:"web"` // Web represents the web configuration. + Api Api `mapstructure:"api"` // Api represents the api configuration. } // Storage represents the storage configuration. type Storage struct { - Name string `yaml:"name"` // Name is the name of the storage file. - Directory string `yaml:"directory"` // Directory is the directory path for storage file. - Path string `yaml:"-"` // Path is the full path to the storage file. - Content Content `yaml:"-"` // Content represents the content stored in the storage file. + Name string `mapstructure:"name"` // Name is the name of the storage file. + Directory string `mapstructure:"directory"` // Directory is the directory path for storage file. + Path string `mapstructure:"-"` // Path is the full path to the storage file. + Content Content `mapstructure:"-"` // Content represents the content stored in the storage file. } // Content represents the content stored in the storage with thread-safe methods. type Content struct { - text string `yaml:"-"` // text is the stored content. - mutext sync.RWMutex `yaml:"-"` // mutext is a read-write mutex for concurrent access control. + text string `mapstructure:"-"` // text is the stored content. + mutext sync.RWMutex `mapstructure:"-"` // mutext is a read-write mutex for concurrent access control. } // Set sets the content to the specified text. diff --git a/internal/config/utils.go b/internal/config/utils.go index 581d909..b7f4a51 100644 --- a/internal/config/utils.go +++ b/internal/config/utils.go @@ -3,7 +3,9 @@ package config import ( "fmt" "os" + "reflect" + "github.com/mitchellh/mapstructure" "github.com/rs/zerolog/log" ) @@ -58,3 +60,37 @@ func checkStorageFile(filePath string) error { return nil } + +// customTokenDecodeHook is a custom mapstructure DecodeHookFunc for decoding YAML data +// into the Token struct. It converts the data into a string and initializes a Token with it. +// +// Parameters: +// - from: reflect.Type +// The type of the source data. +// - to: reflect.Type +// The type of the target data. +// - data: interface{} +// The data to be decoded. +// +// Returns: +// - interface{} +// The decoded data. +// - error +// An error, if any, encountered during the decoding process. +func customTokenDecodeHook(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) { + // If the target type is not Token, return the data as is + if to != reflect.TypeOf(Token{}) { + return data, nil + } + + var tokenValue string + // Decode the data into a string + if err := mapstructure.Decode(data, &tokenValue); err != nil { + return nil, fmt.Errorf("Unable to decode Token. %v", err) + } + + // Initialize a Token with the decoded string + var token Token + token.Set(tokenValue) + return token, nil +} diff --git a/internal/config/utils_test.go b/internal/config/utils_test.go index cb85473..7d07400 100644 --- a/internal/config/utils_test.go +++ b/internal/config/utils_test.go @@ -2,6 +2,7 @@ package config import ( "os" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -60,3 +61,42 @@ func TestCheckStorageFile(t *testing.T) { assert.Equal(t, err.Error(), "Unable to create storage file: open /non_existent_path/test_storage.txt: no such file or directory") }) } + +func TestCustomTokenDecodeHook(t *testing.T) { + t.Run("Decode hook converts string to Token successfully", func(t *testing.T) { + data := "mytoken" + fromType := reflect.TypeOf(data) + toType := reflect.TypeOf(Token{}) + + result, err := customTokenDecodeHook(fromType, toType, data) + assert.NoError(t, err) + + // Check if the result is a Token with the correct value + token, ok := result.(Token) + assert.True(t, ok) + assert.Equal(t, "mytoken", token.Get()) + }) + + t.Run("Decode hook passes through non-Token types", func(t *testing.T) { + data := 42 + fromType := reflect.TypeOf(data) + toType := reflect.TypeOf(42) + + result, err := customTokenDecodeHook(fromType, toType, data) + assert.NoError(t, err) + + // Check if the result is the same as the input data + assert.Equal(t, data, result) + }) + + t.Run("Decode hook returns error for invalid Token value", func(t *testing.T) { + data := 42 + fromType := reflect.TypeOf(data) + toType := reflect.TypeOf(Token{}) + + result, err := customTokenDecodeHook(fromType, toType, data) + assert.Error(t, err) + assert.Nil(t, result) + assert.EqualError(t, err, "Unable to decode Token. '' expected type 'string', got unconvertible type 'int', value: '42'") + }) +} diff --git a/internal/handlers/append.go b/internal/handlers/append.go index 4ab8710..d67de41 100644 --- a/internal/handlers/append.go +++ b/internal/handlers/append.go @@ -11,7 +11,7 @@ import ( ) func init() { - server.Register("/append", Append, "Append content to storage file", "POST") + server.Register("/append", "Append content to storage file", true, Append, "POST") } // Append handles HTTP requests for appending content to a file. diff --git a/internal/handlers/fetch.go b/internal/handlers/fetch.go index 5be29eb..033fead 100644 --- a/internal/handlers/fetch.go +++ b/internal/handlers/fetch.go @@ -9,7 +9,7 @@ import ( ) func init() { - server.Register("/fetch", Fetch, "Fetch content from storage file", "GET") + server.Register("/fetch", "Fetch content from storage file", true, Fetch, "GET") } // Fetch handles HTTP requests for fetching content. @@ -25,7 +25,7 @@ func init() { // - r: *http.Request // The HTTP request being processed. func Fetch(w http.ResponseWriter, r *http.Request) { - LogRequest(r) + log.Trace().Msg(generateHTTPRequestLogEntry(r)) w.Header().Set("Content-Type", "application/text") diff --git a/internal/handlers/home.go b/internal/handlers/home.go index 926d643..3e81d92 100644 --- a/internal/handlers/home.go +++ b/internal/handlers/home.go @@ -4,10 +4,12 @@ import ( "net/http" "vimbin/internal/config" "vimbin/internal/server" + + "github.com/rs/zerolog/log" ) func init() { - server.Register("/", Home, "Home site with editor", "GET") + server.Register("/", "Home site with editor", false, Home, "GET") } // Home handles HTTP requests for the home page. @@ -22,11 +24,12 @@ func init() { // - r: *http.Request // The HTTP request being processed. func Home(w http.ResponseWriter, r *http.Request) { - LogRequest(r) + log.Trace().Msg(generateHTTPRequestLogEntry(r)) page := Page{ Title: "vimbin - a pastebin with vim motion", Content: config.App.Storage.Content.Get(), + Token: config.App.Server.Api.Token.Get(), } if err := config.App.HtmlTemplate.Execute(w, page); err != nil { diff --git a/internal/handlers/save.go b/internal/handlers/save.go index a560426..d4cf485 100644 --- a/internal/handlers/save.go +++ b/internal/handlers/save.go @@ -10,7 +10,7 @@ import ( ) func init() { - server.Register("/save", Save, "Save content to storage file", "POST") + server.Register("/save", "Save content to storage file", true, Save, "POST") } // Save handles HTTP requests for saving content to a file. diff --git a/internal/handlers/structs.go b/internal/handlers/structs.go index 55f13b4..171a4e2 100644 --- a/internal/handlers/structs.go +++ b/internal/handlers/structs.go @@ -7,4 +7,5 @@ package handlers type Page struct { Title string // Title is the title of the page. Content string // Content is the content of the page. + Token string // Token is the API token. } diff --git a/internal/handlers/utils.go b/internal/handlers/utils.go index 6864fc3..7913457 100644 --- a/internal/handlers/utils.go +++ b/internal/handlers/utils.go @@ -8,7 +8,6 @@ import ( "strings" "vimbin/internal/config" - "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) @@ -23,16 +22,20 @@ func Collect() {} // filePermission represents the default file permission used in the application. const filePermission = 0644 -// LogRequest logs the details of an HTTP request. +// generateHTTPRequestLogEntry generates a log entry for an HTTP request. +// +// This function takes an HTTP request as input and creates a formatted log entry +// containing the request method, request URI, and sanitized query string. It removes +// newlines, carriage returns, and replaces spaces in the query string to ensure a clean log entry. // // Parameters: // - req: *http.Request -// The HTTP request to log. -func LogRequest(req *http.Request) { - if log.Logger.GetLevel() != zerolog.TraceLevel { - return - } - +// The HTTP request to generate a log entry for. +// +// Returns: +// - string +// A formatted log entry containing the request details. +func generateHTTPRequestLogEntry(req *http.Request) string { query := strings.Map(func(r rune) rune { switch r { case '\n', '\r': // Remove newlines and carriage returns from the query string @@ -44,7 +47,7 @@ func LogRequest(req *http.Request) { } }, req.URL.RawQuery) - log.Trace().Msgf("%s %s%s", req.Method, req.RequestURI, query) + return fmt.Sprintf("%s %s%s", req.Method, req.RequestURI, query) } // handleContentRequest handles HTTP requests for updating content. @@ -79,7 +82,7 @@ func handleContentRequest( mergeContentFunc func(string, string) string, saveContentFunc func(*config.Content, string), ) { - LogRequest(r) + log.Trace().Msg(generateHTTPRequestLogEntry(r)) // Parse JSON request body var requestData map[string]string diff --git a/internal/server/auth.go b/internal/server/auth.go new file mode 100644 index 0000000..4a15028 --- /dev/null +++ b/internal/server/auth.go @@ -0,0 +1,20 @@ +package server + +import ( + "net/http" + + "github.com/rs/zerolog/log" +) + +// ApiTokenMiddleware checks for the presence and validity of the API token. +func ApiTokenMiddleware(next http.HandlerFunc, token string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + apiToken := r.Header.Get("X-API-Token") + if apiToken != token { + log.Error().Msgf("Unauthorized API token: %s", apiToken) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + next(w, r) + } +} diff --git a/internal/server/register.go b/internal/server/register.go index af197ca..83f8602 100644 --- a/internal/server/register.go +++ b/internal/server/register.go @@ -5,8 +5,9 @@ import "net/http" // Handler is a struct that contains a route and a handler function. type Handler struct { Path string // Path is the URL route for the handler. - Handler func(http.ResponseWriter, *http.Request) // Handler is the function that handles HTTP requests for the route. Description string // Description provides a brief explanation of the handler's purpose. + Handler func(http.ResponseWriter, *http.Request) // Handler is the function that handles HTTP requests for the route. + NeedsToken bool // NeedsToken indicates if the handler requires a valid API token. Methods []string // Methods is a list of HTTP methods supported by the handler. } @@ -26,14 +27,17 @@ var Handlers = []Handler{} // The function that handles HTTP requests for the route. // - description: string // A brief explanation of the handler's purpose. +// - needsToken: bool +// Indicates if the handler requires a valid API token. // - methods: ...string // Optional list of HTTP methods supported by the handler. -func Register(path string, h func(http.ResponseWriter, *http.Request), description string, methods ...string) { +func Register(path, description string, needsToken bool, handler func(http.ResponseWriter, *http.Request), methods ...string) { Handlers = append(Handlers, Handler{ Path: path, Description: description, - Handler: h, + Handler: handler, + NeedsToken: needsToken, Methods: methods, }) } diff --git a/internal/server/register_test.go b/internal/server/register_test.go index 877ef76..ed96115 100644 --- a/internal/server/register_test.go +++ b/internal/server/register_test.go @@ -17,7 +17,7 @@ func TestRegister(t *testing.T) { description := "Example handler description" handlerFunc := func(http.ResponseWriter, *http.Request) {} - Register(path, handlerFunc, description) + Register(path, description, false, handlerFunc) // Check if the handler is registered correctly assert.Len(t, Handlers, 1) @@ -36,7 +36,7 @@ func TestRegister(t *testing.T) { handlerFunc := func(http.ResponseWriter, *http.Request) {} methods := []string{"GET", "POST"} - Register(path, handlerFunc, description, methods...) + Register(path, description, false, handlerFunc, methods...) // Check if the handler is registered correctly assert.Len(t, Handlers, 1) @@ -46,6 +46,24 @@ func TestRegister(t *testing.T) { assert.Equal(t, methods, Handlers[0].Methods) }) + t.Run("Register a handler token set to true", func(t *testing.T) { + // Clear existing handlers + Handlers = nil + + path := "/example" + description := "Example handler description" + handlerFunc := func(http.ResponseWriter, *http.Request) {} + + Register(path, description, true, handlerFunc) + + // Check if the handler is registered correctly + assert.Len(t, Handlers, 1) + assert.Equal(t, path, Handlers[0].Path) + assert.Equal(t, description, Handlers[0].Description) + assert.Equal(t, reflect.ValueOf(handlerFunc).Pointer(), reflect.ValueOf(Handlers[0].Handler).Pointer()) + assert.Empty(t, Handlers[0].Methods) + }) + t.Run("Register multiple handlers", func(t *testing.T) { // Clear existing handlers Handlers = nil @@ -60,11 +78,17 @@ func TestRegister(t *testing.T) { handlerFunc2 := func(http.ResponseWriter, *http.Request) {} methods2 := []string{"POST"} - Register(path1, handlerFunc1, description1, methods1...) - Register(path2, handlerFunc2, description2, methods2...) + path3 := "/example3" + description3 := "Example handler 3 description" + handlerFunc3 := func(http.ResponseWriter, *http.Request) {} + methods3 := []string{"GET", "POST"} + + Register(path1, description1, false, handlerFunc1, methods1...) + Register(path2, description2, false, handlerFunc2, methods2...) + Register(path3, description3, true, handlerFunc3, methods3...) // Check if both handlers are registered correctly - assert.Len(t, Handlers, 2) + assert.Len(t, Handlers, 3) assert.Equal(t, path1, Handlers[0].Path) assert.Equal(t, description1, Handlers[0].Description) @@ -75,5 +99,11 @@ func TestRegister(t *testing.T) { assert.Equal(t, description2, Handlers[1].Description) assert.Equal(t, reflect.ValueOf(handlerFunc2).Pointer(), reflect.ValueOf(Handlers[1].Handler).Pointer()) assert.Equal(t, methods2, Handlers[1].Methods) + + assert.Equal(t, path3, Handlers[2].Path) + assert.Equal(t, description3, Handlers[2].Description) + assert.Equal(t, reflect.ValueOf(handlerFunc3).Pointer(), reflect.ValueOf(Handlers[2].Handler).Pointer()) + assert.Equal(t, methods3, Handlers[2].Methods) + }) } diff --git a/internal/server/server.go b/internal/server/server.go index f30b30b..92a32d2 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -21,7 +21,7 @@ var StaticFS embed.FS // Parameters: // - listenAddress: string // The address on which the server should listen (e.g., ":8080"). -func Run(listenAddress string) { +func Run(listenAddress string, token string) { // Use a buffered channel for runChan to prevent signal drops runChan := make(chan os.Signal, 1) signal.Notify(runChan, os.Interrupt, syscall.SIGTERM) @@ -31,7 +31,7 @@ func Run(listenAddress string) { defer cancel() // Create the router and configure routes - router := newRouter() + router := newRouter(token) // Create the HTTP server server := &http.Server{ @@ -72,7 +72,7 @@ func Run(listenAddress string) { // Returns: // - *mux.Router // A configured instance of the Gorilla Mux router. -func newRouter() *mux.Router { +func newRouter(token string) *mux.Router { router := mux.NewRouter() // Handler for embed static files @@ -84,7 +84,12 @@ func newRouter() *mux.Router { // Add the handlers to the router for _, h := range Handlers { + if h.NeedsToken { + router.Handle(h.Path, ApiTokenMiddleware(h.Handler, token)).Methods(h.Methods...) + continue + } router.HandleFunc(h.Path, h.Handler).Methods(h.Methods...) + } // Custom 404 handler diff --git a/internal/server/server_test.go b/internal/server/server_test.go index e8e27b7..181820b 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -17,7 +17,7 @@ func TestRun(t *testing.T) { testPort := "127.0.0.1:0" // Run the server in a goroutine - go Run(testPort) + go Run(testPort, "token") // Allow some time for the server to start time.Sleep(500 * time.Millisecond) @@ -45,36 +45,102 @@ func stopServerGracefully() { } } +func TestNotFoundHandler(t *testing.T) { + t.Run("Custom 404 handler returns a 404 response", func(t *testing.T) { + recorder := httptest.NewRecorder() + request := httptest.NewRequest("GET", "/nonexistent", nil) + notFoundHandler(recorder, request) + + // Check if the response has a 404 status code + assert.Equal(t, http.StatusNotFound, recorder.Code) + + // Check if the response body contains the expected content + expectedContent := "

404 Not Found

" + assert.Contains(t, recorder.Body.String(), expectedContent) + }) +} + func TestNewRouter(t *testing.T) { - t.Run("Router configuration", func(t *testing.T) { - // Set up a router - Handlers = []Handler{ - { - Path: "/example", - Description: "Example handler description", - Handler: func(r http.ResponseWriter, w *http.Request) { - r.Write([]byte("Test file content")) - r.Header().Set("Content-Type", "text/plain") - r.WriteHeader(http.StatusOK) - }, - Methods: []string{"GET"}, - }, - } - - router := newRouter() - - // Create a test request - req := httptest.NewRequest("GET", "/example", nil) - w := httptest.NewRecorder() - - // Serve the request using the router - router.ServeHTTP(w, req) - - // Check if the request is handled correctly - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "Test file content") - - // Cleanup: clear the Handlers slice - Handlers = nil + // Mock handler for testing + mockHandler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + + // Mock handlers for testing + Handlers = []Handler{ + { + Path: "/mock", + Handler: mockHandler, + Methods: []string{"GET"}, + NeedsToken: false, + Description: "Mock handler without token", + }, + { + Path: "/mock-with-token", + Handler: mockHandler, + Methods: []string{"GET"}, + NeedsToken: true, + Description: "Mock handler with token", + }, + } + + // Set up the router + router := newRouter("mock-token") + + // Test handler without token + t.Run("Handler without token", func(t *testing.T) { + request := httptest.NewRequest("GET", "/mock", nil) + responseRecorder := httptest.NewRecorder() + + router.ServeHTTP(responseRecorder, request) + + // Check the status code of the response + assert.Equal(t, http.StatusOK, responseRecorder.Code) + }) + + // Test handler with valid token + t.Run("Handler with valid token", func(t *testing.T) { + request := httptest.NewRequest("GET", "/mock-with-token", nil) + request.Header.Set("X-API-Token", "mock-token") + responseRecorder := httptest.NewRecorder() + + router.ServeHTTP(responseRecorder, request) + + // Check the status code of the response + assert.Equal(t, http.StatusOK, responseRecorder.Code) + }) + + // Test handler with invalid token + t.Run("Handler with invalid token", func(t *testing.T) { + request := httptest.NewRequest("GET", "/mock-with-token", nil) + request.Header.Set("X-API-Token", "invalid-token") + responseRecorder := httptest.NewRecorder() + + router.ServeHTTP(responseRecorder, request) + + // Check the status code of the response + assert.Equal(t, http.StatusUnauthorized, responseRecorder.Code) + }) + + // Test static file serving + t.Run("Static file serving", func(t *testing.T) { + request := httptest.NewRequest("GET", "/static/css/vimbin.css", nil) + responseRecorder := httptest.NewRecorder() + + router.ServeHTTP(responseRecorder, request) + + // Check the status code of the response + assert.Equal(t, http.StatusOK, responseRecorder.Code) + }) + + // Test 404 handler + t.Run("404 handler", func(t *testing.T) { + request := httptest.NewRequest("GET", "/non-existent", nil) + responseRecorder := httptest.NewRecorder() + + router.ServeHTTP(responseRecorder, request) + + // Check the status code of the response + assert.Equal(t, http.StatusNotFound, responseRecorder.Code) }) } diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 3bac373..c5735ca 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -1,7 +1,10 @@ package utils import ( + "crypto/rand" "crypto/tls" + "encoding/base64" + "fmt" "net" "net/http" "strconv" @@ -80,3 +83,37 @@ func CreateHTTPClient(insecureSkipVerify bool) *http.Client { return httpClient } + +// GenerateRandomToken generates a random token of the specified length. +// +// Parameters: +// - length: int +// The desired length of the generated token. +// +// Returns: +// - string +// A random token of the specified length. +// - error +// An error if the random token generation fails. +func GenerateRandomToken(length int) (string, error) { + if length <= 0 { + return "", fmt.Errorf("Invalid token length '%d'. Must be at minimum 1", length) + } + // Calculate the number of bytes needed to create the token + numBytes := (length * 3) / 4 + + // Generate random bytes + randomBytes := make([]byte, numBytes) + _, err := rand.Read(randomBytes) + if err != nil { + return "", err + } + + // Encode random bytes to base64 to create the token + token := base64.URLEncoding.EncodeToString(randomBytes) + + // Trim the padding '=' characters + token = token[:length] + + return token, nil +} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go index fd58f6c..5f41bba 100644 --- a/internal/utils/utils_test.go +++ b/internal/utils/utils_test.go @@ -79,3 +79,33 @@ func TestCreateHTTPClient(t *testing.T) { assert.True(t, client.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify) }) } + +func TestGenerateRandomToken(t *testing.T) { + t.Run("Generate random token with length 16", func(t *testing.T) { + token, err := GenerateRandomToken(16) + + assert.Nil(t, err) + assert.Equal(t, 16, len(token)) + }) + + t.Run("Generate random token with length 32", func(t *testing.T) { + token, err := GenerateRandomToken(32) + + assert.Nil(t, err) + assert.Equal(t, 32, len(token)) + }) + + t.Run("Generate random token with length 64", func(t *testing.T) { + token, err := GenerateRandomToken(64) + + assert.Nil(t, err) + assert.Equal(t, 64, len(token)) + }) + + t.Run("Error on token generation with invalid length", func(t *testing.T) { + _, err := GenerateRandomToken(-1) + + assert.NotNil(t, err) + assert.EqualError(t, err, "Invalid token length '-1'. Must be at minimum 1") + }) +} diff --git a/web/static/js/vimbin.js b/web/static/js/vimbin.js index 9f23f73..51f728a 100644 --- a/web/static/js/vimbin.js +++ b/web/static/js/vimbin.js @@ -78,6 +78,7 @@ document.addEventListener("DOMContentLoaded", function () { method: "POST", headers: { "Content-Type": "application/json", + "X-API-Token": apiToken, }, body: JSON.stringify({ content: editor.getValue() }), }); diff --git a/web/templates/index.html b/web/templates/index.html index 7416133..2d6abd8 100644 --- a/web/templates/index.html +++ b/web/templates/index.html @@ -33,6 +33,9 @@

{{.Title}}

+