From 7df07384eb6a5b8ef134212483bd576008dd7afe Mon Sep 17 00:00:00 2001 From: Liam Beckman Date: Thu, 25 Jan 2024 13:28:03 -0800 Subject: [PATCH] Update auth handling to match TES Compliance Suite --- server/auth.go | 4 ++-- server/server.go | 26 +++++++++++++++++++++----- tests/core/basic_test.go | 2 +- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/server/auth.go b/server/auth.go index c52cf5844..7750ac096 100644 --- a/server/auth.go +++ b/server/auth.go @@ -50,12 +50,12 @@ func authorize(ctx context.Context, user, password string) error { if requser == user && reqpass == password { return nil } - return status.Errorf(codes.PermissionDenied, "") + return status.Errorf(codes.PermissionDenied, "AUTH DENIED") } } } - return status.Errorf(codes.Unauthenticated, "") + return status.Errorf(codes.Unauthenticated, "UNAUTHENTICATED") } // parseBasicAuth parses an HTTP Basic Authentication string. diff --git a/server/server.go b/server/server.go index aa64c6f6d..f598db761 100644 --- a/server/server.go +++ b/server/server.go @@ -6,6 +6,7 @@ import ( "net/http" "strings" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" @@ -58,7 +59,7 @@ func newDebugInterceptor(log *logger.Logger) grpc.UnaryServerInterceptor { // customErrorHandler is a custom error handler for the gRPC gateway // Returns '400' for invalid backend parameters and '500' for all other errors -// Required for Compliance Tests +// Required for TES Compliance Tests func customErrorHandler(ctx context.Context, mux *runtime.ServeMux, marshaler runtime.Marshaler, w http.ResponseWriter, r *http.Request, err error) { const fallback = `{"error": "failed to process the request"}` @@ -70,10 +71,25 @@ func customErrorHandler(ctx context.Context, mux *runtime.ServeMux, marshaler ru } // Map specific gRPC error codes to HTTP status codes - if (strings.Contains(st.Message(), "backend parameters not supported")) { - w.WriteHeader(http.StatusBadRequest) - } else { - w.WriteHeader(http.StatusInternalServerError) + switch st.Code() { + case codes.NotFound: + // Special case for missing tasks (TES Compliance Suite) + if (strings.Contains(st.Message(), "task not found")) { + w.WriteHeader(http.StatusInternalServerError) // 500 + } else { + w.WriteHeader(http.StatusNotFound) // 404 + } + case codes.PermissionDenied: + w.WriteHeader(http.StatusForbidden) // 403 + case codes.Unauthenticated: + w.WriteHeader(http.StatusUnauthorized) // 401 + default: + // Special case for missing backend parameters (TODO: send error codes from backends?) + if (strings.Contains(st.Message(), "backend parameters not supported")) { + w.WriteHeader(http.StatusBadRequest) // 400 + } else { + w.WriteHeader(http.StatusInternalServerError) // 500 + } } // Write the error message diff --git a/tests/core/basic_test.go b/tests/core/basic_test.go index 261d7e24d..bbcca85f0 100644 --- a/tests/core/basic_test.go +++ b/tests/core/basic_test.go @@ -39,7 +39,7 @@ func TestGetUnknownTask(t *testing.T) { Id: "nonexistent-task-id", View: tes.View_MINIMAL.String(), }) - if err == nil || !strings.Contains(err.Error(), "STATUS CODE - 404") { + if err == nil || !strings.Contains(err.Error(), "STATUS CODE - 500") { t.Error("expected not found error", err) }