@@ -82,10 +82,12 @@ def _match_attr_name(attr, ava):
8282 try :
8383 friendly_name = attr ["friendly_name" ]
8484 except KeyError :
85- friendly_name = get_local_name (acs , attr ["name" ], attr ["name_format" ])
85+ friendly_name = get_local_name (acs , attr ["name" ],
86+ attr ["name_format" ])
8687
8788 _fn = _match (friendly_name , ava )
88- if not _fn : # In the unlikely case that someone has provided us with URIs as attribute names
89+ if not _fn : # In the unlikely case that someone has provided us with
90+ # URIs as attribute names
8991 _fn = _match (attr ["name" ], ava )
9092
9193 return _fn
@@ -152,8 +154,8 @@ def filter_on_demands(ava, required=None, optional=None):
152154 for val in vals :
153155 if val not in ava [lava [attr ]]:
154156 raise MissingValue (
155- "Required attribute value missing: %s,%s" % (attr ,
156- val ))
157+ "Required attribute value missing: %s,%s" % (attr ,
158+ val ))
157159 else :
158160 raise MissingValue ("Required attribute missing: %s" % (attr ,))
159161
@@ -266,6 +268,11 @@ def restriction_from_attribute_spec(attributes):
266268
267269def post_entity_categories (maps , ** kwargs ):
268270 restrictions = {}
271+ try :
272+ required = [d ['friendly_name' ].lower () for d in kwargs ['required' ]]
273+ except KeyError :
274+ required = []
275+
269276 if kwargs ["mds" ]:
270277 try :
271278 ecs = kwargs ["mds" ].entity_categories (kwargs ["sp_entity_id" ])
@@ -275,19 +282,25 @@ def post_entity_categories(maps, **kwargs):
275282 restrictions [attr ] = None
276283 else :
277284 for ec_map in maps :
278- for key , val in ec_map .items ():
285+ for key , ( atlist , only_required ) in ec_map .items ():
279286 if key == "" : # always released
280- attrs = val
287+ attrs = atlist
281288 elif isinstance (key , tuple ):
282- attrs = val
289+ if only_required :
290+ attrs = [a for a in atlist if a in required ]
291+ else :
292+ attrs = atlist
283293 for _key in key :
284294 try :
285295 assert _key in ecs
286296 except AssertionError :
287297 attrs = []
288298 break
289299 elif key in ecs :
290- attrs = val
300+ if only_required :
301+ attrs = [a for a in atlist if a in required ]
302+ else :
303+ attrs = atlist
291304 else :
292305 attrs = []
293306
@@ -332,10 +345,15 @@ def compile(self, restrictions):
332345 ecs = []
333346 for cat in items :
334347 _mod = importlib .import_module (
335- "saml2.entity_category.%s" % cat )
348+ "saml2.entity_category.%s" % cat )
336349 _ec = {}
337350 for key , items in _mod .RELEASE .items ():
338- _ec [key ] = [k .lower () for k in items ]
351+ alist = [k .lower () for k in items ]
352+ try :
353+ _only_required = _mod .ONLY_REQUIRED [key ]
354+ except (AttributeError , KeyError ):
355+ _only_required = False
356+ _ec [key ] = (alist , _only_required )
339357 ecs .append (_ec )
340358 spec ["entity_categories" ] = ecs
341359 try :
@@ -444,15 +462,15 @@ def entity_category_attributes(self, ec):
444462 pass
445463 return []
446464
447- def get_entity_categories (self , sp_entity_id , mds ):
465+ def get_entity_categories (self , sp_entity_id , mds , required ):
448466 """
449467
450468 :param sp_entity_id:
451469 :param mds: MetadataStore instance
452470 :return: A dictionary with restrictions
453471 """
454472
455- kwargs = {"mds" : mds }
473+ kwargs = {"mds" : mds , 'required' : required }
456474
457475 return self .get ("entity_categories" , sp_entity_id , default = {},
458476 post_func = post_entity_categories , ** kwargs )
@@ -483,19 +501,15 @@ def filter(self, ava, sp_entity_id, mdstore, required=None, optional=None):
483501 """
484502
485503 _ava = None
486- if required or optional :
487- logger .debug ("required: %s, optional: %s" , required , optional )
488- _ava = filter_on_attributes (
489- ava .copy (), required , optional , self .acs ,
490- self .get_fail_on_missing_requested (sp_entity_id ))
491504
492- _rest = self .get_entity_categories (sp_entity_id , mdstore )
505+ _rest = self .get_entity_categories (sp_entity_id , mdstore , required )
493506 if _rest :
494- ava_ec = filter_attribute_value_assertions (ava .copy (), _rest )
495- if _ava is None :
496- _ava = ava_ec
497- else :
498- _ava .update (ava_ec )
507+ _ava = filter_attribute_value_assertions (ava .copy (), _rest )
508+ elif required or optional :
509+ logger .debug ("required: %s, optional: %s" , required , optional )
510+ _ava = filter_on_attributes (
511+ ava .copy (), required , optional , self .acs ,
512+ self .get_fail_on_missing_requested (sp_entity_id ))
499513
500514 _rest = self .get_attribute_restrictions (sp_entity_id )
501515 if _rest :
@@ -537,9 +551,9 @@ def conditions(self, sp_entity_id):
537551 # How long might depend on who's getting it
538552 not_on_or_after = self .not_on_or_after (sp_entity_id ),
539553 audience_restriction = [factory (
540- saml .AudienceRestriction ,
541- audience = [factory (saml .Audience ,
542- text = sp_entity_id )])])
554+ saml .AudienceRestriction ,
555+ audience = [factory (saml .Audience ,
556+ text = sp_entity_id )])])
543557
544558 def get_sign (self , sp_entity_id ):
545559 """
@@ -569,7 +583,7 @@ def _authn_context_class_ref(authn_class, authn_auth=None):
569583 return factory (saml .AuthnContext ,
570584 authn_context_class_ref = cntx_class ,
571585 authenticating_authority = factory (
572- saml .AuthenticatingAuthority , text = authn_auth ))
586+ saml .AuthenticatingAuthority , text = authn_auth ))
573587 else :
574588 return factory (saml .AuthnContext ,
575589 authn_context_class_ref = cntx_class )
@@ -585,7 +599,7 @@ def _authn_context_decl(decl, authn_auth=None):
585599 return factory (saml .AuthnContext ,
586600 authn_context_decl = decl ,
587601 authenticating_authority = factory (
588- saml .AuthenticatingAuthority , text = authn_auth ))
602+ saml .AuthenticatingAuthority , text = authn_auth ))
589603
590604
591605def _authn_context_decl_ref (decl_ref , authn_auth = None ):
@@ -598,7 +612,7 @@ def _authn_context_decl_ref(decl_ref, authn_auth=None):
598612 return factory (saml .AuthnContext ,
599613 authn_context_decl_ref = decl_ref ,
600614 authenticating_authority = factory (
601- saml .AuthenticatingAuthority , text = authn_auth ))
615+ saml .AuthenticatingAuthority , text = authn_auth ))
602616
603617
604618def authn_statement (authn_class = None , authn_auth = None ,
@@ -624,29 +638,29 @@ def authn_statement(authn_class=None, authn_auth=None,
624638
625639 if authn_class :
626640 res = factory (
627- saml .AuthnStatement ,
628- authn_instant = _instant ,
629- session_index = sid (),
630- authn_context = _authn_context_class_ref (
631- authn_class , authn_auth ))
641+ saml .AuthnStatement ,
642+ authn_instant = _instant ,
643+ session_index = sid (),
644+ authn_context = _authn_context_class_ref (
645+ authn_class , authn_auth ))
632646 elif authn_decl :
633647 res = factory (
634- saml .AuthnStatement ,
635- authn_instant = _instant ,
636- session_index = sid (),
637- authn_context = _authn_context_decl (authn_decl , authn_auth ))
648+ saml .AuthnStatement ,
649+ authn_instant = _instant ,
650+ session_index = sid (),
651+ authn_context = _authn_context_decl (authn_decl , authn_auth ))
638652 elif authn_decl_ref :
639653 res = factory (
640- saml .AuthnStatement ,
641- authn_instant = _instant ,
642- session_index = sid (),
643- authn_context = _authn_context_decl_ref (authn_decl_ref ,
644- authn_auth ))
654+ saml .AuthnStatement ,
655+ authn_instant = _instant ,
656+ session_index = sid (),
657+ authn_context = _authn_context_decl_ref (authn_decl_ref ,
658+ authn_auth ))
645659 else :
646660 res = factory (
647- saml .AuthnStatement ,
648- authn_instant = _instant ,
649- session_index = sid ())
661+ saml .AuthnStatement ,
662+ authn_instant = _instant ,
663+ session_index = sid ())
650664
651665 if subject_locality :
652666 res .subject_locality = saml .SubjectLocality (text = subject_locality )
@@ -688,7 +702,8 @@ def do_subject(policy, sp_entity_id, name_id, **farg):
688702 specs = farg ['subject_confirmation' ]
689703
690704 if isinstance (specs , list ):
691- res = [do_subject_confirmation (policy , sp_entity_id , ** s ) for s in specs ]
705+ res = [do_subject_confirmation (policy , sp_entity_id , ** s ) for s in
706+ specs ]
692707 else :
693708 res = [do_subject_confirmation (policy , sp_entity_id , ** specs )]
694709
@@ -736,7 +751,7 @@ def construct(self, sp_entity_id, attrconvs, policy, issuer, farg,
736751 _name_format = NAME_FORMAT_URI
737752
738753 attr_statement = saml .AttributeStatement (attribute = from_local (
739- attrconvs , self , _name_format ))
754+ attrconvs , self , _name_format ))
740755
741756 if encrypt == "attributes" :
742757 for attr in attr_statement .attribute :
0 commit comments