Skip to content

Commit

Permalink
Fix Dynamic Validation Flow for Workload Identity
Browse files Browse the repository at this point in the history
  • Loading branch information
rajdeepc2792 committed Oct 8, 2024
1 parent 69378fb commit 177aa40
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 13 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ run-rp: ci-rp podman-secrets
-e ARO_ADOPT_BY_HIVE="true" \
-e MOCK_MSI_TENANT_ID \
-e MOCK_MSI_CLIENT_ID \
-e MOCK_MSI_OBJECT_ID \
-e MOCK_MSI_CERT \
--secret aks.kubeconfig,target=/app/secrets/aks.kubeconfig \
--secret proxy-client.key,target=/app/secrets/proxy-client.key \
Expand Down
1 change: 1 addition & 0 deletions docs/deploy-development-rp.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ mock a cluster MSI. This script will also create the platform identities, platfo
- `RP_MODE`: Set to `development` to use a development RP running at
https://localhost:8443/.
- `MOCK_MSI_CLIENT_ID`: Client ID for service principal that mocks cluster MSI (see previous step).
- `MOCK_MSI_OBJECT_ID`: Object ID for service principal that mocks cluster MSI (see previous step).
- `MOCK_MSI_CERT`: Base64 encoded certificate for service principal that mocks cluster MSI (see previous step).
- `MOCK_MSI_TENANT_ID`: Tenant ID for service principal that mocks cluster MSI (see previous step).
- `PLATFORM_WORKLOAD_IDENTITY_ROLE_SETS`: The platform workload identity role sets (see previous step or value in `local_dev_env.sh`).
Expand Down
1 change: 1 addition & 0 deletions env.example
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export ARO_IMAGE=arointsvc.azurecr.io/aro:latest
# out or remove them from your env file if you're only going to be creating service principal
# clusters.
export MOCK_MSI_CLIENT_ID="replace_with_value_output_by_hack/devtools/msi.sh"
export MOCK_MSI_OBJECT_ID="replace_with_value_output_by_hack/devtools/msi.sh"
export MOCK_MSI_CERT="replace_with_value_output_by_hack/devtools/msi.sh"
export MOCK_MSI_TENANT_ID="replace_with_value_output_by_hack/devtools/msi.sh"
export PLATFORM_WORKLOAD_IDENTITY_ROLE_SETS="replace_with_value_output_by_hack/devtools/msi.sh"
Expand Down
6 changes: 6 additions & 0 deletions hack/devtools/local_dev_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ get_mock_msi_tenantID() {
echo "$1" | jq -r .tenant
}

get_mock_msi_clientID() {
az ad sp list --all --filter "appId eq '$1'" | jq -r ".[] | .id"
}

get_mock_msi_cert() {
certFilePath=$(echo "$1" | jq -r '.fileWithCertAndPrivateKey')
base64EncodedCert=$(base64 -w 0 "$certFilePath")
Expand Down Expand Up @@ -234,6 +238,7 @@ create_miwi_env_file() {
mockClientID=$(get_mock_msi_clientID "$mockMSI")
mockTenantID=$(get_mock_msi_tenantID "$mockMSI")
mockCert=$(get_mock_msi_cert "$mockMSI")
mockObjectID=$(get_mock_msi_objectID "$mockMSI")

setup_platform_identity
cluster_msi_role_assignment "${mockClientID}"
Expand All @@ -243,6 +248,7 @@ export LOCATION=eastus
export ARO_IMAGE=arointsvc.azurecr.io/aro:latest
export RP_MODE=development # to use a development RP running at https://localhost:8443/
export MOCK_MSI_CLIENT_ID="$mockClientID"
export "MOCK_MSI_OBJECT_ID"="$mockObjectID"
export MOCK_MSI_TENANT_ID="$mockTenantID"
export MOCK_MSI_CERT="$mockCert"
export PLATFORM_WORKLOAD_IDENTITY_ROLE_SETS="$PLATFORM_WORKLOAD_IDENTITY_ROLE_SETS"
Expand Down
2 changes: 1 addition & 1 deletion pkg/cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func New(ctx context.Context, log *logrus.Entry, _env env.Interface, db database
return nil, err
}

armRoleDefinitionsClient, err := armauthorization.NewArmRoleDefinitionsClient(fpCredClusterTenant, clientOptions)
armRoleDefinitionsClient, err := armauthorization.NewArmRoleDefinitionsClient(fpCredClusterTenant, r.SubscriptionID, clientOptions)
if err != nil {
return nil, err
}
Expand Down
7 changes: 6 additions & 1 deletion pkg/cluster/deploybaseresources.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,16 @@ func (m *manager) deployBaseResourceTemplate(ctx context.Context) error {
m.storageAccount(m.doc.OpenShiftCluster.Properties.ImageRegistryStorageAccountName, azureRegion, ocpSubnets, true),
m.storageAccountBlobContainer(m.doc.OpenShiftCluster.Properties.ImageRegistryStorageAccountName, "image-registry"),
m.clusterNSG(infraID, azureRegion),
m.clusterServicePrincipalRBAC(),
m.networkPrivateLinkService(azureRegion),
m.networkInternalLoadBalancer(azureRegion),
}

if !m.doc.OpenShiftCluster.UsesWorkloadIdentity() {
resources = append(resources,
m.clusterServicePrincipalRBAC(),
)
}

// Create a public load balancer routing if needed
if m.doc.OpenShiftCluster.Properties.NetworkProfile.OutboundType == api.OutboundTypeLoadbalancer {
m.newPublicLoadBalancer(ctx, &resources)
Expand Down
3 changes: 2 additions & 1 deletion pkg/env/prod.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ func (p *prod) MsiDataplaneClientOptions(msiResourceId *arm.ResourceID) (*policy
if p.FeatureIsSet(FeatureUseMockMsiRp) {
keysToValidate := []string{
"MOCK_MSI_CLIENT_ID",
"MOCK_MSI_OBJECT_ID",
"MOCK_MSI_CERT",
"MOCK_MSI_TENANT_ID",
}
Expand All @@ -425,6 +426,7 @@ func (p *prod) MsiDataplaneClientOptions(msiResourceId *arm.ResourceID) (*policy
ClientID: ptr.To(os.Getenv("MOCK_MSI_CLIENT_ID")),
ClientSecret: ptr.To(os.Getenv("MOCK_MSI_CERT")),
TenantID: ptr.To(os.Getenv("MOCK_MSI_TENANT_ID")),
ObjectID: ptr.To(os.Getenv("MOCK_MSI_OBJECT_ID")),
ResourceID: ptr.To(msiResourceId.String()),
AuthenticationEndpoint: ptr.To(p.Environment().Cloud.ActiveDirectoryAuthorityHost),
CannotRenewAfter: &placeholder,
Expand All @@ -437,7 +439,6 @@ func (p *prod) MsiDataplaneClientOptions(msiResourceId *arm.ResourceID) (*policy
XMSAzNwperimid: []*string{&placeholder},
XMSAzTm: &placeholder,
},
ObjectID: &placeholder,
},
},
},
Expand Down
127 changes: 127 additions & 0 deletions pkg/util/azblob2/manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package azblob2

// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.

import (
"context"
"io"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/sas"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service"
"github.com/sirupsen/logrus"
)

type AZBlobClient interface {
DownloadStream(ctx context.Context, containerName string, blobName string, o *azblob.DownloadStreamOptions) ([]byte, error)
UploadBuffer(ctx context.Context, containerName string, blobName string, buffer []byte) error
DeleteBlob(ctx context.Context, containerName string, blobName string) error
Exists(ctx context.Context, container string, blobPath string) (bool, error)
}

type azBlobClient struct {
client *azblob.Client
}

func NewAZBlobClient(ctx context.Context, blobContainerURL string, credential azcore.TokenCredential, options *azblob.ClientOptions, isUserDesignated bool, log *logrus.Entry) (AZBlobClient, error) {
client, err := azblob.NewClient(blobContainerURL, credential, options)
if err != nil {
return nil, err
}
blobClient := &azBlobClient{client: client}
if isUserDesignated {
sasURL, err := blobClient.signBlobURL(ctx, blobContainerURL, time.Now().UTC().Add(2*time.Hour))
log.Printf("sasURL -------- %s", sasURL)
if err != nil {
return nil, err
}

client, err = azblob.NewClientWithNoCredential(sasURL, options)
if err != nil {
return nil, err
}
blobClient = &azBlobClient{client: client}
}
return blobClient, nil
}

func (azBlobClient *azBlobClient) DownloadStream(ctx context.Context, containerName string, blobName string, o *azblob.DownloadStreamOptions) ([]byte, error) {
response, err := azBlobClient.client.DownloadStream(ctx, containerName, blobName, o)
if err != nil {
return nil, err
}
defer response.Body.Close()
return io.ReadAll(response.Body)
}

func (azBlobClient *azBlobClient) UploadBuffer(ctx context.Context, containerName string, blobName string, buffer []byte) error {
_, err := azBlobClient.client.UploadBuffer(ctx, containerName, blobName, buffer, &azblob.UploadBufferOptions{})
return err
}

func (azBlobClient *azBlobClient) DeleteBlob(ctx context.Context, containerName string, blobName string) error {
_, err := azBlobClient.client.DeleteBlob(ctx, containerName, blobName, &azblob.DeleteBlobOptions{})
return err
}

func (azBlobClient *azBlobClient) signBlobURL(ctx context.Context, blobURL string, expires time.Time) (string, error) {
urlParts, err := sas.ParseURL(blobURL)
if err != nil {
return "", err
}
// perms := sas.BlobPermissions{Read: true, Write: true, Create: true}
perms := sas.BlobPermissions{Read: true, Write: true, Create: true}
signatureValues := sas.BlobSignatureValues{
Protocol: sas.ProtocolHTTPS,
StartTime: time.Now().UTC().Add(-10 * time.Second),
ExpiryTime: expires,
Permissions: perms.String(),
ContainerName: "aro",
BlobName: "graph",
}
urlParts.SAS, err = azBlobClient.sign(ctx, &signatureValues)
if err != nil {
return "", err
}
return urlParts.String(), nil
}

func (azBlobClient *azBlobClient) sign(ctx context.Context, signatureValues *sas.BlobSignatureValues) (sas.QueryParameters, error) {
currentTime := time.Now().UTC().Add(-10 * time.Second)
expiryTime := currentTime.Add(2 * time.Hour)

info := service.KeyInfo{
Start: to.Ptr(currentTime.UTC().Format(sas.TimeFormat)),
Expiry: to.Ptr(expiryTime.UTC().Format(sas.TimeFormat)),
}

udc, err := azBlobClient.client.ServiceClient().GetUserDelegationCredential(ctx, info, nil)
if err != nil {
return sas.QueryParameters{}, err
}
return signatureValues.SignWithUserDelegation(udc)
}

func (azBlobClient *azBlobClient) Exists(ctx context.Context, container string, blobPath string) (bool, error) {
blobRef := azBlobClient.client.ServiceClient().NewContainerClient(container).NewBlobClient(blobPath)
_, err := blobRef.GetProperties(ctx, nil)
if err != nil {
if bloberror.HasCode(
err,
bloberror.BlobNotFound,
bloberror.ContainerNotFound,
bloberror.ResourceNotFound,
bloberror.CannotVerifyCopySource,
) {
return false, nil
} else {
return false, err
}
}
return true, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package armauthorization

import (
"context"
"fmt"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
Expand All @@ -17,13 +18,20 @@ type RoleDefinitionsClient interface {

type ArmRoleDefinitionsClient struct {
*armauthorization.RoleDefinitionsClient
subscriptionID string
}

var _ RoleDefinitionsClient = &ArmRoleDefinitionsClient{}

func NewArmRoleDefinitionsClient(credential azcore.TokenCredential, options *arm.ClientOptions) (*ArmRoleDefinitionsClient, error) {
func NewArmRoleDefinitionsClient(credential azcore.TokenCredential, subscriptionID string, options *arm.ClientOptions) (*ArmRoleDefinitionsClient, error) {
client, err := armauthorization.NewRoleDefinitionsClient(credential, options)
return &ArmRoleDefinitionsClient{
RoleDefinitionsClient: client,
subscriptionID: subscriptionID,
}, err
}

func (client ArmRoleDefinitionsClient) GetByID(ctx context.Context, roleDefinitionID string, options *armauthorization.RoleDefinitionsClientGetByIDOptions) (armauthorization.RoleDefinitionsClientGetByIDResponse, error) {
roleID := fmt.Sprintf("/subscriptions/%s%s", client.subscriptionID, roleDefinitionID)
return client.RoleDefinitionsClient.GetByID(ctx, roleID, options)
}
2 changes: 1 addition & 1 deletion pkg/validate/dynamic/platformworkloadidentityprofile.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (dv *dynamic) validateClusterMSI(ctx context.Context, oc *api.OpenShiftClus

// Validate that the cluster MSI has all permissions specified in AzureRedHatOpenShiftFederatedCredentialRole over each platform managed identity
func (dv *dynamic) validateClusterMSIPermissions(ctx context.Context, oid string, platformIdentities []api.PlatformWorkloadIdentity, roleDefinitions armauthorization.RoleDefinitionsClient) error {
actions, err := getActionsForRoleDefinition(ctx, rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole, roleDefinitions)
actions, err := getActionsForRoleDefinition(ctx, fmt.Sprintf("/providers/Microsoft.Authorization/roleDefinitions/%s", rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole), roleDefinitions)
if err != nil {
return err
}
Expand Down
16 changes: 8 additions & 8 deletions pkg/validate/dynamic/platformworkloadidentityprofile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func TestValidatePlatformWorkloadIdentityProfile(t *testing.T) {
},
},
mocks: func(roleDefinitions *mock_armauthorization.MockRoleDefinitionsClient) {
roleDefinitions.EXPECT().GetByID(ctx, rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole, &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, nil)
roleDefinitions.EXPECT().GetByID(ctx, fmt.Sprintf("/providers/Microsoft.Authorization/roleDefinitions/%s", rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole), &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, nil)
roleDefinitions.EXPECT().GetByID(ctx, gomock.Any(), &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).AnyTimes().Return(platformIdentityRequiredPermissions, nil)
},
checkAccessMocks: func(cancel context.CancelFunc, pdpClient *mock_remotepdp.MockRemotePDPClient, tokenCred *mock_azcore.MockTokenCredential) {
Expand Down Expand Up @@ -232,7 +232,7 @@ func TestValidatePlatformWorkloadIdentityProfile(t *testing.T) {
},
},
mocks: func(roleDefinitions *mock_armauthorization.MockRoleDefinitionsClient) {
roleDefinitions.EXPECT().GetByID(ctx, rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole, &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, nil)
roleDefinitions.EXPECT().GetByID(ctx, fmt.Sprintf("/providers/Microsoft.Authorization/roleDefinitions/%s", rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole), &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, nil)
roleDefinitions.EXPECT().GetByID(ctx, gomock.Any(), &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).AnyTimes().Return(platformIdentityRequiredPermissions, nil)
},
checkAccessMocks: func(cancel context.CancelFunc, pdpClient *mock_remotepdp.MockRemotePDPClient, tokenCred *mock_azcore.MockTokenCredential) {
Expand Down Expand Up @@ -267,7 +267,7 @@ func TestValidatePlatformWorkloadIdentityProfile(t *testing.T) {
},
},
mocks: func(roleDefinitions *mock_armauthorization.MockRoleDefinitionsClient) {
roleDefinitions.EXPECT().GetByID(ctx, rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole, &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, nil)
roleDefinitions.EXPECT().GetByID(ctx, fmt.Sprintf("/providers/Microsoft.Authorization/roleDefinitions/%s", rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole), &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, nil)
roleDefinitions.EXPECT().GetByID(ctx, gomock.Any(), &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).AnyTimes().Return(platformIdentityRequiredPermissions, nil)
},
checkAccessMocks: func(cancel context.CancelFunc, pdpClient *mock_remotepdp.MockRemotePDPClient, tokenCred *mock_azcore.MockTokenCredential) {
Expand Down Expand Up @@ -431,7 +431,7 @@ func TestValidatePlatformWorkloadIdentityProfile(t *testing.T) {
},
},
mocks: func(roleDefinitions *mock_armauthorization.MockRoleDefinitionsClient) {
roleDefinitions.EXPECT().GetByID(ctx, rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole, &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, errors.New("Generic Error"))
roleDefinitions.EXPECT().GetByID(ctx, fmt.Sprintf("/providers/Microsoft.Authorization/roleDefinitions/%s", rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole), &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, errors.New("Generic Error"))
},
wantErr: "Generic Error",
},
Expand All @@ -458,7 +458,7 @@ func TestValidatePlatformWorkloadIdentityProfile(t *testing.T) {
},
},
mocks: func(roleDefinitions *mock_armauthorization.MockRoleDefinitionsClient) {
roleDefinitions.EXPECT().GetByID(ctx, rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole, &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, nil)
roleDefinitions.EXPECT().GetByID(ctx, fmt.Sprintf("/providers/Microsoft.Authorization/roleDefinitions/%s", rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole), &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, nil)
},
wantErr: "parsing failed for Invalid UUID. Invalid resource Id format",
},
Expand All @@ -479,7 +479,7 @@ func TestValidatePlatformWorkloadIdentityProfile(t *testing.T) {
},
},
mocks: func(roleDefinitions *mock_armauthorization.MockRoleDefinitionsClient) {
roleDefinitions.EXPECT().GetByID(ctx, rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole, &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, nil)
roleDefinitions.EXPECT().GetByID(ctx, fmt.Sprintf("/providers/Microsoft.Authorization/roleDefinitions/%s", rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole), &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, nil)
roleDefinitions.EXPECT().GetByID(ctx, gomock.Any(), &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).AnyTimes().Return(platformIdentityRequiredPermissions, nil)
},
checkAccessMocks: func(cancel context.CancelFunc, pdpClient *mock_remotepdp.MockRemotePDPClient, tokenCred *mock_azcore.MockTokenCredential) {
Expand Down Expand Up @@ -508,7 +508,7 @@ func TestValidatePlatformWorkloadIdentityProfile(t *testing.T) {
},
},
mocks: func(roleDefinitions *mock_armauthorization.MockRoleDefinitionsClient) {
roleDefinitions.EXPECT().GetByID(ctx, rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole, &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, nil)
roleDefinitions.EXPECT().GetByID(ctx, fmt.Sprintf("/providers/Microsoft.Authorization/roleDefinitions/%s", rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole), &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, nil)
roleDefinitions.EXPECT().GetByID(ctx, gomock.Any(), &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).AnyTimes().Return(platformIdentityRequiredPermissions, nil)
},
checkAccessMocks: func(cancel context.CancelFunc, pdpClient *mock_remotepdp.MockRemotePDPClient, tokenCred *mock_azcore.MockTokenCredential) {
Expand Down Expand Up @@ -537,7 +537,7 @@ func TestValidatePlatformWorkloadIdentityProfile(t *testing.T) {
},
},
mocks: func(roleDefinitions *mock_armauthorization.MockRoleDefinitionsClient) {
roleDefinitions.EXPECT().GetByID(ctx, rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole, &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, nil)
roleDefinitions.EXPECT().GetByID(ctx, fmt.Sprintf("/providers/Microsoft.Authorization/roleDefinitions/%s", rbac.RoleAzureRedHatOpenShiftFederatedCredentialRole), &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).Return(msiRequiredPermissions, nil)
roleDefinitions.EXPECT().GetByID(ctx, gomock.Any(), &sdkauthorization.RoleDefinitionsClientGetByIDOptions{}).AnyTimes().Return(platformIdentityRequiredPermissions, errors.New("Generic Error"))
},
checkAccessMocks: func(cancel context.CancelFunc, pdpClient *mock_remotepdp.MockRemotePDPClient, tokenCred *mock_azcore.MockTokenCredential) {
Expand Down

0 comments on commit 177aa40

Please sign in to comment.