diff --git a/shibuya/api/collection.go b/shibuya/api/collection.go index e12ffe7a..3c8ba4e8 100644 --- a/shibuya/api/collection.go +++ b/shibuya/api/collection.go @@ -25,7 +25,7 @@ func getCollection(collectionID string) (*model.Collection, error) { } func (s *ShibuyaAPI) collectionConfigGetHandler(w http.ResponseWriter, req *http.Request, params httprouter.Params) { - collection, err := checkCollectionOwnership(req, params) + collection, err := hasCollectionOwnership(req, params) if err != nil { s.handleErrors(w, err) return diff --git a/shibuya/api/errors.go b/shibuya/api/errors.go index 8771e2a1..b3cbb44f 100644 --- a/shibuya/api/errors.go +++ b/shibuya/api/errors.go @@ -32,3 +32,11 @@ func makeInternalServerError(message string) error { func makeInvalidResourceError(resource string) error { return fmt.Errorf("%winvalid %s", invalidRequestErr, resource) } + +func makeProjectOwnershipError() error { + return fmt.Errorf("%w%s", noPermissionErr, "You don't own the project") +} + +func makeCollectionOwnershipError() error { + return fmt.Errorf("%w%s", noPermissionErr, "You don't own the collection") +} diff --git a/shibuya/api/main.go b/shibuya/api/main.go index 5d102e07..3dfb8d50 100644 --- a/shibuya/api/main.go +++ b/shibuya/api/main.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "strconv" + "strings" "time" "github.com/julienschmidt/httprouter" @@ -89,11 +90,7 @@ func (s *ShibuyaAPI) handleErrors(w http.ResponseWriter, err error) { } func (s *ShibuyaAPI) projectsGetHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - account := model.GetAccountBySession(r) - if account == nil { - s.makeFailMessage(w, "Need to login", http.StatusForbidden) - return - } + account := r.Context().Value(accountKey).(*model.Account) qs := r.URL.Query() var includeCollections, includePlans bool var err error @@ -145,11 +142,7 @@ func (s *ShibuyaAPI) projectUpdateHandler(w http.ResponseWriter, _ *http.Request } func (s *ShibuyaAPI) projectCreateHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - account := model.GetAccountBySession(r) - if account == nil { - s.handleErrors(w, makeLoginError()) - return - } + account := r.Context().Value(accountKey).(*model.Account) r.ParseForm() name := r.Form.Get("name") if name == "" { @@ -191,18 +184,14 @@ func (s *ShibuyaAPI) projectCreateHandler(w http.ResponseWriter, r *http.Request } func (s *ShibuyaAPI) projectDeleteHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - account := model.GetAccountBySession(r) - if account == nil { - s.handleErrors(w, makeLoginError()) - return - } + account := r.Context().Value(accountKey).(*model.Account) project, err := getProject(params.ByName("project_id")) if err != nil { s.handleErrors(w, err) return } - if _, ok := account.MLMap[project.Owner]; !ok { - s.handleErrors(w, noPermissionErr) + if r := hasProjectOwnership(project, account); !r { + s.handleErrors(w, makeProjectOwnershipError()) return } collectionIDs, err := project.GetCollections() @@ -260,11 +249,7 @@ func (s *ShibuyaAPI) collectionAdminGetHandler(w http.ResponseWriter, r *http.Re } func (s *ShibuyaAPI) planCreateHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - account := model.GetAccountBySession(r) - if account == nil { - s.handleErrors(w, makeLoginError()) - return - } + account := r.Context().Value(accountKey).(*model.Account) r.ParseForm() projectID := r.Form.Get("project_id") project, err := getProject(projectID) @@ -272,8 +257,8 @@ func (s *ShibuyaAPI) planCreateHandler(w http.ResponseWriter, r *http.Request, _ s.handleErrors(w, err) return } - if _, ok := account.MLMap[project.Owner]; !ok { - s.handleErrors(w, makeNoPermissionErr("You don't own the project")) + if r := hasProjectOwnership(project, account); !r { + s.handleErrors(w, makeProjectOwnershipError()) return } name := r.Form.Get("name") @@ -294,11 +279,7 @@ func (s *ShibuyaAPI) planCreateHandler(w http.ResponseWriter, r *http.Request, _ } func (s *ShibuyaAPI) planDeleteHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - account := model.GetAccountBySession(r) - if account == nil { - s.handleErrors(w, makeLoginError()) - return - } + account := r.Context().Value(accountKey).(*model.Account) plan, err := getPlan(params.ByName("plan_id")) if err != nil { s.handleErrors(w, err) @@ -309,8 +290,8 @@ func (s *ShibuyaAPI) planDeleteHandler(w http.ResponseWriter, r *http.Request, p s.handleErrors(w, err) return } - if _, ok := account.MLMap[project.Owner]; !ok { - s.handleErrors(w, makeLoginError()) + if r := hasProjectOwnership(project, account); !r { + s.handleErrors(w, makeProjectOwnershipError()) return } using, err := plan.IsBeingUsed() @@ -355,7 +336,7 @@ func (s *ShibuyaAPI) collectionFilesGetHandler(w http.ResponseWriter, _ *http.Re } func (s *ShibuyaAPI) collectionFilesUploadHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - collection, err := checkCollectionOwnership(r, params) + collection, err := hasCollectionOwnership(r, params) if err != nil { s.handleErrors(w, err) return @@ -375,7 +356,7 @@ func (s *ShibuyaAPI) collectionFilesUploadHandler(w http.ResponseWriter, r *http } func (s *ShibuyaAPI) collectionFilesDeleteHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - collection, err := checkCollectionOwnership(r, params) + collection, err := hasCollectionOwnership(r, params) if err != nil { s.handleErrors(w, err) return @@ -415,11 +396,7 @@ func (s *ShibuyaAPI) planFilesDeleteHandler(w http.ResponseWriter, r *http.Reque } func (s *ShibuyaAPI) collectionCreateHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - account := model.GetAccountBySession(r) - if account == nil { - s.handleErrors(w, makeLoginError()) - return - } + account := r.Context().Value(accountKey).(*model.Account) r.ParseForm() collectionName := r.Form.Get("name") if collectionName == "" { @@ -432,8 +409,8 @@ func (s *ShibuyaAPI) collectionCreateHandler(w http.ResponseWriter, r *http.Requ s.handleErrors(w, err) return } - if _, ok := account.MLMap[project.Owner]; !ok { - s.handleErrors(w, makeNoPermissionErr("You don't have the permission")) + if r := hasProjectOwnership(project, account); !r { + s.handleErrors(w, makeProjectOwnershipError()) return } collectionID, err := model.CreateCollection(collectionName, project.ID) @@ -450,7 +427,7 @@ func (s *ShibuyaAPI) collectionCreateHandler(w http.ResponseWriter, r *http.Requ } func (s *ShibuyaAPI) collectionDeleteHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - collection, err := checkCollectionOwnership(r, params) + collection, err := hasCollectionOwnership(r, params) if err != nil { s.handleErrors(w, err) return @@ -480,7 +457,7 @@ func (s *ShibuyaAPI) collectionDeleteHandler(w http.ResponseWriter, r *http.Requ } func (s *ShibuyaAPI) collectionGetHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - collection, err := checkCollectionOwnership(r, params) + collection, err := hasCollectionOwnership(r, params) if err != nil { s.handleErrors(w, err) return @@ -519,7 +496,7 @@ func (s *ShibuyaAPI) collectionUpdateHandler(w http.ResponseWriter, _ *http.Requ } func (s *ShibuyaAPI) collectionUploadHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - collection, err := checkCollectionOwnership(r, params) + collection, err := hasCollectionOwnership(r, params) if err != nil { s.handleErrors(w, err) return @@ -613,7 +590,7 @@ func (s *ShibuyaAPI) collectionUploadHandler(w http.ResponseWriter, r *http.Requ } func (s *ShibuyaAPI) collectionEnginesDetailHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - collection, err := checkCollectionOwnership(r, params) + collection, err := hasCollectionOwnership(r, params) if err != nil { s.handleErrors(w, err) return @@ -627,7 +604,7 @@ func (s *ShibuyaAPI) collectionEnginesDetailHandler(w http.ResponseWriter, r *ht } func (s *ShibuyaAPI) collectionDeploymentHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - collection, err := checkCollectionOwnership(r, params) + collection, err := hasCollectionOwnership(r, params) if err != nil { s.handleErrors(w, err) return @@ -644,7 +621,7 @@ func (s *ShibuyaAPI) collectionDeploymentHandler(w http.ResponseWriter, r *http. } func (s *ShibuyaAPI) collectionTriggerHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - collection, err := checkCollectionOwnership(r, params) + collection, err := hasCollectionOwnership(r, params) if err != nil { s.handleErrors(w, err) return @@ -656,7 +633,7 @@ func (s *ShibuyaAPI) collectionTriggerHandler(w http.ResponseWriter, r *http.Req } func (s *ShibuyaAPI) collectionTermHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - collection, err := checkCollectionOwnership(r, params) + collection, err := hasCollectionOwnership(r, params) if err != nil { s.handleErrors(w, err) return @@ -668,7 +645,7 @@ func (s *ShibuyaAPI) collectionTermHandler(w http.ResponseWriter, r *http.Reques } func (s *ShibuyaAPI) collectionStatusHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - collection, err := checkCollectionOwnership(r, params) + collection, err := hasCollectionOwnership(r, params) if err != nil { s.handleErrors(w, err) return @@ -681,7 +658,7 @@ func (s *ShibuyaAPI) collectionStatusHandler(w http.ResponseWriter, r *http.Requ } func (s *ShibuyaAPI) collectionPurgeHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - collection, err := checkCollectionOwnership(r, params) + collection, err := hasCollectionOwnership(r, params) if err != nil { s.handleErrors(w, err) return @@ -714,7 +691,7 @@ func (s *ShibuyaAPI) planLogHandler(w http.ResponseWriter, r *http.Request, para } func (s *ShibuyaAPI) streamCollectionMetrics(w http.ResponseWriter, r *http.Request, params httprouter.Params) { - collection, err := checkCollectionOwnership(r, params) + collection, err := hasCollectionOwnership(r, params) if err != nil { s.handleErrors(w, err) return @@ -789,7 +766,7 @@ type Route struct { type Routes []*Route func (s *ShibuyaAPI) InitRoutes() Routes { - return Routes{ + routes := Routes{ &Route{"get_projects", "GET", "/api/projects", s.projectsGetHandler}, &Route{"create_project", "POST", "/api/projects", s.projectCreateHandler}, &Route{"delete_project", "DELETE", "/api/projects/:project_id", s.projectDeleteHandler}, @@ -833,4 +810,12 @@ func (s *ShibuyaAPI) InitRoutes() Routes { &Route{"admin_collections", "GET", "/api/admin/collections", s.collectionAdminGetHandler}, } + for _, r := range routes { + // TODO! We don't require auth for usage endpoint for now. + if strings.Contains(r.Path, "usage") { + continue + } + r.HandlerFunc = s.authRequired(r.HandlerFunc) + } + return routes } diff --git a/shibuya/api/middlewares.go b/shibuya/api/middlewares.go new file mode 100644 index 00000000..32c971a5 --- /dev/null +++ b/shibuya/api/middlewares.go @@ -0,0 +1,40 @@ +package api + +import ( + "context" + "errors" + "net/http" + + "github.com/julienschmidt/httprouter" + "github.com/rakutentech/shibuya/shibuya/model" +) + +const ( + accountKey = "account" +) + +func authWithSession(r *http.Request) (*model.Account, error) { + account := model.GetAccountBySession(r) + if account == nil { + return nil, makeLoginError() + } + return account, nil +} + +// TODO add JWT token auth in the future +func authWithToken(_ *http.Request) (*model.Account, error) { + return nil, errors.New("No token presented") +} + +func (s *ShibuyaAPI) authRequired(next httprouter.Handle) httprouter.Handle { + return httprouter.Handle(func(w http.ResponseWriter, r *http.Request, params httprouter.Params) { + var account *model.Account + var err error + account, err = authWithSession(r) + if err != nil { + s.handleErrors(w, err) + return + } + next(w, r.WithContext(context.WithValue(r.Context(), accountKey, account)), params) + }) +} diff --git a/shibuya/api/networkutils.go b/shibuya/api/networkutils.go new file mode 100644 index 00000000..bfe56440 --- /dev/null +++ b/shibuya/api/networkutils.go @@ -0,0 +1,14 @@ +package api + +import ( + "net/http" + "strings" +) + +func retrieveClientIP(r *http.Request) string { + t := r.Header.Get("x-forwarded-for") + if t == "" { + return r.RemoteAddr + } + return strings.Split(t, ",")[0] +} diff --git a/shibuya/api/ownership.go b/shibuya/api/ownership.go new file mode 100644 index 00000000..f5e1152d --- /dev/null +++ b/shibuya/api/ownership.go @@ -0,0 +1,33 @@ +package api + +import ( + "net/http" + + "github.com/julienschmidt/httprouter" + "github.com/rakutentech/shibuya/shibuya/model" +) + +func hasProjectOwnership(project *model.Project, account *model.Account) bool { + if _, ok := account.MLMap[project.Owner]; !ok { + if !account.IsAdmin() { + return false + } + } + return true +} + +func hasCollectionOwnership(r *http.Request, params httprouter.Params) (*model.Collection, error) { + collection, err := getCollection(params.ByName("collection_id")) + if err != nil { + return nil, err + } + account := r.Context().Value(accountKey).(*model.Account) + project, err := model.GetProject(collection.ProjectID) + if err != nil { + return nil, err + } + if r := hasProjectOwnership(project, account); !r { + return nil, makeCollectionOwnershipError() + } + return collection, nil +} diff --git a/shibuya/api/utils.go b/shibuya/api/utils.go deleted file mode 100644 index d8b463c2..00000000 --- a/shibuya/api/utils.go +++ /dev/null @@ -1,38 +0,0 @@ -package api - -import ( - "net/http" - "strings" - - "github.com/julienschmidt/httprouter" - "github.com/rakutentech/shibuya/shibuya/model" -) - -func retrieveClientIP(r *http.Request) string { - t := r.Header.Get("x-forwarded-for") - if t == "" { - return r.RemoteAddr - } - return strings.Split(t, ",")[0] -} - -func checkCollectionOwnership(r *http.Request, params httprouter.Params) (*model.Collection, error) { - account := model.GetAccountBySession(r) - if account == nil { - return nil, makeLoginError() - } - collection, err := getCollection(params.ByName("collection_id")) - if err != nil { - return nil, err - } - project, err := model.GetProject(collection.ProjectID) - if err != nil { - return nil, err - } - if _, ok := account.MLMap[project.Owner]; !ok { - if !account.IsAdmin() { - return nil, makeNoPermissionErr("You are not the owner of the collection") - } - } - return collection, nil -}