|
8 | 8 |
|
9 | 9 | from saml2 import saml |
10 | 10 | from saml2 import xmlenc |
11 | | -from saml2.attribute_converter import from_local, get_local_name |
| 11 | +from saml2.attribute_converter import from_local, ac_factory |
| 12 | +from saml2.attribute_converter import get_local_name |
12 | 13 | from saml2.s_utils import assertion_factory |
13 | 14 | from saml2.s_utils import factory |
14 | | -from saml2.s_utils import sid, MissingValue |
| 15 | +from saml2.s_utils import sid |
| 16 | +from saml2.s_utils import MissingValue |
15 | 17 | from saml2.saml import NAME_FORMAT_URI |
16 | | -from saml2.time_util import instant, in_a_while |
| 18 | +from saml2.time_util import instant |
| 19 | +from saml2.time_util import in_a_while |
17 | 20 |
|
18 | 21 | logger = logging.getLogger(__name__) |
19 | 22 |
|
@@ -78,15 +81,22 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None, |
78 | 81 | """ |
79 | 82 |
|
80 | 83 | def _match_attr_name(attr, ava): |
81 | | - |
82 | | - local_name = get_local_name(acs, attr["name"], attr["name_format"]) |
83 | | - if not local_name: |
84 | | - try: |
85 | | - local_name = attr["friendly_name"] |
86 | | - except KeyError: |
87 | | - pass |
| 84 | + local_name = None |
| 85 | + |
| 86 | + for a in ['name_format', 'friendly_name']: |
| 87 | + _val = attr.get(a) |
| 88 | + if _val: |
| 89 | + if a == 'name_format': |
| 90 | + local_name = get_local_name(acs, attr['name'], _val) |
| 91 | + else: |
| 92 | + local_name = _val |
| 93 | + break |
| 94 | + |
| 95 | + if local_name: |
| 96 | + _fn = _match(local_name, ava) |
| 97 | + else: |
| 98 | + _fn = None |
88 | 99 |
|
89 | | - _fn = _match(local_name, ava) |
90 | 100 | if not _fn: # In the unlikely case that someone has provided us with |
91 | 101 | # URIs as attribute names |
92 | 102 | _fn = _match(attr["name"], ava) |
@@ -117,8 +127,7 @@ def _apply_attr_value_restrictions(attr, res, must=False): |
117 | 127 | if _fn: |
118 | 128 | _apply_attr_value_restrictions(attr, res, True) |
119 | 129 | elif fail_on_unfulfilled_requirements: |
120 | | - desc = "Required attribute missing: '%s' (%s)" % (attr["name"], |
121 | | - _fn) |
| 130 | + desc = "Required attribute missing: '%s'" % (attr["name"]) |
122 | 131 | raise MissingValue(desc) |
123 | 132 |
|
124 | 133 | if optional is None: |
@@ -502,6 +511,9 @@ def filter(self, ava, sp_entity_id, mdstore, required=None, optional=None): |
502 | 511 |
|
503 | 512 | _ava = None |
504 | 513 |
|
| 514 | + if not self.acs: # acs MUST have a value, fall back to default. |
| 515 | + self.acs = ac_factory() |
| 516 | + |
505 | 517 | _rest = self.get_entity_categories(sp_entity_id, mdstore, required) |
506 | 518 | if _rest: |
507 | 519 | _ava = filter_attribute_value_assertions(ava.copy(), _rest) |
|
0 commit comments