Skip to content

Commit

Permalink
Merge pull request #432 from REANNZ/filter_attributes_mdscope
Browse files Browse the repository at this point in the history
Filter attributes by shibmd_scope
  • Loading branch information
c00kiemon5ter authored Mar 21, 2023
2 parents 4e8d27c + f8529f1 commit 497aa9c
Show file tree
Hide file tree
Showing 3 changed files with 365 additions and 15 deletions.
29 changes: 29 additions & 0 deletions example/plugins/microservices/filter_attributes.yaml.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,35 @@ module: satosa.micro_services.attribute_modifications.FilterAttributeValues
name: AttributeFilter
config:
attribute_filters:
# default rules for any IdentityProvider
"":
# default rules for any requester
"":
# enforce controlled vocabulary (via simple notation)
eduPersonAffiliation: "^(faculty|student|staff|alum|member|affiliate|employee|library-walk-in)$"
eduPersonPrimaryAffiliation: "^(faculty|student|staff|alum|member|affiliate|employee|library-walk-in)$"
eduPersonScopedAffiliation:
# enforce controlled vocabulary (via extended notation)
regexp: "^(faculty|student|staff|alum|member|affiliate|employee|library-walk-in)@"
# enforce correct scope
shibmdscope_match_scope:
eduPersonPrincipalName:
# enforce correct scope
shibmdscope_match_scope:
subject-id:
# enforce attribute syntax
regexp: "^[0-9A-Za-z][-=0-9A-Za-z]{0,126}@[0-9A-Za-z][-.0-9A-Za-z]{0,126}\\Z"
# enforce correct scope
shibmdscope_match_scope:
pairwise-id:
# enforce attribute syntax
regexp: "^[0-9A-Za-z][-=0-9A-Za-z]{0,126}@[0-9A-Za-z][-.0-9A-Za-z]{0,126}\\Z"
# enforce correct scope
shibmdscope_match_scope:
schacHomeOrganization:
# enforce scoping rule on attribute value
shibmdscope_match_value:

target_provider1:
requester1:
attr1: "^foo:bar$"
68 changes: 53 additions & 15 deletions src/satosa/micro_services/attribute_modifications.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import re
import logging

from .base import ResponseMicroService
from ..context import Context
from ..exception import SATOSAError

logger = logging.getLogger(__name__)

class AddStaticAttributes(ResponseMicroService):
"""
Expand Down Expand Up @@ -29,28 +33,62 @@ def __init__(self, config, *args, **kwargs):
def process(self, context, data):
# apply default filters
provider_filters = self.attribute_filters.get("", {})
self._apply_requester_filters(data.attributes, provider_filters, data.requester)
target_provider = data.auth_info.issuer
self._apply_requester_filters(data.attributes, provider_filters, data.requester, context, target_provider)

# apply target provider specific filters
target_provider = data.auth_info.issuer
provider_filters = self.attribute_filters.get(target_provider, {})
self._apply_requester_filters(data.attributes, provider_filters, data.requester)
self._apply_requester_filters(data.attributes, provider_filters, data.requester, context, target_provider)
return super().process(context, data)

def _apply_requester_filters(self, attributes, provider_filters, requester):
def _apply_requester_filters(self, attributes, provider_filters, requester, context, target_provider):
# apply default requester filters
default_requester_filters = provider_filters.get("", {})
self._apply_filter(attributes, default_requester_filters)
self._apply_filters(attributes, default_requester_filters, context, target_provider)

# apply requester specific filters
requester_filters = provider_filters.get(requester, {})
self._apply_filter(attributes, requester_filters)

def _apply_filter(self, attributes, attribute_filters):
for attribute_name, attribute_filter in attribute_filters.items():
regex = re.compile(attribute_filter)
if attribute_name == "": # default filter for all attributes
for attribute, values in attributes.items():
attributes[attribute] = list(filter(regex.search, attributes[attribute]))
elif attribute_name in attributes:
attributes[attribute_name] = list(filter(regex.search, attributes[attribute_name]))
self._apply_filters(attributes, requester_filters, context, target_provider)

def _apply_filters(self, attributes, attribute_filters, context, target_provider):
for attribute_name, attribute_filters in attribute_filters.items():
if type(attribute_filters) == str:
# convert simple notation to filter list
attribute_filters = {'regexp': attribute_filters}

for filter_type, filter_value in attribute_filters.items():

if filter_type == "regexp":
filter_func = re.compile(filter_value).search
elif filter_type == "shibmdscope_match_scope":
mdstore = context.get_decoration(Context.KEY_METADATA_STORE)
md_scopes = list(mdstore.shibmd_scopes(target_provider,"idpsso_descriptor")) if mdstore else []
filter_func = lambda v: self._shibmdscope_match_scope(v, md_scopes)
elif filter_type == "shibmdscope_match_value":
mdstore = context.get_decoration(Context.KEY_METADATA_STORE)
md_scopes = list(mdstore.shibmd_scopes(target_provider,"idpsso_descriptor")) if mdstore else []
filter_func = lambda v: self._shibmdscope_match_value(v, md_scopes)
else:
raise SATOSAError("Unknown filter type")

if attribute_name == "": # default filter for all attributes
for attribute, values in attributes.items():
attributes[attribute] = list(filter(filter_func, attributes[attribute]))
elif attribute_name in attributes:
attributes[attribute_name] = list(filter(filter_func, attributes[attribute_name]))

def _shibmdscope_match_value(self, value, md_scopes):
for md_scope in md_scopes:
if not md_scope['regexp'] and md_scope['text'] == value:
return True
elif md_scope['regexp'] and re.fullmatch(md_scope['text'], value):
return True
return False

def _shibmdscope_match_scope(self, value, md_scopes):
split_value = value.split('@')
if len(split_value) != 2:
logger.info(f"Discarding invalid scoped value {value}")
return False
value_scope = split_value[1]
return self._shibmdscope_match_value(value_scope, md_scopes)
Loading

0 comments on commit 497aa9c

Please sign in to comment.