From fcee9c12b9d9e81b3c257af9c37fb549b2774821 Mon Sep 17 00:00:00 2001
From: war <thomas@warwaris.at>
Date: Mon, 8 Jan 2024 04:19:33 +0100
Subject: [PATCH] allow signature and digest algorithm as parameters for
 MD-generation. Also use sha256 as default to prevent breaking on systems with
 disabled sha1

---
 src/satosa/metadata_creation/saml_metadata.py |  9 +++-
 src/satosa/scripts/satosa_saml_metadata.py    | 46 ++++++++++++++-----
 2 files changed, 42 insertions(+), 13 deletions(-)

diff --git a/src/satosa/metadata_creation/saml_metadata.py b/src/satosa/metadata_creation/saml_metadata.py
index f88bbaaec..dfc2a4c7f 100644
--- a/src/satosa/metadata_creation/saml_metadata.py
+++ b/src/satosa/metadata_creation/saml_metadata.py
@@ -134,11 +134,14 @@ def create_signed_entities_descriptor(entity_descriptors, security_context, vali
     return xmldoc
 
 
-def create_signed_entity_descriptor(entity_descriptor, security_context, valid_for=None):
+def create_signed_entity_descriptor(entity_descriptor, security_context, valid_for=None, sign_alg=None,
+                                    digest_alg=None):
     """
     :param entity_descriptor: the entity descriptor to sign
     :param security_context: security context for the signature
     :param valid_for: number of hours the metadata should be valid
+    :param sign_alg: signature algorithm from saml2.xmldsig
+    :param digest_alg: digest algorithm from saml2.xmldsig
     :return: the signed XML document
 
     :type entity_descriptor: saml2.md.EntityDescriptor]
@@ -148,7 +151,9 @@ def create_signed_entity_descriptor(entity_descriptor, security_context, valid_f
     if valid_for:
         entity_descriptor.valid_until = in_a_while(hours=valid_for)
 
-    entity_desc, xmldoc = sign_entity_descriptor(entity_descriptor, None, security_context)
+    entity_desc, xmldoc = sign_entity_descriptor(entity_descriptor, None, security_context,
+                                                 sign_alg=sign_alg,
+                                                 digest_alg=digest_alg)
 
     if not valid_instance(entity_desc):
         raise ValueError("Could not construct valid EntityDescriptor tag")
diff --git a/src/satosa/scripts/satosa_saml_metadata.py b/src/satosa/scripts/satosa_saml_metadata.py
index c0638d8b7..193e8e653 100644
--- a/src/satosa/scripts/satosa_saml_metadata.py
+++ b/src/satosa/scripts/satosa_saml_metadata.py
@@ -3,6 +3,7 @@
 import click
 from saml2.config import Config
 from saml2.sigver import security_context
+from saml2 import xmldsig
 
 from ..metadata_creation.saml_metadata import create_entity_descriptors
 from ..metadata_creation.saml_metadata import create_entity_descriptor_metadata
@@ -17,12 +18,22 @@ def _get_security_context(key, cert):
     return security_context(conf)
 
 
-def _create_split_entity_descriptors(entities, secc, valid, sign=True):
+def _get_sign_and_digest_alg(signature_algorithm, digest_algorithm):
+    sign_alg = digest_alg = None
+    if signature_algorithm:
+        sign_alg = getattr(xmldsig, signature_algorithm)
+    if digest_algorithm:
+        digest_alg = getattr(xmldsig, digest_algorithm)
+    return sign_alg, digest_alg
+
+def _create_split_entity_descriptors(entities, secc, valid, sign=True, signature_algorithm=None,
+                                     digest_algorithm=None):
     output = []
+    sign_alg, digest_alg = _get_sign_and_digest_alg(signature_algorithm, digest_algorithm)
     for module_name, eds in entities.items():
         for i, ed in enumerate(eds):
             ed_str = (
-                create_signed_entity_descriptor(ed, secc, valid)
+                create_signed_entity_descriptor(ed, secc, valid, sign_alg=sign_alg, digest_alg=digest_alg)
                 if sign
                 else create_entity_descriptor_metadata(ed, valid)
             )
@@ -31,12 +42,14 @@ def _create_split_entity_descriptors(entities, secc, valid, sign=True):
     return output
 
 
-def _create_merged_entities_descriptors(entities, secc, valid, name, sign=True):
+def _create_merged_entities_descriptors(entities, secc, valid, name, sign=True, signature_algorithm=None,
+                                     digest_algorithm=None):
     output = []
+    sign_alg, digest_alg = _get_sign_and_digest_alg(signature_algorithm, digest_algorithm)
     frontend_entity_descriptors = [e for sublist in entities.values() for e in sublist]
     for frontend in frontend_entity_descriptors:
         ed_str = (
-            create_signed_entity_descriptor(frontend, secc, valid)
+            create_signed_entity_descriptor(frontend, secc, valid, sign_alg=sign_alg, digest_alg=digest_alg)
             if sign
             else create_entity_descriptor_metadata(frontend, valid)
         )
@@ -46,7 +59,8 @@ def _create_merged_entities_descriptors(entities, secc, valid, name, sign=True):
 
 
 def create_and_write_saml_metadata(proxy_conf, key, cert, dir, valid, split_frontend_metadata=False,
-                                   split_backend_metadata=False, sign=True):
+                                   split_backend_metadata=False, sign=True, signature_algorithm=None,
+                                   digest_algorithm=None):
     """
     Generates SAML metadata for the given PROXY_CONF, signed with the given KEY and associated CERT.
     """
@@ -61,14 +75,18 @@ def create_and_write_saml_metadata(proxy_conf, key, cert, dir, valid, split_fron
     output = []
     if frontend_entities:
         if split_frontend_metadata:
-            output.extend(_create_split_entity_descriptors(frontend_entities, secc, valid, sign))
+            output.extend(_create_split_entity_descriptors(frontend_entities, secc, valid, sign,
+                                                           signature_algorithm, digest_algorithm))
         else:
-            output.extend(_create_merged_entities_descriptors(frontend_entities, secc, valid, "frontend.xml", sign))
+            output.extend(_create_merged_entities_descriptors(frontend_entities, secc, valid, "frontend.xml",
+                                                              sign, signature_algorithm, digest_algorithm))
     if backend_entities:
         if split_backend_metadata:
-            output.extend(_create_split_entity_descriptors(backend_entities, secc, valid, sign))
+            output.extend(_create_split_entity_descriptors(backend_entities, secc, valid, sign, signature_algorithm,
+                                                           digest_algorithm))
         else:
-            output.extend(_create_merged_entities_descriptors(backend_entities, secc, valid, "backend.xml", sign))
+            output.extend(_create_merged_entities_descriptors(backend_entities, secc, valid, "backend.xml",
+                                                              sign, signature_algorithm, digest_algorithm))
 
     for metadata, filename in output:
         path = os.path.join(dir, filename)
@@ -92,5 +110,11 @@ def create_and_write_saml_metadata(proxy_conf, key, cert, dir, valid, split_fron
               help="Create one entity descriptor per file for the backend metadata")
 @click.option("--sign/--no-sign", is_flag=True, type=click.BOOL, default=True,
               help="Sign the generated metadata")
-def construct_saml_metadata(proxy_conf, key, cert, dir, valid, split_frontend, split_backend, sign):
-    create_and_write_saml_metadata(proxy_conf, key, cert, dir, valid, split_frontend, split_backend, sign)
+@click.option("--signature-algorithm", type=click.STRING, default="SIG_RSA_SHA256",
+              help="Algorithm to sign metadata, from xmldsig")
+@click.option("--digest-algorithm", type=click.STRING, default="DIGEST_SHA256",
+              help="Algorithm for the metadata digest, from xmldsig")
+def construct_saml_metadata(proxy_conf, key, cert, dir, valid, split_frontend, split_backend, sign,
+                            signature_algorithm, digest_algorithm):
+    create_and_write_saml_metadata(proxy_conf, key, cert, dir, valid, split_frontend, split_backend,
+                                   sign, signature_algorithm, digest_algorithm)