diff --git a/auth/cookie.go b/auth/cookie.go index a8093e483..f93679470 100644 --- a/auth/cookie.go +++ b/auth/cookie.go @@ -9,10 +9,11 @@ import ( "net/url" "time" - "github.com/flyteorg/flyteadmin/auth/interfaces" "github.com/flyteorg/flytestdlib/errors" "github.com/flyteorg/flytestdlib/logger" "github.com/gorilla/securecookie" + + "github.com/flyteorg/flyteadmin/auth/interfaces" ) const ( @@ -52,25 +53,18 @@ func HashCsrfState(csrf string) string { } func NewSecureCookie(cookieName, value string, hashKey, blockKey []byte, domain string, sameSiteMode http.SameSite) (http.Cookie, error) { - var s = securecookie.New(hashKey, blockKey) + s := securecookie.New(hashKey, blockKey) encoded, err := s.Encode(cookieName, value) - if err == nil { - if len(domain) > 0 { - return http.Cookie{ - Name: cookieName, - Value: encoded, - Domain: domain, - SameSite: sameSiteMode, - }, nil - } - return http.Cookie{ - Name: cookieName, - Value: encoded, - SameSite: sameSiteMode, - }, nil - } - - return http.Cookie{}, errors.Wrapf(ErrSecureCookie, err, "Error creating secure cookie") + if err != nil { + return http.Cookie{}, errors.Wrapf(ErrSecureCookie, err, "Error creating secure cookie") + } + + return http.Cookie{ + Name: cookieName, + Value: encoded, + Domain: domain, + SameSite: sameSiteMode, + }, nil } func retrieveSecureCookie(ctx context.Context, request *http.Request, cookieName string, hashKey, blockKey []byte) (string, error) { diff --git a/auth/cookie_manager.go b/auth/cookie_manager.go index a32c773d1..8b6eca36b 100644 --- a/auth/cookie_manager.go +++ b/auth/cookie_manager.go @@ -8,12 +8,12 @@ import ( "net/http" "time" - "github.com/flyteorg/flyteadmin/auth/config" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flytestdlib/errors" "github.com/flyteorg/flytestdlib/logger" - "golang.org/x/oauth2" + + "github.com/flyteorg/flyteadmin/auth/config" ) type CookieManager struct { @@ -175,29 +175,31 @@ func (c CookieManager) SetTokenCookies(ctx context.Context, writer http.Response return nil } -func getLogoutAccessCookie() *http.Cookie { +func (c *CookieManager) getLogoutAccessCookie() *http.Cookie { return &http.Cookie{ Name: accessTokenCookieName, Value: "", + Domain: c.domain, MaxAge: 0, HttpOnly: true, Expires: time.Now().Add(-1 * time.Hour), } } -func getLogoutRefreshCookie() *http.Cookie { +func (c *CookieManager) getLogoutRefreshCookie() *http.Cookie { return &http.Cookie{ Name: refreshTokenCookieName, Value: "", + Domain: c.domain, MaxAge: 0, HttpOnly: true, Expires: time.Now().Add(-1 * time.Hour), } } -func (c CookieManager) DeleteCookies(ctx context.Context, writer http.ResponseWriter) { - http.SetCookie(writer, getLogoutAccessCookie()) - http.SetCookie(writer, getLogoutRefreshCookie()) +func (c CookieManager) DeleteCookies(_ context.Context, writer http.ResponseWriter) { + http.SetCookie(writer, c.getLogoutAccessCookie()) + http.SetCookie(writer, c.getLogoutRefreshCookie()) } func (c CookieManager) getHTTPSameSitePolicy() http.SameSite { diff --git a/auth/cookie_manager_test.go b/auth/cookie_manager_test.go index ce6b3b827..5bb11f5c7 100644 --- a/auth/cookie_manager_test.go +++ b/auth/cookie_manager_test.go @@ -1,17 +1,20 @@ package auth import ( + "bytes" "context" + "encoding/base64" "fmt" "net/http" "net/http/httptest" "testing" "time" - "github.com/flyteorg/flyteadmin/auth/config" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/oauth2" + + "github.com/flyteorg/flyteadmin/auth/config" ) func TestCookieManager_SetTokenCookies(t *testing.T) { @@ -25,124 +28,128 @@ func TestCookieManager_SetTokenCookies(t *testing.T) { } manager, err := NewCookieManager(ctx, hashKeyEncoded, blockKeyEncoded, cookieSetting) assert.NoError(t, err) - token := &oauth2.Token{ AccessToken: "access", RefreshToken: "refresh", } - token = token.WithExtra(map[string]interface{}{ "id_token": "id token", }) - w := httptest.NewRecorder() - _, err = http.NewRequest("GET", "/api/v1/projects", nil) - assert.NoError(t, err) - err = manager.SetTokenCookies(ctx, w, token) - assert.NoError(t, err) - fmt.Println(w.Header().Get("Set-Cookie")) - c := w.Result().Cookies() - assert.Equal(t, "flyte_at", c[0].Name) - assert.Equal(t, "flyte_idt", c[1].Name) - assert.Equal(t, "flyte_rt", c[2].Name) -} + t.Run("invalid_hash_key", func(t *testing.T) { + _, err := NewCookieManager(ctx, "wrong", blockKeyEncoded, cookieSetting) -func TestCookieManager_RetrieveTokenValues(t *testing.T) { - ctx := context.Background() - // These were generated for unit testing only. - hashKeyEncoded := "wG4pE1ccdw/pHZ2ml8wrD5VJkOtLPmBpWbKHmezWXktGaFbRoAhXidWs8OpbA3y7N8vyZhz1B1E37+tShWC7gA" //nolint:goconst - blockKeyEncoded := "afyABVgGOvWJFxVyOvCWCupoTn6BkNl4SOHmahho16Q" //nolint:goconst + assert.EqualError(t, err, "[BINARY_DECODING_FAILED] Error decoding hash key bytes, caused by: illegal base64 data at input byte 4") + }) - cookieSetting := config.CookieSettings{ - SameSitePolicy: config.SameSiteDefaultMode, - Domain: "default", - } + t.Run("invalid_block_key", func(t *testing.T) { + _, err := NewCookieManager(ctx, hashKeyEncoded, "wrong", cookieSetting) - manager, err := NewCookieManager(ctx, hashKeyEncoded, blockKeyEncoded, cookieSetting) - assert.NoError(t, err) + assert.EqualError(t, err, "[BINARY_DECODING_FAILED] Error decoding block key bytes, caused by: illegal base64 data at input byte 4") + }) - token := &oauth2.Token{ - AccessToken: "access", - RefreshToken: "refresh", - } + t.Run("set_token_cookies", func(t *testing.T) { + w := httptest.NewRecorder() - token = token.WithExtra(map[string]interface{}{ - "id_token": "id token", + err = manager.SetTokenCookies(ctx, w, token) + + assert.NoError(t, err) + fmt.Println(w.Header().Get("Set-Cookie")) + c := w.Result().Cookies() + assert.Equal(t, "flyte_at", c[0].Name) + assert.Equal(t, "flyte_idt", c[1].Name) + assert.Equal(t, "flyte_rt", c[2].Name) }) - w := httptest.NewRecorder() - _, err = http.NewRequest("GET", "/api/v1/projects", nil) - assert.NoError(t, err) - err = manager.SetTokenCookies(ctx, w, token) - assert.NoError(t, err) + t.Run("set_token_cookies_wrong_key", func(t *testing.T) { + wrongKey := base64.RawStdEncoding.EncodeToString(bytes.Repeat([]byte("X"), 75)) + wrongManager, err := NewCookieManager(ctx, wrongKey, wrongKey, cookieSetting) + require.NoError(t, err) + w := httptest.NewRecorder() - cookies := w.Result().Cookies() - req, err := http.NewRequest("GET", "/api/v1/projects", nil) - assert.NoError(t, err) - for _, c := range cookies { - req.AddCookie(c) - } + err = wrongManager.SetTokenCookies(ctx, w, token) - idToken, access, refresh, err := manager.RetrieveTokenValues(ctx, req) - assert.NoError(t, err) - assert.Equal(t, "id token", idToken) - assert.Equal(t, "access", access) - assert.Equal(t, "refresh", refresh) -} + assert.EqualError(t, err, "[SECURE_COOKIE_ERROR] Error creating secure cookie, caused by: securecookie: error - caused by: crypto/aes: invalid key size 75") + }) -func TestGetLogoutAccessCookie(t *testing.T) { - cookie := getLogoutAccessCookie() - assert.True(t, time.Now().After(cookie.Expires)) -} + t.Run("retrieve_token_values", func(t *testing.T) { + w := httptest.NewRecorder() -func TestGetLogoutRefreshCookie(t *testing.T) { - cookie := getLogoutRefreshCookie() - assert.True(t, time.Now().After(cookie.Expires)) -} + err = manager.SetTokenCookies(ctx, w, token) + assert.NoError(t, err) -func TestCookieManager_DeleteCookies(t *testing.T) { - ctx := context.Background() + cookies := w.Result().Cookies() + req, err := http.NewRequest("GET", "/api/v1/projects", nil) + assert.NoError(t, err) + for _, c := range cookies { + req.AddCookie(c) + } - // These were generated for unit testing only. - hashKeyEncoded := "wG4pE1ccdw/pHZ2ml8wrD5VJkOtLPmBpWbKHmezWXktGaFbRoAhXidWs8OpbA3y7N8vyZhz1B1E37+tShWC7gA" //nolint:goconst - blockKeyEncoded := "afyABVgGOvWJFxVyOvCWCupoTn6BkNl4SOHmahho16Q" //nolint:goconst - cookieSetting := config.CookieSettings{ - SameSitePolicy: config.SameSiteDefaultMode, - Domain: "default", - } + idToken, access, refresh, err := manager.RetrieveTokenValues(ctx, req) - manager, err := NewCookieManager(ctx, hashKeyEncoded, blockKeyEncoded, cookieSetting) - assert.NoError(t, err) + assert.NoError(t, err) + assert.Equal(t, "id token", idToken) + assert.Equal(t, "access", access) + assert.Equal(t, "refresh", refresh) + }) - w := httptest.NewRecorder() - manager.DeleteCookies(ctx, w) - cookies := w.Result().Cookies() - assert.Equal(t, 2, len(cookies)) - assert.True(t, time.Now().After(cookies[0].Expires)) - assert.True(t, time.Now().After(cookies[1].Expires)) -} + t.Run("retrieve_token_values_wrong_key", func(t *testing.T) { + wrongKey := base64.RawStdEncoding.EncodeToString(bytes.Repeat([]byte("X"), 75)) + wrongManager, err := NewCookieManager(ctx, wrongKey, wrongKey, cookieSetting) + require.NoError(t, err) -func TestGetHTTPSameSitePolicy(t *testing.T) { - ctx := context.Background() + w := httptest.NewRecorder() - // These were generated for unit testing only. - hashKeyEncoded := "wG4pE1ccdw/pHZ2ml8wrD5VJkOtLPmBpWbKHmezWXktGaFbRoAhXidWs8OpbA3y7N8vyZhz1B1E37+tShWC7gA" //nolint:goconst - blockKeyEncoded := "afyABVgGOvWJFxVyOvCWCupoTn6BkNl4SOHmahho16Q" //nolint:goconst - cookieSetting := config.CookieSettings{ - SameSitePolicy: config.SameSiteDefaultMode, - Domain: "default", - } + err = manager.SetTokenCookies(ctx, w, token) + assert.NoError(t, err) - manager, err := NewCookieManager(ctx, hashKeyEncoded, blockKeyEncoded, cookieSetting) - assert.NoError(t, err) - assert.Equal(t, http.SameSiteDefaultMode, manager.getHTTPSameSitePolicy()) + cookies := w.Result().Cookies() + req, err := http.NewRequest("GET", "/api/v1/projects", nil) + assert.NoError(t, err) + for _, c := range cookies { + req.AddCookie(c) + } + + _, _, _, err = wrongManager.RetrieveTokenValues(ctx, req) + + assert.EqualError(t, err, "[EMPTY_OAUTH_TOKEN] Error reading existing secure cookie [flyte_idt]. Error: [SECURE_COOKIE_ERROR] Error reading secure cookie flyte_idt, caused by: securecookie: error - caused by: crypto/aes: invalid key size 75") + }) + + t.Run("logout_access_cookie", func(t *testing.T) { + cookie := manager.getLogoutAccessCookie() + + assert.True(t, time.Now().After(cookie.Expires)) + assert.Equal(t, cookieSetting.Domain, cookie.Domain) + }) + + t.Run("logout_refresh_cookie", func(t *testing.T) { + cookie := manager.getLogoutRefreshCookie() - manager.sameSitePolicy = config.SameSiteLaxMode - assert.Equal(t, http.SameSiteLaxMode, manager.getHTTPSameSitePolicy()) + assert.True(t, time.Now().After(cookie.Expires)) + assert.Equal(t, cookieSetting.Domain, cookie.Domain) + }) - manager.sameSitePolicy = config.SameSiteStrictMode - assert.Equal(t, http.SameSiteStrictMode, manager.getHTTPSameSitePolicy()) + t.Run("delete_cookies", func(t *testing.T) { + w := httptest.NewRecorder() - manager.sameSitePolicy = config.SameSiteNoneMode - assert.Equal(t, http.SameSiteNoneMode, manager.getHTTPSameSitePolicy()) + manager.DeleteCookies(ctx, w) + + cookies := w.Result().Cookies() + require.Equal(t, 2, len(cookies)) + assert.True(t, time.Now().After(cookies[0].Expires)) + assert.Equal(t, cookieSetting.Domain, cookies[0].Domain) + assert.True(t, time.Now().After(cookies[1].Expires)) + assert.Equal(t, cookieSetting.Domain, cookies[1].Domain) + }) + + t.Run("get_http_same_site_policy", func(t *testing.T) { + manager.sameSitePolicy = config.SameSiteLaxMode + assert.Equal(t, http.SameSiteLaxMode, manager.getHTTPSameSitePolicy()) + + manager.sameSitePolicy = config.SameSiteStrictMode + assert.Equal(t, http.SameSiteStrictMode, manager.getHTTPSameSitePolicy()) + + manager.sameSitePolicy = config.SameSiteNoneMode + assert.Equal(t, http.SameSiteNoneMode, manager.getHTTPSameSitePolicy()) + }) } diff --git a/auth/handlers.go b/auth/handlers.go index 9604d90ec..0ee2cc776 100644 --- a/auth/handlers.go +++ b/auth/handlers.go @@ -8,6 +8,9 @@ import ( "strings" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flytestdlib/errors" + "github.com/flyteorg/flytestdlib/logger" "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" "golang.org/x/oauth2" "google.golang.org/grpc" @@ -20,9 +23,6 @@ import ( "github.com/flyteorg/flyteadmin/auth/interfaces" "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/plugins" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" - "github.com/flyteorg/flytestdlib/errors" - "github.com/flyteorg/flytestdlib/logger" ) const (