Skip to content

Commit

Permalink
[Connection] Refine workspace connection to require limited permission (
Browse files Browse the repository at this point in the history
#1135)

# Description

When specifying workspace connection provider and get connection with
secrets, sdk will call arm, this path does not requires create pf azure
client, as pf azure client creation requires workspace read permission.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Signed-off-by: Brynn Yin <[email protected]>
Co-authored-by: Zhengfei Wang <[email protected]>
  • Loading branch information
brynn-code and zhengfeiwang authored Nov 14, 2023
1 parent bf3c777 commit 2499daf
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,28 @@
class LocalAzureConnectionOperations(TelemetryMixin):
def __init__(self, connection_provider, **kwargs):

from promptflow.azure._pf_client import PFClient as PFAzureClient

super().__init__(**kwargs)
subscription_id, resource_group, workspace_name = self._extract_workspace(connection_provider)
self._pfazure_client = PFAzureClient(
# TODO: disable interactive credential when starting as a service
credential=self.get_credential(),
subscription_id=subscription_id,
resource_group_name=resource_group,
workspace_name=workspace_name,
)
self._subscription_id, self._resource_group, self._workspace_name = self._extract_workspace(connection_provider)
# Lazy init client as ml_client initialization require workspace read permission
self._pfazure_client = None
self._credential = self._get_credential()

@property
def _client(self):
if self._pfazure_client is None:
from promptflow.azure._pf_client import PFClient as PFAzureClient

self._pfazure_client = PFAzureClient(
# TODO: disable interactive credential when starting as a service
credential=self._credential,
subscription_id=self._subscription_id,
resource_group_name=self._resource_group,
workspace_name=self._workspace_name,
)
return self._pfazure_client

@classmethod
def get_credential(cls):
def _get_credential(cls):
from azure.identity import DefaultAzureCredential, DeviceCodeCredential

if is_from_cli():
Expand Down Expand Up @@ -95,7 +103,7 @@ def list(
logger.warning(
"max_results and all_results are not supported for workspace connection and will be ignored."
)
return self._pfazure_client._connections.list()
return self._client._connections.list()

@monitor_operation(activity_name="pf.connections.azure.get", activity_type=ActivityType.PUBLICAPI)
def get(self, name: str, **kwargs) -> _Connection:
Expand All @@ -108,8 +116,14 @@ def get(self, name: str, **kwargs) -> _Connection:
"""
with_secrets = kwargs.get("with_secrets", False)
if with_secrets:
return self._pfazure_client._arm_connections.get(name)
return self._pfazure_client._connections.get(name)
# Do not use pfazure_client here as it requires workspace read permission
# Get secrets from arm only requires workspace listsecrets permission
from promptflow.azure.operations._arm_connection_operations import ArmConnectionOperations

return ArmConnectionOperations._direct_get(
name, self._subscription_id, self._resource_group, self._workspace_name, self._credential
)
return self._client._connections.get(name)

@monitor_operation(activity_name="pf.connections.azure.delete", activity_type=ActivityType.PUBLICAPI)
def delete(self, name: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ def get(self, name, **kwargs):
connection_dict = self.build_connection_dict(name)
return _Connection._from_execution_connection_dict(name=name, data=connection_dict)

@classmethod
def _direct_get(cls, name, subscription_id, resource_group_name, workspace_name, credential):
"""
This method is added for local pf_client with workspace provider to ensure we only require limited
permission(workspace/list secrets). As create azure pf_client requires workspace read permission.
"""
connection_dict = cls._build_connection_dict(
name, subscription_id, resource_group_name, workspace_name, credential
)
return _Connection._from_execution_connection_dict(name=name, data=connection_dict)

@classmethod
def open_url(cls, token, url, action, host="management.azure.com", method="GET", model=None) -> Union[Any, dict]:
"""
Expand Down Expand Up @@ -199,19 +210,33 @@ def build_connection_dict_from_rest_object(cls, name, obj) -> dict:
# Note: Filter empty values out to ensure default values can be picked when init class object.
return {**meta, "value": {k: v for k, v in value.items() if v}}

def build_connection_dict(self, name) -> dict:
def build_connection_dict(self, name):
return self._build_connection_dict(
name,
self._operation_scope.subscription_id,
self._operation_scope.resource_group_name,
self._operation_scope.workspace_name,
self._credential,
)

@classmethod
def _build_connection_dict(cls, name, subscription_id, resource_group_name, workspace_name, credential) -> dict:
"""
:type name: str
:type subscription_id: str
:type resource_group_name: str
:type workspace_name: str
:type credential: azure.identity.TokenCredential
"""
url = GET_CONNECTION_URL.format(
sub=self._operation_scope.subscription_id,
rg=self._operation_scope.resource_group_name,
ws=self._operation_scope.workspace_name,
sub=subscription_id,
rg=resource_group_name,
ws=workspace_name,
name=name,
)
try:
rest_obj: WorkspaceConnectionPropertiesV2BasicResource = self.open_url(
self._credential.get_token("https://management.azure.com/.default").token,
rest_obj: WorkspaceConnectionPropertiesV2BasicResource = cls.open_url(
credential.get_token("https://management.azure.com/.default").token,
url=url,
action="listsecrets",
method="POST",
Expand All @@ -225,7 +250,7 @@ def build_connection_dict(self, name) -> dict:
)
raise OpenURLUserAuthenticationError(message=auth_error_message)
try:
return self.build_connection_dict_from_rest_object(name, rest_obj)
return cls.build_connection_dict_from_rest_object(name, rest_obj)
except Exception as e:
raise BuildConnectionError(
message_format=f"Build connection dict for connection {{name}} failed with {e}.",
Expand Down

0 comments on commit 2499daf

Please sign in to comment.