Skip to content

Commit

Permalink
Upgrade to gorm v2 (#413)
Browse files Browse the repository at this point in the history
<!--  Thanks for sending a pull request!  Here are some tips for you:

1. Run unit tests and ensure that they are passing
2. If your change introduces any API changes, make sure to update the
e2e tests
3. Make sure documentation is updated for your PR!

-->

**What this PR does / why we need it**:
<!-- Explain here the context and why you're making the change. What is
the problem you're trying to solve. --->

Upgrading the dependency on Gorm from `github.com/jinzhu/gorm v1.9.11`
to [Gorm v2](https://gorm.io/) @ `v1.24.0`. Some minor refactoring of
the DB connection layer has been done.

Notable changes:
* Function `gorm.IsRecordNotFoundError` is removed. Checks replaced by:
`errors.Is(err, gorm.ErrRecordNotFound)`
* Gorm's tag `association_foreignkey` is changed to `references`.
* Count() query works with `int64` now, as opposed to `int`.
* `it/database/` moved to -> `database/`. The functions in the new
package are largely similar to
[turing](https://github.com/caraml-dev/turing/tree/main/api/turing/database).
* `api/cmd/api/main.go`, `api/cmd/api/setup.go` - Refactored the running
of DB migrations for main DB and test DB into the same function
`migrateDB` in the `database/` package.
* `github.com/pilagod/gorm-cursor-paginator` library dependency has been
upgraded to match the Gorm version. There are some minor changes in the
API contract and behavior.
* Refactored some of the transaction rollback logic as
`RollbackUnlessCommitted()` is removed in Gorm v2.
* `Select("table_name.*")`, to select all columns of a table, no longer
works. However, the columns selected are automatically filtered
according to the type of the variables being selected into, so I removed
such `Select` calls.

**Which issue(s) this PR fixes**:
<!--
*Automatically closes linked issue when PR is merged.
Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`.
-->

None

**Does this PR introduce a user-facing change?**:
<!--
If no, just write "NONE" in the release-note block below.
If yes, a release note is required. Enter your extended release note in
the block below.
If the PR requires additional action from users switching to the new
release, include the string "action required".

For more information about release notes, see kubernetes' guide here:
http://git.k8s.io/community/contributors/guide/release-notes.md
-->

```release-note
NONE
```

**Checklist**

- [ ] Added unit test, integration, and/or e2e tests
- [x] Tested locally
- [ ] Updated documentation
- [ ] Update Swagger spec if the PR introduce API changes
- [ ] Regenerated Golang and Python client if the PR introduce API
changes
  • Loading branch information
krithika369 authored Jul 20, 2023
1 parent 2a70a79 commit f89860f
Show file tree
Hide file tree
Showing 55 changed files with 529 additions and 389 deletions.
15 changes: 8 additions & 7 deletions api/api/model_endpoint_alerts_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
package api

import (
"errors"
"fmt"
"net/http"

"github.com/jinzhu/gorm"
"gorm.io/gorm"

"github.com/caraml-dev/merlin/models"
)
Expand All @@ -44,7 +45,7 @@ func (c *AlertsController) ListModelEndpointAlerts(r *http.Request, vars map[str

modelEndpointAlerts, err := c.ModelEndpointAlertService.ListModelAlerts(modelID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model endpoint alert not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error listing alerts for model: %v", err))
Expand All @@ -60,7 +61,7 @@ func (c *AlertsController) GetModelEndpointAlert(r *http.Request, vars map[strin

modelEndpointAlert, err := c.ModelEndpointAlertService.GetModelEndpointAlert(modelID, modelEndpointID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model endpoint alert not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting alert for model endpoint: %v", err))
Expand All @@ -84,7 +85,7 @@ func (c *AlertsController) CreateModelEndpointAlert(r *http.Request, vars map[st

model, err := c.ModelsService.FindByID(ctx, modelID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model: %v", err))
Expand All @@ -94,7 +95,7 @@ func (c *AlertsController) CreateModelEndpointAlert(r *http.Request, vars map[st

modelEndpoint, err := c.ModelEndpointsService.FindByID(ctx, modelEndpointID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model endpoint not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model endpoint: %v", err))
Expand Down Expand Up @@ -124,15 +125,15 @@ func (c *AlertsController) UpdateModelEndpointAlert(r *http.Request, vars map[st

model, err := c.ModelsService.FindByID(ctx, modelID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model: %v", err))
}

oldAlert, err := c.ModelEndpointAlertService.GetModelEndpointAlert(modelID, modelEndpointID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model endpoint alert not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting alert for model endpoint: %v", err))
Expand Down
2 changes: 1 addition & 1 deletion api/api/model_endpoint_alerts_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (

"github.com/caraml-dev/merlin/models"
"github.com/caraml-dev/merlin/service/mocks"
"github.com/jinzhu/gorm"
"gorm.io/gorm"
)

func TestListTeams(t *testing.T) {
Expand Down
23 changes: 12 additions & 11 deletions api/api/model_endpoints_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
package api

import (
"errors"
"fmt"
"net/http"

"github.com/caraml-dev/merlin/models"
"github.com/jinzhu/gorm"
"gorm.io/gorm"
)

// ModelEndpointsController controls model endpoints API
Expand All @@ -36,7 +37,7 @@ func (c *ModelEndpointsController) ListModelEndpointInProject(r *http.Request, v

modelEndpoints, err := c.ModelEndpointsService.ListModelEndpointsInProject(ctx, projectID, region)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model endpoints not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error listing model endpoints: %v", err))
Expand All @@ -52,7 +53,7 @@ func (c *ModelEndpointsController) ListModelEndpoints(r *http.Request, vars map[
modelID, _ := models.ParseID(vars["model_id"])
modelEndpoints, err := c.ModelEndpointsService.ListModelEndpoints(ctx, modelID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model endpoints not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error listing model endpoints: %v", err))
Expand All @@ -67,7 +68,7 @@ func (c *ModelEndpointsController) GetModelEndpoint(r *http.Request, vars map[st
modelEndpointID, _ := models.ParseID(vars["model_endpoint_id"])
modelEndpoint, err := c.ModelEndpointsService.FindByID(ctx, modelEndpointID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model endpoint not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model endpoint: %v", err))
Expand All @@ -86,7 +87,7 @@ func (c *ModelEndpointsController) CreateModelEndpoint(r *http.Request, vars map
modelID, _ := models.ParseID(vars["model_id"])
model, err := c.ModelsService.FindByID(ctx, modelID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model: %v", err))
Expand All @@ -109,7 +110,7 @@ func (c *ModelEndpointsController) CreateModelEndpoint(r *http.Request, vars map
// Check environment exists
env, err = c.AppContext.EnvironmentService.GetEnvironment(endpoint.EnvironmentName)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Environment not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting environment: %v", err))
Expand All @@ -135,7 +136,7 @@ func (c *ModelEndpointsController) UpdateModelEndpoint(r *http.Request, vars map
modelEndpointID, _ := models.ParseID(vars["model_endpoint_id"])
model, err := c.ModelsService.FindByID(ctx, modelID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model: %v", err))
Expand All @@ -159,7 +160,7 @@ func (c *ModelEndpointsController) UpdateModelEndpoint(r *http.Request, vars map
// Check environment exists
env, err = c.AppContext.EnvironmentService.GetEnvironment(newEndpoint.EnvironmentName)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Environment not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting environment: %v", err))
Expand All @@ -170,7 +171,7 @@ func (c *ModelEndpointsController) UpdateModelEndpoint(r *http.Request, vars map

currentEndpoint, err := c.ModelEndpointsService.FindByID(ctx, modelEndpointID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model endpoint not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model endpoint: %v", err))
Expand Down Expand Up @@ -207,15 +208,15 @@ func (c *ModelEndpointsController) DeleteModelEndpoint(r *http.Request, vars map

model, err := c.ModelsService.FindByID(ctx, modelID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model: %v", err))
}

modelEndpoint, err := c.ModelEndpointsService.FindByID(ctx, modelEndpointID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model endpoint not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model endpoint: %v", err))
Expand Down
2 changes: 1 addition & 1 deletion api/api/model_endpoints_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import (
"github.com/caraml-dev/merlin/mlp"
"github.com/caraml-dev/merlin/models"
"github.com/caraml-dev/merlin/service/mocks"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/mock"
"gorm.io/gorm"
)

func TestListModelEndpointInProject(t *testing.T) {
Expand Down
7 changes: 4 additions & 3 deletions api/api/models_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
package api

import (
"errors"
"fmt"
"net/http"
"strconv"

"github.com/jinzhu/gorm"
"gorm.io/gorm"

"github.com/caraml-dev/merlin/mlflow"
"github.com/caraml-dev/merlin/models"
Expand Down Expand Up @@ -94,7 +95,7 @@ func (c *ModelsController) GetModel(r *http.Request, vars map[string]string, bod

model, err := c.ModelsService.FindByID(ctx, modelID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model: %v", err))
Expand Down Expand Up @@ -123,7 +124,7 @@ func (c *ModelsController) DeleteModel(r *http.Request, vars map[string]string,

model, err := c.ModelsService.FindByID(ctx, modelID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model is not found: %s", err.Error()))
}
return InternalServerError(fmt.Sprintf("Error getting model: %v", err.Error()))
Expand Down
2 changes: 1 addition & 1 deletion api/api/models_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import (
mlflowDeleteServiceMocks "github.com/caraml-dev/mlp/api/pkg/client/mlflow/mocks"
"github.com/google/uuid"

"github.com/jinzhu/gorm"
"github.com/stretchr/testify/mock"
"gorm.io/gorm"

"github.com/caraml-dev/mlp/api/client"

Expand Down
17 changes: 9 additions & 8 deletions api/api/prediction_job_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
package api

import (
"errors"
"fmt"
"net/http"

"github.com/jinzhu/gorm"
"gorm.io/gorm"

"github.com/caraml-dev/merlin/models"
"github.com/caraml-dev/merlin/service"
Expand All @@ -38,7 +39,7 @@ func (c *PredictionJobController) Create(r *http.Request, vars map[string]string

model, version, err := c.getModelAndVersion(ctx, modelID, versionID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model / version not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model / version: %v", err))
Expand Down Expand Up @@ -71,7 +72,7 @@ func (c *PredictionJobController) List(r *http.Request, vars map[string]string,

model, _, err := c.getModelAndVersion(ctx, modelID, versionID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model / version not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model / version: %v", err))
Expand Down Expand Up @@ -100,7 +101,7 @@ func (c *PredictionJobController) Get(r *http.Request, vars map[string]string, _

model, version, err := c.getModelAndVersion(ctx, modelID, versionID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model / version not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model / version: %v", err))
Expand All @@ -113,7 +114,7 @@ func (c *PredictionJobController) Get(r *http.Request, vars map[string]string, _

job, err := c.PredictionJobService.GetPredictionJob(ctx, env, model, version, id)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Prediction job not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting prediction job: %v", err))
Expand All @@ -132,7 +133,7 @@ func (c *PredictionJobController) Stop(r *http.Request, vars map[string]string,

model, version, err := c.getModelAndVersion(ctx, modelID, versionID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model / version not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model / version: %v", err))
Expand Down Expand Up @@ -161,7 +162,7 @@ func (c *PredictionJobController) ListContainers(r *http.Request, vars map[strin

model, version, err := c.getModelAndVersion(ctx, modelID, versionID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Model / version not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting model / version: %v", err))
Expand All @@ -174,7 +175,7 @@ func (c *PredictionJobController) ListContainers(r *http.Request, vars map[strin

job, err := c.PredictionJobService.GetPredictionJob(ctx, env, model, version, id)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return NotFound(fmt.Sprintf("Prediction job not found: %v", err))
}
return InternalServerError(fmt.Sprintf("Error getting prediction job: %v", err))
Expand Down
2 changes: 1 addition & 1 deletion api/api/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ import (
"github.com/feast-dev/feast/sdk/go/protos/feast/core"
"github.com/go-playground/validator/v10"
"github.com/gorilla/mux"
"github.com/jinzhu/gorm"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"gorm.io/gorm"

"github.com/caraml-dev/mlp/api/pkg/authz/enforcer"
"github.com/caraml-dev/mlp/api/pkg/instrumentation/newrelic"
Expand Down
7 changes: 4 additions & 3 deletions api/api/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,26 @@ package api

import (
"context"
"errors"
"fmt"

"github.com/jinzhu/gorm"
"gorm.io/gorm"

"github.com/caraml-dev/merlin/models"
)

func (c *AppContext) getModelAndVersion(ctx context.Context, modelID models.ID, versionID models.ID) (*models.Model, *models.Version, error) {
model, err := c.ModelsService.FindByID(ctx, modelID)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, fmt.Errorf("model with given id: %d not found", modelID)
}
return nil, nil, fmt.Errorf("error retrieving model with id: %d", modelID)
}

version, err := c.VersionsService.FindByID(ctx, modelID, versionID, c.FeatureToggleConfig.MonitoringConfig)
if err != nil {
if gorm.IsRecordNotFoundError(err) {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, fmt.Errorf("model version with given id: %d not found", versionID)
}
return nil, nil, fmt.Errorf("error retrieving model version with id: %d", versionID)
Expand Down
Loading

0 comments on commit f89860f

Please sign in to comment.