Polish OpenSamlAuthenticationRequestFactory

- Refactored to use SAMLMetadataSignatureSigningParametersResolver

Issue gh-7758
This commit is contained in:
Josh Cummings 2020-09-25 16:27:01 -06:00
parent 2ee455b7bf
commit a36baffb3a
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
2 changed files with 117 additions and 51 deletions

View File

@ -21,11 +21,14 @@ import java.security.PrivateKey;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
import java.time.Clock; import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.util.Collection; import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.UUID; import java.util.UUID;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import net.shibboleth.utilities.java.support.xml.SerializeSupport; import net.shibboleth.utilities.java.support.xml.SerializeSupport;
import org.joda.time.DateTime; import org.joda.time.DateTime;
import org.opensaml.core.config.ConfigurationService; import org.opensaml.core.config.ConfigurationService;
@ -37,15 +40,18 @@ import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder; import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder;
import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller; import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller;
import org.opensaml.saml.saml2.core.impl.IssuerBuilder; import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver;
import org.opensaml.security.SecurityException; import org.opensaml.security.SecurityException;
import org.opensaml.security.credential.BasicCredential; import org.opensaml.security.credential.BasicCredential;
import org.opensaml.security.credential.Credential; import org.opensaml.security.credential.Credential;
import org.opensaml.security.credential.CredentialSupport; import org.opensaml.security.credential.CredentialSupport;
import org.opensaml.security.credential.UsageType; import org.opensaml.security.credential.UsageType;
import org.opensaml.xmlsec.SignatureSigningParameters; import org.opensaml.xmlsec.SignatureSigningParameters;
import org.opensaml.xmlsec.SignatureSigningParametersResolver;
import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion;
import org.opensaml.xmlsec.crypto.XMLSigningUtil; import org.opensaml.xmlsec.crypto.XMLSigningUtil;
import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration;
import org.opensaml.xmlsec.signature.support.SignatureConstants; import org.opensaml.xmlsec.signature.support.SignatureConstants;
import org.opensaml.xmlsec.signature.support.SignatureException;
import org.opensaml.xmlsec.signature.support.SignatureSupport; import org.opensaml.xmlsec.signature.support.SignatureSupport;
import org.w3c.dom.Element; import org.w3c.dom.Element;
@ -58,6 +64,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.UriUtils; import org.springframework.web.util.UriUtils;
/** /**
@ -105,9 +112,17 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
request.getAssertionConsumerServiceUrl(), this.protocolBindingResolver.convert(null)); request.getAssertionConsumerServiceUrl(), this.protocolBindingResolver.convert(null));
for (org.springframework.security.saml2.credentials.Saml2X509Credential credential : request.getCredentials()) { for (org.springframework.security.saml2.credentials.Saml2X509Credential credential : request.getCredentials()) {
if (credential.isSigningCredential()) { if (credential.isSigningCredential()) {
Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), X509Certificate certificate = credential.getCertificate();
request.getIssuer()); PrivateKey privateKey = credential.getPrivateKey();
return serialize(sign(authnRequest, cred)); BasicCredential cred = CredentialSupport.getSimpleCredential(certificate, privateKey);
cred.setEntityId(request.getIssuer());
cred.setUsageType(UsageType.SIGNING);
SignatureSigningParameters parameters = new SignatureSigningParameters();
parameters.setSigningCredential(cred);
parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256);
parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
return serialize(sign(authnRequest, parameters));
} }
} }
throw new IllegalArgumentException("No signing credential provided"); throw new IllegalArgumentException("No signing credential provided");
@ -132,16 +147,13 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml)); String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
result.samlRequest(deflatedAndEncoded).relayState(context.getRelayState()); result.samlRequest(deflatedAndEncoded).relayState(context.getRelayState());
if (context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()) { if (context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
Collection<Saml2X509Credential> signingCredentials = context.getRelyingPartyRegistration() Map<String, String> parameters = new LinkedHashMap<>();
.getSigningX509Credentials(); parameters.put("SAMLRequest", deflatedAndEncoded);
for (Saml2X509Credential credential : signingCredentials) { if (StringUtils.hasText(context.getRelayState())) {
Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), ""); parameters.put("RelayState", context.getRelayState());
Map<String, String> signedParams = signQueryParameters(cred, deflatedAndEncoded,
context.getRelayState());
return result.samlRequest(signedParams.get("SAMLRequest")).relayState(signedParams.get("RelayState"))
.sigAlg(signedParams.get("SigAlg")).signature(signedParams.get("Signature")).build();
} }
throw new Saml2Exception("No signing credential provided"); sign(parameters, context.getRelyingPartyRegistration());
return result.sigAlg(parameters.get("SigAlg")).signature(parameters.get("Signature")).build();
} }
return result.build(); return result.build();
} }
@ -211,59 +223,39 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
} }
private AuthnRequest sign(AuthnRequest authnRequest, RelyingPartyRegistration relyingPartyRegistration) { private AuthnRequest sign(AuthnRequest authnRequest, RelyingPartyRegistration relyingPartyRegistration) {
for (Saml2X509Credential credential : relyingPartyRegistration.getSigningX509Credentials()) { SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration);
Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), return sign(authnRequest, parameters);
relyingPartyRegistration.getEntityId());
return sign(authnRequest, cred);
}
throw new IllegalArgumentException("No signing credential provided");
} }
private AuthnRequest sign(AuthnRequest authnRequest, Credential credential) { private AuthnRequest sign(AuthnRequest authnRequest, SignatureSigningParameters parameters) {
SignatureSigningParameters parameters = new SignatureSigningParameters();
parameters.setSigningCredential(credential);
parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256);
parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
try { try {
SignatureSupport.signObject(authnRequest, parameters); SignatureSupport.signObject(authnRequest, parameters);
return authnRequest; return authnRequest;
} }
catch (MarshallingException | SignatureException | SecurityException ex) { catch (Exception ex) {
throw new Saml2Exception(ex); throw new Saml2Exception(ex);
} }
} }
private Credential getSigningCredential(X509Certificate certificate, PrivateKey privateKey, String entityId) { private void sign(Map<String, String> components, RelyingPartyRegistration relyingPartyRegistration) {
BasicCredential cred = CredentialSupport.getSimpleCredential(certificate, privateKey); SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration);
cred.setEntityId(entityId); sign(components, parameters);
cred.setUsageType(UsageType.SIGNING);
return cred;
} }
private Map<String, String> signQueryParameters(Credential credential, String samlRequest, String relayState) { private void sign(Map<String, String> components, SignatureSigningParameters parameters) {
Assert.notNull(samlRequest, "samlRequest cannot be null"); Credential credential = parameters.getSigningCredential();
String algorithmUri = SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256; String algorithmUri = parameters.getSignatureAlgorithm();
StringBuilder queryString = new StringBuilder(); components.put("SigAlg", algorithmUri);
queryString.append("SAMLRequest").append("=").append(UriUtils.encode(samlRequest, StandardCharsets.ISO_8859_1)) UriComponentsBuilder builder = UriComponentsBuilder.newInstance();
.append("&"); for (Map.Entry<String, String> component : components.entrySet()) {
if (StringUtils.hasText(relayState)) { builder.queryParam(component.getKey(), UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1));
queryString.append("RelayState").append("=")
.append(UriUtils.encode(relayState, StandardCharsets.ISO_8859_1)).append("&");
} }
queryString.append("SigAlg").append("=").append(UriUtils.encode(algorithmUri, StandardCharsets.ISO_8859_1)); String queryString = builder.build(true).toString().substring(1);
try { try {
byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri, byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri,
queryString.toString().getBytes(StandardCharsets.UTF_8)); queryString.getBytes(StandardCharsets.UTF_8));
String b64Signature = Saml2Utils.samlEncode(rawSignature); String b64Signature = Saml2Utils.samlEncode(rawSignature);
Map<String, String> result = new LinkedHashMap<>(); components.put("Signature", b64Signature);
result.put("SAMLRequest", samlRequest);
if (StringUtils.hasText(relayState)) {
result.put("RelayState", relayState);
}
result.put("SigAlg", algorithmUri);
result.put("Signature", b64Signature);
return result;
} }
catch (SecurityException ex) { catch (SecurityException ex) {
throw new Saml2Exception(ex); throw new Saml2Exception(ex);
@ -280,4 +272,40 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
} }
} }
private SignatureSigningParameters resolveSigningParameters(RelyingPartyRegistration relyingPartyRegistration) {
List<Credential> credentials = resolveSigningCredentials(relyingPartyRegistration);
List<String> algorithms = Collections.singletonList(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
List<String> digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256);
String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS;
SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver();
CriteriaSet criteria = new CriteriaSet();
BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration();
signingConfiguration.setSigningCredentials(credentials);
signingConfiguration.setSignatureAlgorithms(algorithms);
signingConfiguration.setSignatureReferenceDigestMethods(digests);
signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization);
criteria.add(new SignatureSigningConfigurationCriterion(signingConfiguration));
try {
SignatureSigningParameters parameters = resolver.resolveSingle(criteria);
Assert.notNull(parameters, "Failed to resolve any signing credential");
return parameters;
}
catch (Exception ex) {
throw new Saml2Exception(ex);
}
}
private List<Credential> resolveSigningCredentials(RelyingPartyRegistration relyingPartyRegistration) {
List<Credential> credentials = new ArrayList<>();
for (Saml2X509Credential x509Credential : relyingPartyRegistration.getSigningX509Credentials()) {
X509Certificate certificate = x509Credential.getCertificate();
PrivateKey privateKey = x509Credential.getPrivateKey();
BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey);
credential.setEntityId(relyingPartyRegistration.getEntityId());
credential.setUsageType(UsageType.SIGNING);
credentials.add(credential);
}
return credentials;
}
} }

View File

@ -26,16 +26,20 @@ import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.AuthnRequest; import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller; import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
import org.opensaml.xmlsec.signature.support.SignatureConstants;
import org.w3c.dom.Document; import org.w3c.dom.Document;
import org.w3c.dom.Element; import org.w3c.dom.Element;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.Saml2Exception;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials; import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -110,6 +114,28 @@ public class OpenSamlAuthenticationRequestFactoryTests {
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT); assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
} }
@Test
public void createRedirectAuthenticationRequestWhenSignRequestThenSignatureIsPresent() {
this.context = this.contextBuilder.relayState("Relay State Value")
.relyingPartyRegistration(this.relyingPartyRegistration).build();
Saml2RedirectAuthenticationRequest request = this.factory.createRedirectAuthenticationRequest(this.context);
assertThat(request.getRelayState()).isEqualTo("Relay State Value");
assertThat(request.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
assertThat(request.getSignature()).isNotNull();
}
@Test
public void createRedirectAuthenticationRequestWhenSignRequestThenCredentialIsRequired() {
Saml2X509Credential credential = org.springframework.security.saml2.core.TestSaml2X509Credentials
.relyingPartyVerifyingCredential();
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials()
.assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build();
this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration)
.build();
assertThatExceptionOfType(Saml2Exception.class)
.isThrownBy(() -> this.factory.createPostAuthenticationRequest(this.context));
}
@Test @Test
public void createPostAuthenticationRequestWhenNotSignRequestThenNoSignatureIsPresent() { public void createPostAuthenticationRequestWhenNotSignRequestThenNoSignatureIsPresent() {
this.context = this.contextBuilder.relayState("Relay State Value") this.context = this.contextBuilder.relayState("Relay State Value")
@ -139,6 +165,18 @@ public class OpenSamlAuthenticationRequestFactoryTests {
.contains("ds:Signature"); .contains("ds:Signature");
} }
@Test
public void createPostAuthenticationRequestWhenSignRequestThenCredentialIsRequired() {
Saml2X509Credential credential = org.springframework.security.saml2.core.TestSaml2X509Credentials
.relyingPartyVerifyingCredential();
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials()
.assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build();
this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration)
.build();
assertThatExceptionOfType(Saml2Exception.class)
.isThrownBy(() -> this.factory.createPostAuthenticationRequest(this.context));
}
@Test @Test
public void createAuthenticationRequestWhenDefaultThenReturnsPostBinding() { public void createAuthenticationRequestWhenDefaultThenReturnsPostBinding() {
AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST); AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST);