Skip to content

Commit

Permalink
Add MSAL support (#330)
Browse files Browse the repository at this point in the history
* Use MSAL library support instead of ADAL.

---------

Co-authored-by: Ray Luo <[email protected]>
  • Loading branch information
akharit and rayluo authored Apr 24, 2023
1 parent 7f6e5ea commit b44a68d
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 179 deletions.
3 changes: 2 additions & 1 deletion HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
Release History
===============

0.0.53 (2022-10-26)
0.0.53 (2023-04-11)
+++++++++++++++++++
* Add MSAL support. Remove ADAL support
* Suppress deprecation warning when detecting pyopenssl existence.

0.0.52 (2020-11-25)
Expand Down
4 changes: 0 additions & 4 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,10 @@ Manually (bleeding edge):

* Download the repo from [https://github.com/Azure/azure-data-lake-store-python](https://github.com/Azure/azure-data-lake-store-python)

* checkout the `dev` branch

* install the requirements (`pip install -r dev_requirements.txt`)

* install in develop mode (`python setup.py develop`)

* optionally: build the documentation (including this page) by running `make html` in the docs directory.

## Auth

Although users can generate and supply their own tokens to the base file-system
Expand Down
104 changes: 51 additions & 53 deletions azure/datalake/store/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
else:
import urllib

from .retry import ExponentialRetryPolicy, retry_decorator_for_auth
from .retry import ExponentialRetryPolicy

# 3rd party imports
import adal
import msal
import requests
import requests.exceptions

_http_cache = {} # Useful for MSAL. https://msal-python.readthedocs.io/en/latest/#msal.PublicClientApplication.params.http_cache

# this is required due to github issue, to ensure we don't lose perf from openPySSL: https://github.com/pyca/pyopenssl/issues/625
def enforce_no_py_open_ssl():
Expand Down Expand Up @@ -118,13 +119,8 @@ def auth(tenant_id=None, username=None,
if not authority:
authority = 'https://login.microsoftonline.com/'


if not tenant_id:
tenant_id = os.environ.get('azure_tenant_id', "common")

context = adal.AuthenticationContext(authority +
tenant_id)

if tenant_id is None or client_id is None:
raise ValueError("tenant_id and client_id must be supplied for authentication")

Expand All @@ -136,60 +132,59 @@ def auth(tenant_id=None, username=None,

if not client_secret:
client_secret = os.environ.get('azure_client_secret', None)

# You can explicitly authenticate with 2fa, or pass in nothing to the auth call
# and the user will be prompted to login interactively through a browser.

@retry_decorator_for_auth(retry_policy=retry_policy)

scopes = kwargs.get('scopes', ["https://datalake.azure.net/.default"])
def get_token_internal():
# Internal function used so as to use retry decorator
if require_2fa or (username is None and password is None and client_secret is None):
code = context.acquire_user_code(resource, client_id)
print(code['message'])
out = context.acquire_token_with_device_code(resource, code, client_id)
contextPub = msal.PublicClientApplication(client_id=client_id, authority=authority+tenant_id, http_cache=_http_cache)
flow = contextPub.initiate_device_flow(scopes=scopes)
print(flow['message'])
out = contextPub.acquire_token_by_device_flow(flow)
elif username and password:
out = context.acquire_token_with_username_password(resource, username,
password, client_id)
contextPub = msal.PublicClientApplication(client_id=client_id, authority=authority+tenant_id, http_cache=_http_cache)
out = contextPub.acquire_token_by_username_password(username=username, password=password, scopes=scopes)
elif client_id and client_secret:
out = context.acquire_token_with_client_credentials(resource, client_id,
client_secret)
contextClient = msal.ConfidentialClientApplication(client_id=client_id, authority=authority+tenant_id, client_credential=client_secret, http_cache=_http_cache)
out = contextClient.acquire_token_for_client(scopes=scopes)
# for service principal, we store the secret in the credential object for use when refreshing.
out.update({'secret': client_secret})
else:
raise ValueError("No authentication method found for credentials")
return out

out = get_token_internal()

out.update({'access': out['accessToken'], 'resource': resource,
'refresh': out.get('refreshToken', False),
'time': time.time(), 'tenant': tenant_id, 'client': client_id})
if 'error' in out:
msg = "MSAL Error: "+out.get('error_description', "")
err = DatalakeRESTException(msg)
logger.log(logging.ERROR, msg)
raise err

return DataLakeCredential(out)
out.update({'access_token': out['access_token'], 'access': out['access_token'], 'resource': resource,
'refresh': out.get('refresh_token', False),
'time': time.time(), 'tenant': tenant_id, 'client': client_id, 'scopes':scopes})

return DataLakeCredential(out)

class DataLakeCredential:
# Be careful modifying this. DataLakeCredential is a general class in azure, and we have to maintain parity.
def __init__(self, token):
self.token = token

def signed_session(self):
# type: () -> requests.Session
"""Create requests session with any required auth headers applied.
:rtype: requests.Session
"""
session = requests.Session()
if time.time() - self.token['time'] > self.token['expiresIn'] - 100:
if time.time() - self.token['time'] > self.token['expires_in'] - 100:
self.refresh_token()

scheme, token = self.token['tokenType'], self.token['access']
session = requests.Session()
scheme, token = self.token['token_type'], self.token['access_token']
header = "{} {}".format(scheme, token)
session.headers['Authorization'] = header
return session

def refresh_token(self, authority=None):
""" Refresh an expired authorization token
Parameters
----------
authority: string
Expand All @@ -201,25 +196,31 @@ def refresh_token(self, authority=None):
if not authority:
authority = 'https://login.microsoftonline.com/'

context = adal.AuthenticationContext(authority +
self.token['tenant'])

tenant_id = self.token['tenant']
scopes = self.token['scopes']
if self.token.get('secret') and self.token.get('client'):
out = context.acquire_token_with_client_credentials(self.token['resource'],
self.token['client'],
self.token['secret'])
out.update({'secret': self.token['secret']})
client_id = self.token['client']
client_secret = self.token['secret']
contextClient = msal.ConfidentialClientApplication(client_id=client_id, authority=authority+tenant_id, client_credential=client_secret, http_cache=_http_cache)
out = contextClient.acquire_token_for_client(scopes=scopes)
out.update({'secret': client_secret})
else:
out = context.acquire_token_with_refresh_token(self.token['refresh'],
client_id=self.token['client'],
resource=self.token['resource'])
contextPub = msal.PublicClientApplication(client_id=client_id, authority=authority+tenant_id, http_cache=_http_cache)
out = contextPub.client.obtain_token_by_refresh_token(self.token['refresh'], scopes=scopes)

if 'error' in out:
msg = "MSAL Error: "+out.get('error_description', "")
err = DatalakeRESTException(msg)
logger.log(logging.ERROR, msg)
raise err
# common items to update
out.update({'access': out['accessToken'],
out.update({'access_token': out['access_token'], 'access': out['access_token'],
'time': time.time(), 'tenant': self.token['tenant'],
'resource': self.token['resource'], 'client': self.token['client']})
'resource': self.token['resource'], 'client': self.token['client'], 'scopes':self.token['scopes']})

self.token = out


class DatalakeRESTInterface:
""" Call factory for webHDFS endpoints on ADLS
Expand All @@ -228,7 +229,7 @@ class DatalakeRESTInterface:
store_name: str
The name of the Data Lake Store account to execute operations against.
token: dict
from `auth()` or `refresh_token()` or other ADAL source
from `auth()` or `refresh_token()` or other MSAL source
url_suffix: str (None)
Domain to send REST requests to. The end-point URL is constructed
using this and the store_name. If None, use default.
Expand Down Expand Up @@ -309,13 +310,10 @@ def session(self):
return s

def _check_token(self, retry_policy= None):
@retry_decorator_for_auth(retry_policy=retry_policy)
def check_token_internal():
cur_session = self.token.signed_session()
if not self.head or self.head.get('Authorization') != cur_session.headers['Authorization']:
self.head = {'Authorization': cur_session.headers['Authorization']}
self.local.session = None
check_token_internal()
cur_session = self.token.signed_session()
if not self.head or self.head.get('Authorization') != cur_session.headers['Authorization']:
self.head = {'Authorization': cur_session.headers['Authorization']}
self.local.session = None

def _log_request(self, method, url, op, path, params, headers, retry_count):
msg = u"HTTP Request\n{} {}\n".format(method.upper(), url)
Expand Down Expand Up @@ -498,7 +496,7 @@ def __getstate__(self):

"""
Not yet implemented (or not applicable)
http://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-hdfs/WebHDFS.html
https://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-hdfs/WebHDFS.html
GETFILECHECKSUM
GETHOMEDIRECTORY
Expand Down
12 changes: 6 additions & 6 deletions azure/datalake/store/multiprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

def monitor_exception(exception_queue, process_ids):
global GLOBAL_EXCEPTION
logger = logging.getLogger(__name__)
logger = logging.getLogger("azure.datalake.store")

while True:
try:
Expand Down Expand Up @@ -53,8 +53,8 @@ def log_listener_process(queue):
queue.task_done()
if record == END_QUEUE_SENTINEL: # We send this as a sentinel to tell the listener to quit.
break
logger = logging.getLogger(record.name)
logger.handlers.clear()
logger = logging.getLogger("azure.datalake.store")
#logger.handlers.clear()
logger.handle(record) # No level or filter logic applied - just do it!
except Empty: # Try again
pass
Expand All @@ -65,7 +65,7 @@ def log_listener_process(queue):


def multi_processor_change_acl(adl, path=None, method_name="", acl_spec="", number_of_sub_process=None):
logger = logging.getLogger(__name__)
logger = logging.getLogger("azure.datalake.store")

def launch_processes(number_of_processes):
if number_of_processes is None:
Expand Down Expand Up @@ -152,8 +152,8 @@ def walk(walk_path):


def processor(adl, file_path_queue, finish_queue_processing_flag, method_name, acl_spec, log_queue, exception_queue):
logger = logging.getLogger(__name__)

logger = logging.getLogger("azure.datalake.store")
logger.setLevel(logging.DEBUG)
removed_default_acl_spec = ",".join([x for x in acl_spec.split(',') if not x.lower().startswith("default")])

try:
Expand Down
52 changes: 0 additions & 52 deletions azure/datalake/store/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,55 +74,3 @@ def should_retry(self, response, last_exception, retry_count):
def __backoff(self):
time.sleep(self.exponential_retry_interval)
self.exponential_retry_interval *= self.exponential_factor


def retry_decorator_for_auth(retry_policy = None):
import adal
from requests import HTTPError
if retry_policy is None:
retry_policy = ExponentialRetryPolicy(max_retries=2)

def deco_retry(func):
@wraps(func)
def f_retry(*args, **kwargs):
retry_count = -1
while True:
last_exception = None
retry_count += 1
try:
out = func(*args, **kwargs)
except (adal.adal_error.AdalError, HTTPError) as e:
# ADAL error corresponds to everything but 429, which bubbles up HTTP error.
last_exception = e
logger.exception("Retry count " + str(retry_count) + "Exception :" + str(last_exception))
# We don't want to stop retry for any error in parsing the exception. This is a GET operation.
try:
if hasattr(last_exception, 'error_response'): # ADAL exception
response = response_from_adal_exception(last_exception)
if hasattr(last_exception, 'response'): # HTTP exception i.e 429
response = last_exception.response
except:
pass
request_successful = last_exception is None or (response is not None and response.status_code == 401) # 401 = Invalid credentials
if request_successful or not retry_policy.should_retry(response, last_exception, retry_count):
break
if last_exception is not None:
raise last_exception
return out
return f_retry

return deco_retry


def response_from_adal_exception(e):
import re
from collections import namedtuple

http_code = re.search(r"http error: (\d+)", str(e))
if http_code is not None: # Add status_code to response object for use in should_retry
status_code = [int(http_code.group(1))]
Response = namedtuple("Response", ['status_code'])
response = Response(
*status_code) # Construct response object with adal exception response and http code
return response

2 changes: 1 addition & 1 deletion azure/datalake/store/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def decrement(self):
self.lock.acquire()
self.val -= 1
if self.val <= 0:
self.lock.notifyAll()
self.lock.notify_all()
self.lock.release()

def total_processed(self):
Expand Down
8 changes: 6 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
description='Azure Data Lake Store Filesystem Client Library for Python',
url='https://github.com/Azure/azure-data-lake-store-python',
author='Microsoft Corporation',
author_email='ptvshelp@microsoft.com',
author_email='Akshat.Harit@microsoft.com',
license='MIT License',
keywords='azure',
classifiers=[
Expand All @@ -36,6 +36,10 @@
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'License :: OSI Approved :: MIT License',
],
packages=find_packages(exclude=['tests',
Expand All @@ -44,7 +48,7 @@
]),
install_requires=[
'cffi',
'adal>=0.4.2',
'msal>=1.16.0,<2', # http_cache was introduced in MSAL 1.16.0
'requests>=2.20.0',
],
extras_require={
Expand Down
Loading

0 comments on commit b44a68d

Please sign in to comment.