Skip to content

Commit 651284d

Browse files
committed
Extract warning and error validators to eIDASConfig
- Extracts common (for SP/IdP) warning and error validators - Moves eIDASSPConfig validate method to eIDASConfig class
1 parent ff8eba7 commit 651284d

File tree

1 file changed

+48
-33
lines changed

1 file changed

+48
-33
lines changed

src/saml2/config.py

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -621,22 +621,9 @@ def get_type_contact_person(contacts, ctype):
621621
def contact_has_email_address(contact):
622622
return not_empty(contact.get("email_address"))
623623

624-
625-
class eIDASSPConfig(SPConfig, eIDASConfig):
626-
def get_endpoint_element(self, element):
627-
return getattr(self, "_sp_endpoints", {}).get(element, None)
628-
629-
def get_application_identifier(self):
630-
return getattr(self, "_sp_application_identifier", None)
631-
632-
def get_protocol_version(self):
633-
return getattr(self, "_sp_protocol_version", None)
634-
635-
def get_node_country(self):
636-
return getattr(self, "_sp_node_country", None)
637-
638-
def validate(self):
639-
warning_validators = {
624+
@property
625+
def warning_validators(self):
626+
return {
640627
"single_logout_service SHOULD NOT be declared":
641628
self.get_endpoint_element("single_logout_service") is None,
642629
"artifact_resolution_service SHOULD NOT be declared":
@@ -661,15 +648,9 @@ def validate(self):
661648
ctype="support")))
662649
}
663650

664-
if not all(warning_validators.values()):
665-
logger.warning(
666-
"Configuration validation warnings occurred: {}".format(
667-
[msg for msg, check in warning_validators.items()
668-
if check is not True]
669-
)
670-
)
671-
672-
error_validators = {
651+
@property
652+
def error_validators(self):
653+
return {
673654
"KeyDescriptor MUST be declared":
674655
self.cert_file or self.encryption_keypairs,
675656
"node_country MUST be declared in ISO 3166-1 alpha-2 format":
@@ -680,21 +661,55 @@ def validate(self):
680661
self.get_application_identifier()),
681662
"entityid MUST be an HTTPS URL pointing to the location of its published "
682663
"metadata":
683-
parse.urlparse(self.entityid).scheme == "https",
684-
"authn_requests_signed MUST be set to True":
685-
getattr(self, "_sp_authn_requests_signed", None) is True,
686-
"sp_type MUST be set to 'public' or 'private'":
687-
getattr(self, "_sp_sp_type", None) in ("public", "private")
664+
parse.urlparse(self.entityid).scheme == "https"
688665
}
689666

690-
if not all(error_validators.values()):
667+
def validate(self):
668+
if not all(self.warning_validators.values()):
669+
logger.warning(
670+
"Configuration validation warnings occurred: {}".format(
671+
[msg for msg, check in self.warning_validators.items()
672+
if check is not True]
673+
)
674+
)
675+
676+
if not all(self.error_validators.values()):
691677
error = "Configuration validation errors occurred:".format(
692-
[msg for msg, check in error_validators.items()
693-
if check is not True])
678+
[msg for msg, check in self.error_validators.items()
679+
if check is not True])
694680
logger.error(error)
695681
raise ConfigValidationError(error)
696682

697683

684+
class eIDASSPConfig(SPConfig, eIDASConfig):
685+
def get_endpoint_element(self, element):
686+
return getattr(self, "_sp_endpoints", {}).get(element, None)
687+
688+
def get_application_identifier(self):
689+
return getattr(self, "_sp_application_identifier", None)
690+
691+
def get_protocol_version(self):
692+
return getattr(self, "_sp_protocol_version", None)
693+
694+
def get_node_country(self):
695+
return getattr(self, "_sp_node_country", None)
696+
697+
@property
698+
def warning_validators(self):
699+
sp_warning_validators = {}
700+
return {**super().warning_validators, **sp_warning_validators}
701+
702+
@property
703+
def error_validators(self):
704+
sp_error_validators = {
705+
"authn_requests_signed MUST be set to True":
706+
getattr(self, "_sp_authn_requests_signed", None) is True,
707+
"sp_type MUST be set to 'public' or 'private'":
708+
getattr(self, "_sp_sp_type", None) in ("public", "private")
709+
}
710+
return {**super().error_validators, **sp_error_validators}
711+
712+
698713
class IdPConfig(Config):
699714
def_context = "idp"
700715

0 commit comments

Comments
 (0)