Add AssertionConsumerServiceBinding

Closes gh-8776
This commit is contained in:
Josh Cummings 2020-07-16 16:22:38 -06:00
parent 2c960d2ad1
commit 44ec061f05
No known key found for this signature in database
GPG Key ID: 49EF60DD7FF83443
4 changed files with 97 additions and 10 deletions

View File

@ -29,6 +29,7 @@ import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Issuer;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.credentials.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest.Builder;
import org.springframework.util.Assert;
@ -43,7 +44,14 @@ import static org.springframework.security.saml2.provider.service.authentication
public class OpenSamlAuthenticationRequestFactory implements Saml2AuthenticationRequestFactory {
private Clock clock = Clock.systemUTC();
private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
private String protocolBinding = SAMLConstants.SAML2_POST_BINDING_URI;
private Converter<Saml2AuthenticationRequestContext, String> protocolBindingResolver =
context -> {
if (context == null) {
return SAMLConstants.SAML2_POST_BINDING_URI;
}
return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn();
};
private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver
= context -> authnRequest -> {};
@ -52,7 +60,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
@Deprecated
public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
AuthnRequest authnRequest = createAuthnRequest(request.getIssuer(),
request.getDestination(), request.getAssertionConsumerServiceUrl());
request.getDestination(), request.getAssertionConsumerServiceUrl(),
this.protocolBindingResolver.convert(null));
return this.saml.serialize(authnRequest, request.getCredentials());
}
@ -101,12 +110,14 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(),
context.getDestination(), context.getAssertionConsumerServiceUrl());
context.getDestination(), context.getAssertionConsumerServiceUrl(),
this.protocolBindingResolver.convert(context));
this.authnRequestConsumerResolver.apply(context).accept(authnRequest);
return authnRequest;
}
private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl) {
private AuthnRequest createAuthnRequest
(String issuer, String destination, String assertionConsumerServiceUrl, String protocolBinding) {
AuthnRequest auth = this.saml.buildSamlObject(AuthnRequest.DEFAULT_ELEMENT_NAME);
auth.setID("ARQ" + UUID.randomUUID().toString().substring(1));
auth.setIssueInstant(new DateTime(this.clock.millis()));
@ -155,13 +166,16 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
* @param protocolBinding either {@link SAMLConstants#SAML2_POST_BINDING_URI} or
* {@link SAMLConstants#SAML2_REDIRECT_BINDING_URI}
* @throws IllegalArgumentException if the protocolBinding is not valid
* @deprecated Use {@link org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.Builder#assertionConsumerServiceBinding}
* instead
*/
@Deprecated
public void setProtocolBinding(String protocolBinding) {
boolean isAllowedBinding = SAMLConstants.SAML2_POST_BINDING_URI.equals(protocolBinding) ||
SAMLConstants.SAML2_REDIRECT_BINDING_URI.equals(protocolBinding);
if (!isAllowedBinding) {
throw new IllegalArgumentException("Invalid protocol binding: " + protocolBinding);
}
this.protocolBinding = protocolBinding;
this.protocolBindingResolver = context -> protocolBinding;
}
}

View File

@ -68,6 +68,7 @@ public class RelyingPartyRegistration {
private final String registrationId;
private final String entityId;
private final String assertionConsumerServiceLocation;
private final Saml2MessageBinding assertionConsumerServiceBinding;
private final ProviderDetails providerDetails;
private final List<Saml2X509Credential> credentials;
@ -75,12 +76,14 @@ public class RelyingPartyRegistration {
String registrationId,
String entityId,
String assertionConsumerServiceLocation,
Saml2MessageBinding assertionConsumerServiceBinding,
ProviderDetails providerDetails,
List<Saml2X509Credential> credentials) {
Assert.hasText(registrationId, "registrationId cannot be empty");
Assert.hasText(entityId, "entityId cannot be empty");
Assert.hasText(assertionConsumerServiceLocation, "assertionConsumerServiceLocation cannot be empty");
Assert.notNull(assertionConsumerServiceBinding, "assertionConsumerServiceBinding cannot be null");
Assert.notNull(providerDetails, "providerDetails cannot be null");
Assert.notEmpty(credentials, "credentials cannot be empty");
for (Saml2X509Credential c : credentials) {
@ -89,6 +92,7 @@ public class RelyingPartyRegistration {
this.registrationId = registrationId;
this.entityId = entityId;
this.assertionConsumerServiceLocation = assertionConsumerServiceLocation;
this.assertionConsumerServiceBinding = assertionConsumerServiceBinding;
this.providerDetails = providerDetails;
this.credentials = Collections.unmodifiableList(new LinkedList<>(credentials));
}
@ -138,6 +142,18 @@ public class RelyingPartyRegistration {
return this.assertionConsumerServiceLocation;
}
/**
* Get the AssertionConsumerService Binding.
* Equivalent to the value found in &lt;AssertionConsumerService Binding="..."/&gt;
* in the relying party's &lt;SPSSODescriptor&gt;.
*
* @return the AssertionConsumerService Binding
* @since 5.4
*/
public Saml2MessageBinding getAssertionConsumerServiceBinding() {
return this.assertionConsumerServiceBinding;
}
/**
* Get the configuration details for the Asserting Party
*
@ -280,6 +296,7 @@ public class RelyingPartyRegistration {
return withRegistrationId(registration.getRegistrationId())
.entityId(registration.getEntityId())
.assertionConsumerServiceLocation(registration.getAssertionConsumerServiceLocation())
.assertionConsumerServiceBinding(registration.getAssertionConsumerServiceBinding())
.assertingPartyDetails(c -> c
.entityId(registration.getAssertingPartyDetails().getEntityId())
.wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned())
@ -575,6 +592,7 @@ public class RelyingPartyRegistration {
private String registrationId;
private String entityId = "{baseUrl}/saml2/service-provider-metadata/{registrationId}";
private String assertionConsumerServiceLocation;
private Saml2MessageBinding assertionConsumerServiceBinding = Saml2MessageBinding.POST;
private ProviderDetails.Builder providerDetails = new ProviderDetails.Builder();
private List<Saml2X509Credential> credentials = new LinkedList<>();
@ -633,6 +651,23 @@ public class RelyingPartyRegistration {
return this;
}
/**
* Set the <a href="https://wiki.shibboleth.net/confluence/display/CONCEPT/AssertionConsumerService">AssertionConsumerService</a>
* Binding.
*
* <p>
* Equivalent to the value found in &lt;AssertionConsumerService Binding="..."/&gt;
* in the relying party's &lt;SPSSODescriptor&gt;
*
* @param assertionConsumerServiceBinding
* @return the {@link Builder} for further configuration
* @since 5.4
*/
public Builder assertionConsumerServiceBinding(Saml2MessageBinding assertionConsumerServiceBinding) {
this.assertionConsumerServiceBinding = assertionConsumerServiceBinding;
return this;
}
/**
* Apply this {@link Consumer} to further configure the Asserting Party details
*
@ -738,6 +773,7 @@ public class RelyingPartyRegistration {
this.registrationId,
this.entityId,
this.assertionConsumerServiceLocation,
this.assertionConsumerServiceBinding,
this.providerDetails.build(),
this.credentials
);

View File

@ -39,6 +39,7 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
@ -52,19 +53,21 @@ public class OpenSamlAuthenticationRequestFactoryTests {
private Saml2AuthenticationRequestContext.Builder contextBuilder;
private Saml2AuthenticationRequestContext context;
private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder;
private RelyingPartyRegistration relyingPartyRegistration;
@Rule
public ExpectedException exception = ExpectedException.none();
private RelyingPartyRegistration relyingPartyRegistration;
@Before
public void setUp() {
relyingPartyRegistration = RelyingPartyRegistration.withRegistrationId("id")
this.relyingPartyRegistrationBuilder = RelyingPartyRegistration.withRegistrationId("id")
.assertionConsumerServiceLocation("template")
.providerDetails(c -> c.webSsoUrl("https://destination/sso"))
.providerDetails(c -> c.entityId("remote-entity-id"))
.localEntityIdTemplate("local-entity-id")
.credentials(c -> c.add(relyingPartySigningCredential()))
.build();
.credentials(c -> c.add(relyingPartySigningCredential()));
this.relyingPartyRegistration = this.relyingPartyRegistrationBuilder.build();
contextBuilder = Saml2AuthenticationRequestContext.builder()
.issuer("https://issuer")
.relyingPartyRegistration(relyingPartyRegistration)
@ -195,6 +198,20 @@ public class OpenSamlAuthenticationRequestFactoryTests {
.isInstanceOf(IllegalArgumentException.class);
}
@Test
public void createPostAuthenticationRequestWhenAssertionConsumerServiceBindingThenUses() {
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationBuilder
.assertionConsumerServiceBinding(REDIRECT)
.build();
Saml2AuthenticationRequestContext context = this.contextBuilder
.relyingPartyRegistration(relyingPartyRegistration)
.build();
Saml2PostAuthenticationRequest request = this.factory.createPostAuthenticationRequest(context);
String samlRequest = request.getSamlRequest();
String inflated = new String(samlDecode(samlRequest));
assertThat(inflated).contains("ProtocolBinding=\"" + SAMLConstants.SAML2_REDIRECT_BINDING_URI + "\"");
}
private AuthnRequest getAuthNRequest(Saml2MessageBinding binding) {
AbstractSaml2AuthenticationRequest result = (binding == REDIRECT) ?
factory.createRedirectAuthenticationRequest(context) :
@ -202,7 +219,7 @@ public class OpenSamlAuthenticationRequestFactoryTests {
String samlRequest = result.getSamlRequest();
assertThat(samlRequest).isNotEmpty();
if (result.getBinding() == REDIRECT) {
samlRequest = Saml2Utils.samlInflate(samlDecode(samlRequest));
samlRequest = samlInflate(samlDecode(samlRequest));
}
else {
samlRequest = new String(samlDecode(samlRequest), UTF_8);

View File

@ -21,6 +21,8 @@ import org.junit.Test;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRegistrationId;
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
@ -31,6 +33,7 @@ public class RelyingPartyRegistrationTests {
RelyingPartyRegistration registration = relyingPartyRegistration()
.providerDetails(p -> p.binding(POST))
.providerDetails(p -> p.signAuthNRequest(false))
.assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT)
.build();
RelyingPartyRegistration copy = RelyingPartyRegistration.withRelyingPartyRegistration(registration).build();
compareRegistrations(registration, copy);
@ -76,5 +79,22 @@ public class RelyingPartyRegistrationTests {
.isEqualTo(copy.getAssertingPartyDetails().getWantAuthnRequestsSigned())
.isEqualTo(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned())
.isFalse();
assertThat(copy.getAssertionConsumerServiceBinding())
.isEqualTo(registration.getAssertionConsumerServiceBinding());
}
@Test
public void buildWhenUsingDefaultsThenAssertionConsumerServiceBindingDefaultsToPost() {
RelyingPartyRegistration relyingPartyRegistration = withRegistrationId("id")
.entityId("entity-id")
.assertionConsumerServiceLocation("location")
.assertingPartyDetails(assertingParty -> assertingParty
.entityId("entity-id")
.singleSignOnServiceLocation("location"))
.credentials(c -> c.add(relyingPartyVerifyingCredential()))
.build();
assertThat(relyingPartyRegistration.getAssertionConsumerServiceBinding())
.isEqualTo(POST);
}
}